5.11如何用PyTorch实现ResNet34

news2024/12/14 5:41:00

ResNet34是由16个残差块和一个全局平局池化层和一个全连接层组成,即32个卷积层+1个pooling层+1和fc层。
训练的数据集是cifar10数据集,训练次数5,损失函数为CrossEntropyLoss(),optimizer = torch.optim.SGD。

1.先定义残差块,每个块中有两个3×3卷积

class Residual(nn.Module):
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
            self.bn3 = nn.BatchNorm2d(out_channels)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)))
        y = self.bn2(self.conv2(y))
        if self.conv3:
            x = self.bn3(self.conv3(x))

        return F.relu(y + x)

上面代码中使用1×1卷积的作用是修改输入数据x的通道数,使得x与y的形状相同。
2.定义由上面小残差块组成的大块
下图是resnet原论文中34层网络图
在这里插入图片描述
这一部分目的是构建由相同颜色的resnet块组成的大块,每个大块中小块个数为3,4,6,3,代码如下

def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
    if first_block == True:
        assert in_channels == out_channels
    blk = []
    for i in range(num_residuals):
        if i==0 and not first_block:
            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
        else:
            blk.append(Residual(out_channels, out_channels))
    return nn.Sequential(*blk)

3.定义网络
在这里插入图片描述

下面是数据进入残差结构的卷积、BN等操作,类似于VGG16。

# 定义网络
net = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)

然后向网络net中添加resnet模块

net.add_module('resnet_block1', resnet_block(64, 64, 3, first_block=True))
net.add_module('resnet_block2', resnet_block(64, 128, 4))
net.add_module('resnet_block3', resnet_block(128, 256, 6))
net.add_module('resnet_block4', resnet_block(256, 512, 3))

下面是池化层和全连接层,池化层将7×7特征图转成1×1,全连接层将(N,C,H,W)的输入转成(N,CHW),在对这个数据应用仿射线性变换,如下

net.add_module('avg', GlobalAvgPool2d())
net.add_module('fc', nn.Sequential(FlattenLayer(), nn.Linear(512, 10)))

4.加载数据集
这部分对数据集的处理包括,原图根据短边在[256,480]之间缩放,随即旋转裁剪固定224大小,再将数据转成Tensor格式。
再用torchvision读取数据dataloader加载数据

# 加载数据集
def load_data_cifar10(batch_size, root='~/Datasets/CIFAR10'):
    image_transform = torchvision.transforms.Compose([
        # Step 1: 随机调整短边大小到 [256, 480]
        torchvision.transforms.RandomResizedCrop(224, scale=(256 / 480, 1.0), ratio=(0.75, 1.33)),
        
        # Step 2: 随机水平翻转
        torchvision.transforms.RandomHorizontalFlip(),
        
        # Step 3: 转换为 Tensor 格式
        torchvision.transforms.ToTensor(),
    ])
    cifar_train = torchvision.datasets.CIFAR10(root=root, train=True, transform=image_transform, download=True)
    cifar_test = torchvision.datasets.CIFAR10(root=root, train=False, transform=image_transform, download=True)
    if sys.platform.startswith('win'):
        num_workers = 0  # 0表示不用额外的进程来加速读取数据
    else:
        num_workers = 4
    train_iter = torch.utils.data.DataLoader(cifar_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(cifar_test, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    return train_iter, test_iter

5.进行模型的训练与评估

def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, torch.nn.Module):
        # 如果没指定device就使用net的device
        device = list(net.parameters())[0].device 
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(net, torch.nn.Module):
                net.eval() # 评估模式, 这会关闭dropout
                acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                net.train() # 改回训练模式
            else: # 自定义的模型, 3.13节之后不会用到, 不考虑GPU
                if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数
                    # 将is_training设置成False
                    acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 
                else:
                    acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 
            n += y.shape[0]
    return acc_sum / n

# 训练
def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on ", device)
    loss = torch.nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

训练结果
在这里插入图片描述

完整代码

import torch
from torch import nn
import torchvision
import sys
import torch.nn.functional as F
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 定义残差块,每个块中有2个3×3卷积
class Residual(nn.Module):
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
            self.bn3 = nn.BatchNorm2d(out_channels)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)))
        y = self.bn2(self.conv2(y))
        if self.conv3:
            x = self.bn3(self.conv3(x))

        return F.relu(y + x)

# 定义网络
net = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)

def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
    if first_block == True:
        assert in_channels == out_channels
    blk = []
    for i in range(num_residuals):
        if i==0 and not first_block:
            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
        else:
            blk.append(Residual(out_channels, out_channels))
    return nn.Sequential(*blk)

class GlobalAvgPool2d(nn.Module):
    def __init__(self) -> None:
        super(GlobalAvgPool2d, self).__init__()

    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=x.size()[2:])
    
class FlattenLayer(nn.Module):
    def __init__(self) -> None:
        super(FlattenLayer, self).__init__()
    def forward(self, x):
        return x.view(x.shape[0], -1)

net.add_module('resnet_block1', resnet_block(64, 64, 3, first_block=True))
net.add_module('resnet_block2', resnet_block(64, 128, 4))
net.add_module('resnet_block3', resnet_block(128, 256, 6))
net.add_module('resnet_block4', resnet_block(256, 512, 3))

net.add_module('avg', GlobalAvgPool2d())
net.add_module('fc', nn.Sequential(FlattenLayer(), nn.Linear(512, 10)))
 

# 加载数据集
def load_data_cifar10(batch_size, root='~/Datasets/CIFAR10'):
    image_transform = torchvision.transforms.Compose([
        # Step 1: 随机调整短边大小到 [256, 480]
        torchvision.transforms.RandomResizedCrop(224, scale=(256 / 480, 1.0), ratio=(0.75, 1.33)),
        
        # Step 2: 随机水平翻转
        torchvision.transforms.RandomHorizontalFlip(),
        
        # Step 3: 转换为 Tensor 格式
        torchvision.transforms.ToTensor(),
    ])
    cifar_train = torchvision.datasets.CIFAR10(root=root, train=True, transform=image_transform, download=True)
    cifar_test = torchvision.datasets.CIFAR10(root=root, train=False, transform=image_transform, download=True)
    if sys.platform.startswith('win'):
        num_workers = 0  # 0表示不用额外的进程来加速读取数据
    else:
        num_workers = 4
    train_iter = torch.utils.data.DataLoader(cifar_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(cifar_test, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    return train_iter, test_iter

def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, torch.nn.Module):
        # 如果没指定device就使用net的device
        device = list(net.parameters())[0].device 
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(net, torch.nn.Module):
                net.eval() # 评估模式, 这会关闭dropout
                acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                net.train() # 改回训练模式
            else: # 自定义的模型, 3.13节之后不会用到, 不考虑GPU
                if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数
                    # 将is_training设置成False
                    acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 
                else:
                    acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 
            n += y.shape[0]
    return acc_sum / n

# 训练
def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on ", device)
    loss = torch.nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
        
batch_size = 64
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = load_data_cifar10(batch_size)

lr, num_epochs = 0.01, 5
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0001)
train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2259137.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

yolov,coco,voc标记的睡岗检测数据集,可识别在桌子上趴着睡,埋头睡觉,座椅上靠着睡,平躺着睡等多种睡姿的检测,6549张图片

yolov,coco,voc标记的睡岗检测数据集,可识别在桌子上趴着睡,埋头睡觉,座椅上靠着睡,平躺着睡等多种睡姿的检测,6549张图片 数据集分割 6549总图像数 训练组91% 5949图片 有效集9&#x…

【C++游记】string的使用和模拟实现

枫の个人主页 你不能改变过去,但你可以改变未来 算法/C/数据结构/C Hello,这里是小枫。C语言与数据结构和算法初阶两个板块都更新完毕,我们继续来学习C的内容呀。C是接近底层有比较经典的语言,因此学习起来注定枯燥无味&#xf…

【深度学习量化交易7】miniQMT快速上手教程案例集——使用xtQuant进行历史数据下载篇

我是Mr.看海,我在尝试用信号处理的知识积累和思考方式做量化交易,应用深度学习和AI实现股票自动交易,目的是实现财务自由~ 目前我正在开发基于miniQMT的量化交易系统。 在前几篇的文章中讲到,我正在开发的看海量化交易系统&#x…

相差不超过k的最多数,最长公共子序列(一),排序子序列,体操队形,青蛙过河

相差不超过k的最多数 链接:相差不超过k的最多数 来源:牛客网 题目描述: 给定一个数组,选择一些数,要求选择的数中任意两数差的绝对值不超过 𝑘 。问最多能选择多少个数? 输入描述: 第一行输入两个正整…

解决navicat 导出excel数字为科学计数法问题

一、原因分析 用程序导出的csv文件,当字段中有比较长的数字字段存在时,在用excel软件查看csv文件时就会变成科学技术法的表现形式。 其实这个问题跟用什么语言导出csv文件没有关系。Excel显示数字时,如果数字大于12位,它会自动转化…

C++3--内联函数、auto

1.内联函数 1.1概念 以inline修饰的函数叫做内联函数,编译时C编译器会在调用内联函数的地方展开,没有函数调用建立栈帧的开销,内联函数提升程序的效率 如果在上述函数前增加inline关键字将其改成内联函数,在编译期间编译器会用函…

AES 与 SM4 加密算法:深度解析与对比

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/literature?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,…

视频怎么转音频mp3?5种视频转音频的方法

在视频剪辑时,将视频中的音频提取出来并转换为MP3格式已成为许多人的需求。无论是为了制作音乐播放列表、剪辑音频片段,还是为了在其他设备上更方便地播放,将视频转换为音频MP3都显得尤为重要。下面将介绍五种实用的方法,帮助你轻…

Maven学习(传统Jar包管理、Maven依赖管理(导入坐标)、快速下载指定jar包)

目录 一、传统Jar包管理。 (1)基本介绍。 (2)传统的Jar包导入方法。 1、手动寻找Jar包。并放置到指定目录下。 2、使用IDEA的库管理功能。 3、配置环境变量。 (3)传统的Jar包管理缺点。 二、Maven。 &#…

【机器学习】分类器

在机器学习(Machine Learning,ML)中,分类器泛指算法或模型,用于将输入数据分为不同的类别或标签。分类器是监督学习的一部分,它依据已知的数据集中的特征和标签进行训练,并根据这些学习到的知识对新的未标记数据进行分…

uni-app在image上绘制点位并回显

在 Uni-app 中绘制多边形可以通过使用 Canvas API 来实现。Uni-app 是一个使用 Vue.js 开发所有前端应用的框架,同时支持编译为 H5、小程序等多个平台。由于 Canvas 是 H5 和小程序中都支持的 API,所以通过 Canvas 绘制多边形是一个比较通用的方法。 1.…

【机器学习与数据挖掘实战】案例01:基于支持向量回归的市财政收入分析

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈机器学习与数据挖掘实战 ⌋ ⌋ ⌋ 机器学习是人工智能的一个分支,专注于让计算机系统通过数据学习和改进。它利用统计和计算方法,使模型能够从数据中自动提取特征并做出预测或决策。数据挖掘则是从大型数…

Spire.PDF for .NET【页面设置】演示:向 PDF 文档添加页码

在 PDF 文档中添加页码不仅实用,而且美观,因为它提供了类似于专业出版材料的精美外观。无论您处理的是小说、报告还是任何其他类型的长文档的数字副本,添加页码都可以显著提高其可读性和实用性。在本文中,您将学习如何使用Spire.P…

【iOS】OC高级编程 iOS多线程与内存管理阅读笔记——自动引用计数(三)

目录 ARC规则 概要 所有权修饰符 __strong修饰符 __weak修饰符 __unsafe_unretained修饰符 __autoreleasing修饰符 ARC规则 概要 “引用计数式内存管理”的本质部分在ARC中并没有改变,ARC只是自动地帮助我们处理“引用计数”的相关部分。 在编译单位上可以…

An error happened while trying to locate the file on the Hub and we cannot f

An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on. 关于上述comfy ui使用control net预处理器的报错问…

angular19-官方教程学习

周日了解到angular已经更新到19了,想按官方教程学习一遍,工欲善其事必先利其器,先更新工具: 安装新版版本 卸载老的nodejs 20.10.0,安装最新的LTS版本 https://nodejs.org 最新LTS版本已经是22.12.0 C:\Program File…

计算机毕业设计Python+Vue.js游戏推荐系统 Steam游戏推荐系统 Django Flask 游 戏可视化 游戏数据分析 游戏大数据 爬虫 机

温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 作者简介:Java领…

上海亚商投顾:创业板指震荡调整 机器人概念股再度爆发

上海亚商投顾前言:无惧大盘涨跌,解密龙虎榜资金,跟踪一线游资和机构资金动向,识别短期热点和强势个股。 一.市场情绪 沪指昨日冲高回落,深成指、创业板指盘中跌超1%,尾盘跌幅有所收窄。机器人概念股逆势爆…

粘贴可运行:Java调用大模型(LLM) 流式Flux stream 输出;基于spring ai alibaba

在Java中,使用Spring AI Alibaba框架调用国产大模型通义千问,实现流式输出,是一种高效的方式。通过Spring AI Alibaba,开发者可以轻松地集成通义千问模型,并利用其流式处理能力,实时获取模型生成的文本。这…

【CSS in Depth 2 精译_070】11.3 利用 OKLCH 颜色值来处理 CSS 中的颜色问题(下):从页面其他颜色衍生出新颜色

当前内容所在位置(可进入专栏查看其他译好的章节内容) 第四部分 视觉增强技术 ✔️【第 11 章 颜色与对比】 ✔️ 11.1 通过对比进行交流 11.1.1 模式的建立11.1.2 还原设计稿 11.2 颜色的定义 11.2.1 色域与色彩空间11.2.2 CSS 颜色表示法 11.2.2.1 RGB…