《深度学习》——ResNet网络

news2025/2/21 18:01:05

文章目录

  • ResNet网络
    • ResNet网络实例
        • 导入所需库
        • 下载训练数据和测试数据
        • 设置每个批次的样本个数
        • 判断是否使用GPU
        • 定义残差模块
        • 定义ResNet网络
        • 模型导入GPU
        • 定义训练函数
        • 定义测试函数
        • 创建损失函数和优化器
        • 训练测试数据
        • 结果

ResNet网络

ResNet(Residual Network,残差网络)是深度学习领域中非常重要且具有影响力的一种卷积神经网络(CNN)架构,由何恺明等人于 2015 年提出,在图像识别、目标检测等诸多计算机视觉任务中取得了巨大成功。
1. 产生背景:在深度学习发展过程中,随着网络深度的增加,会出现梯度消失或梯度爆炸的问题,导 致网络难以训练。即使通过归一化等方法解决了梯度问题,还会面临退化问题,即网络深度增加时,模型的训 练误差和测试误差反而增大。ResNet 的提出就是为了解决深度神经网络中的退化问题。
在这里插入图片描述
在这里插入图片描述

  • ResNet-18:是 ResNet 家族中相对较浅的网络,由 4 个残差块组构成,每个残差块组包含不同数量的残差块。它的结构简单,计算量相对较小,适合计算资源有限或对模型复杂度要求不高的场景,如一些小型图像数据集的分类任务。它在一些对实时性要求较高的应用中,如移动设备上的图像识别,也有一定的应用。
  • ResNet-34:同样由 4 个残差块组组成,但相比 ResNet-18,它在某些残差块组中包含更多的残差块,网络深度更深,因此能够学习到更复杂的特征表示。它在中等规模的图像数据集上表现良好,在一些对模型性能有一定要求但又不过分追求极致精度的任务中较为常用。
  • ResNet-50:是一个比较常用的 ResNet 模型,在许多计算机视觉任务中都有广泛应用。它使用了瓶颈结构(Bottleneck)的残差块,这种结构通过先降维、再卷积、最后升维的方式,在减少计算量的同时保持了模型的表达能力。该模型在图像分类、目标检测、语义分割等任务中,都能作为性能不错的骨干网络,为后续的任务提供有效的特征提取。
  • ResNet-101:比 ResNet-50 的网络层数更多,拥有更强大的特征提取能力。它适用于大规模图像数据集和复杂的计算机视觉任务,如在大型目标检测数据集中,能够更好地捕捉目标的细节特征,提升检测的准确性。由于其深度和复杂度,在处理高分辨率图像或需要精细特征表示的任务时表现出色。
  • ResNet-152:是 ResNet 系列中深度较深的网络,具有极高的特征提取能力。但由于其深度很大,计算量和参数量也相应增加,训练和推理所需的时间和资源较多。它通常用于对精度要求极高的场景,如学术研究中的图像识别挑战、大规模图像搜索引擎的图像特征提取等。

18层残差网络:

在这里插入图片描述

ResNet网络实例

项目需求:对手写数字进行识别。
数据集:此项目数据集来自MNIST 数据集由美国国家标准与技术研究所(NIST)整理而成,包含手写数字的图像,主要用于数字识别的训练和测试。该数据集被分为两部分:训练集和测试集。训练集包含 60,000 张图像,用于模型的学习和训练;测试集包含 10,000 张图像,用于评估训练好的模型在未见过的数据上的性能。
图像格式:数据集中的图像是灰度图像,即每个像素只有一个值表示其亮度,取值范围通常为 0(黑色)到 255(白色)。
图像尺寸:每张图像的尺寸为 28x28 像素,总共有 784 个像素点。
标签信息:每个图像都有一个对应的标签,标签是 0 到 9 之间的整数,表示图像中手写数字的值。

导入所需库
import torch
from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据
from torchvision import datasets  # 封装了很对与图像相关的模型,数据集
from torchvision.transforms import ToTensor  # 数据转换,张量,将其他类型的数据转换成tensor张量
import torch.nn.functional as F # 用于应用 ReLU 激活函数
下载训练数据和测试数据
'''下载训练数据集(包含训练集图片+标签)'''
training_data = datasets.MNIST(  # 跳转到函数的内部源代码,pycharm 按下ctrl+鼠标点击
    root='data',  # 表示下载的手写数字 到哪个路径。60000
    train=True,  # 读取下载后的数据中的数据集
    download=True,  # 如果你之前已经下载过了,就不用再下载了
    transform=ToTensor(),  # 张量,图片是不能直接传入神经网络模型
    # 对于pytorch库能够识别的数据一般是tensor张量
)

'''下载测试数据集(包含训练图片+标签)'''
test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor(),  # Tensor是在深度学习中提出并广泛应用的数据类型,它与深度学习框架(如pytorch,TensorFlow)
)  # numpy数组只能在cpu上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度。
print(len(training_data))
print(len(test_data))
设置每个批次的样本个数
train_dataloader = DataLoader(training_data, batch_size=64)  # 建议用2的指数当作一个包的数量
test_dataloader = DataLoader(test_data, batch_size=64)
判断是否使用GPU
'''判断是否支持GPU'''
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')
定义残差模块
# 定义残差块类,继承自 nn.Module
class ResBlock(nn.Module):
    def __init__(self, channels_in):
        # 调用父类的构造函数
        super().__init__()
        # 定义第一个卷积层,输入通道数为 channels_in,输出通道数为 30,卷积核大小为 5,填充为 2
        self.conv1 = torch.nn.Conv2d(channels_in, 30, 5, padding=2)
        # 定义第二个卷积层,输入通道数为 30,输出通道数为 channels_in,卷积核大小为 3,填充为 1
        self.conv2 = torch.nn.Conv2d(30, channels_in, 3, padding=1)

    def forward(self, x):
        # 输入数据通过第一个卷积层
        out = self.conv1(x)
        # 经过第一个卷积层的输出再通过第二个卷积层
        out = self.conv2(out)
        # 将输入 x 与卷积输出 out 相加,并通过 ReLU 激活函数
        return F.relu(out + x)
定义ResNet网络
# 定义 ResNet 网络类,继承自 nn.Module
class ResNet(nn.Module):
    def __init__(self):
        # 调用父类的构造函数
        super().__init__()
        # 定义第一个卷积层,输入通道数为 1,输出通道数为 20,卷积核大小为 5
        self.conv1 = torch.nn.Conv2d(1, 20, 5)
        # 定义第二个卷积层,输入通道数为 20,输出通道数为 15,卷积核大小为 3
        self.conv2 = torch.nn.Conv2d(20, 15, 3)
        # 定义最大池化层,池化核大小为 2
        self.maxpool = torch.nn.MaxPool2d(2)
        # 定义第一个残差块,输入通道数为 20
        self.resblock1 = ResBlock(channels_in=20)
        # 定义第二个残差块,输入通道数为 15
        self.resblock2 = ResBlock(channels_in=15)
        # 定义全连接层,输入特征数为 375,输出特征数为 10
        self.full_c = torch.nn.Linear(375, 10)

    def forward(self, x):
        # 获取输入数据的批次大小
        size = x.shape[0]
        # 输入数据通过第一个卷积层,然后进行最大池化,最后通过 ReLU 激活函数
        x = F.relu(self.maxpool(self.conv1(x)))
        # 经过第一个卷积和池化的输出通过第一个残差块
        x = self.resblock1(x)
        # 经过第一个残差块的输出通过第二个卷积层,然后进行最大池化,最后通过 ReLU 激活函数
        x = F.relu(self.maxpool(self.conv2(x)))
        # 经过第二个卷积和池化的输出通过第二个残差块
        x = self.resblock2(x)
        # 将输出数据展平为一维向量
        x = x.view(size, -1)
        # 展平后的向量通过全连接层
        x = self.full_c(x)
        return x
模型导入GPU
model = ResNet().to(device)
定义训练函数
# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):
    # 将模型设置为训练模式,这会影响一些层(如 Dropout、BatchNorm 等)的行为
    model.train()
    # 初始化批次编号
    batch_size_num = 1
    # 遍历数据加载器中的每个批次
    for x, y in dataloader:
        # 将输入数据和标签移动到指定设备(如 GPU)
        x, y = x.to(device), y.to(device)
        # 前向传播,计算模型的预测结果
        pred = model.forward(x)
        # 通过交叉熵损失函数计算预测结果与真实标签之间的损失值
        loss = loss_fn(pred, y)
        # 反向传播步骤:
        # 清零优化器中的梯度信息,防止梯度累积
        optimizer.zero_grad()
        # 反向传播计算每个参数的梯度
        loss.backward()
        # 根据计算得到的梯度更新模型的参数
        optimizer.step()
        # 从张量中提取损失值的标量
        loss_value = loss.item()
        # 每 100 个批次打印一次损失值
        if batch_size_num % 100 == 0:
            print(f'loss:{loss_value:7f}  [number:{batch_size_num}]')
        # 批次编号加 1
        batch_size_num += 1

定义测试函数
# 定义测试函数
def test(dataloader, model, loss_fn):
    # 获取数据集的总样本数
    size = len(dataloader.dataset)
    # 获取数据加载器中的批次数量
    num_batches = len(dataloader)
    # 将模型设置为评估模式,这会影响一些层(如 Dropout、BatchNorm 等)的行为
    model.eval()
    # 初始化测试损失和正确预测的样本数
    test_loss, correct = 0, 0
    # 上下文管理器,关闭梯度计算,减少内存消耗
    with torch.no_grad():
        # 遍历数据加载器中的每个批次
        for x, y in dataloader:
            # 将输入数据和标签移动到指定设备(如 GPU)
            x, y = x.to(device), y.to(device)
            # 前向传播,计算模型的预测结果
            pred = model.forward(x)
            # 累加每个批次的损失值
            test_loss += loss_fn(pred, y).item()
            # 计算每个批次中预测正确的样本数并累加
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    # 计算平均测试损失
    test_loss /= num_batches
    # 计算平均准确率
    correct /= size
    # 打印测试结果
    print(f'Test result: \n Accuracy:{(100 * correct)}%,Avg loss:{test_loss}')
创建损失函数和优化器
# 创建交叉熵损失函数对象
loss_fn = nn.CrossEntropyLoss()
# 创建 Adam 优化器,用于更新模型的参数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=3,gamma=0.1)
训练测试数据
# 定义训练的轮数
epochs = 26
# 开始训练循环
for t in range(epochs):
    print(f'epoch{t + 1}\n--------------------')
    # 调用训练函数进行一轮训练
    train(train_dataloader, model, loss_fn, optimizer)
print('Done!')
# 调用测试函数进行测试
test(test_dataloader, model, loss_fn)
结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2301133.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Windows软件 - HeidiSQL】导出数据库

HeidSQL导出数据库 软件信息 具体操作 示例文件 选项分析 选项(1) 结果(1) -- -------------------------------------------------------- -- 主机: 127.0.0.1 -- 服务器版本: …

【达梦数据库】dblink连接[SqlServer/Mysql]报错处理

目录 背景问题1:无法测试以ODBC数据源方式访问的外部链接!问题分析&原因解决方法 问题2:DBLINK连接丢失问题分析&原因解决方法 问题3:DBIINK远程服务器获取对象[xxx]失败,错误洋情[[FreeTDS][SQL Server]Could not find stored proce…

java断点调试(debug)

在开发中,新手程序员在查找错误时, 这时老程序员就会温馨提示,可以用断点调试,一步一步的看源码执行的过程,从而发现错误所在。 重要提示: 断点调试过程是运行状态,是以对象的运行类型来执行的 断点调试介绍 断点调试是…

最新智能优化算法:牛优化( Ox Optimizer,OX)算法求解经典23个函数测试集,MATLAB代码

一、牛优化算法 牛优化( OX Optimizer,OX)算法由 AhmadK.AlHwaitat 与 andHussamN.Fakhouri于2024年提出,该算法的设计灵感来源于公牛的行为特性。公牛以其巨大的力量而闻名,能够承载沉重的负担并进行远距离运输。这种…

Redis7——基础篇(四)

前言:此篇文章系本人学习过程中记录下来的笔记,里面难免会有不少欠缺的地方,诚心期待大家多多给予指教。 基础篇: Redis(一)Redis(二)Redis(三) 接上期内容&…

Git备忘录(三)

设置用户信息: git config --global user.name “itcast” git config --global user.email “ helloitcast.cn” 查看配置信息 git config --global user.name git config --global user.email $ git init $ git remote add origin gitgitee.com:XXX/avas.git $ git pull or…

MySQL 之INDEX 索引(Index Index of MySQL)

MySQL 之INDEX 索引 1.4 INDEX 索引 1.4.1 索引介绍 索引:是排序的快速查找的特殊数据结构,定义作为查找条件的字段上,又称为键 key,索引通过存储引擎实现。 优点 大大加快数据的检索速度; 创建唯一性索引,保证数…

Linux基础24-C语言之分支结构Ⅰ【入门级】

分支结构 问题抛出 我们在程序设计中往往会遇到如下问题,比如下面的函数计算: 也就是我们必须要通过一个条件的结果来选择下一步的操作,算法上属于一个分支结构,处于严重实现分支结构主要使用if语句。 条件判断 根据某个条件成…

LeetCode47

LeetCode47 目录 题目描述示例思路分析代码段代码逐行讲解复杂度分析总结的知识点整合总结 题目描述 给定一个可包含重复数字的整数数组 nums,按任意顺序返回所有不重复的全排列。 示例 示例 1 输入: nums [1, 1, 2]输出: [[1, 1, 2],[1, 2, 1],[2, 1, 1] ]…

【Unity动画】导入动画资源到项目中,Animator播放角色动画片段,角色会跟随着动画播放移动。

导入动画资源到项目中,Animator播放角色动画片段,角色会跟随着动画播放移动,但我只想要角色在原地播放动画。比如:播放一个角色Run动画,希望角色在原地奔跑,而不是产生了移动距离。 问题排查: 1.是否勾选…

图解循环神经网络(RNN)

目录 1.循环神经网络介绍 2.网络结构 3.结构分类 4.模型工作原理 5.模型工作示例 6.总结 1.循环神经网络介绍 RNN(Recurrent Neural Network,循环神经网络)是一种专门用于处理序列数据的神经网络结构。与传统的神经网络不同&#xff0c…

【数据结构】(9) 优先级队列(堆)

一、优先级队列 优先级队列不同于队列,队列是先进先出,优先级队列是优先级最高的先出。一般有两种操作:返回最高优先级对象,添加一个新对象。 二、堆 2.1、什么是堆 堆也是一种数据结构,是一棵完全二叉树&#xff0c…

4、IP查找工具-Angry IP Scanner

在前序文章中,提到了多种IP查找方法,可能回存在不同场景需要使用不同的查找命令,有些不容易记忆,本文将介绍一个比较优秀的IP查找工具,可以应用在连接树莓派或查找IP的其他场景中。供大家参考。 Angry IP Scanner下载…

【Linux】命令操作、打jar包、项目部署

阿华代码,不是逆风,就是我疯 你们的点赞收藏是我前进最大的动力!! 希望本文内容能够帮助到你!! 目录 一:Xshell下载 1:镜像设置 二:阿里云设置镜像Ubuntu 三&#xf…

瑞萨RA-T系列芯片ADCGPT功能模块的配合使用

在马达或电源工程中,往往需要采集多路AD信号,且这些信号的优先级和采样时机不相同。本篇介绍在使用RA-T系列芯片建立马达或电源工程时,如何根据需求来设置主要功能模块ADC&GPT,包括采样通道打包和分组,GPT触发启动…

Unity Shader学习6:多盏平行光+点光源 ( 逐像素 ) 前向渲染 (Built-In)

0 、分析 在前向渲染中,对于逐像素光源来说,①ForwardBase中只计算一个平行光,其他的光都是在FowardAdd中计算的,所以为了能够渲染出其他的光照,需要在第二个Pass中再来一遍光照计算。 而有所区别的操作是&#xff0…

tailwindcss学习01

系列教程 01 入门 02 vue中接入 入门 # 注意使用cmd不要powershell npm init -y # 如果没有npx则安装 npm install -g npx npm install -D tailwindcss3.4.17 --registry http://registry.npm.taobao.org npx tailwindcss init修改tailwind.config.js /** type {import(tai…

DIN:引入注意力机制的深度学习推荐系统,

实验和完整代码 完整代码实现和jupyter运行:https://github.com/Myolive-Lin/RecSys--deep-learning-recommendation-system/tree/main 引言 在电商与广告推荐场景中,用户兴趣的多样性和动态变化是核心挑战。传统推荐模型(如Embedding &…

【前端】如何安装配置WebStorm软件?

文章目录 前言一、前端开发工具WebStorm和VS Code对比二、官网下载三、安装1、开始安装2、选择安装路径3、安装选项4、选择开始菜单文件夹5、安装成功 四、启动WebStorm五、登录授权六、开始使用 前言 WebStorm 是一款由 JetBrains 公司开发的专业集成开发环境(IDE…

【Golang学习之旅】Go 语言微服务架构实践(gRPC、Kafka、Docker、K8s)

文章目录 1. 前言:为什么选择Go语言构建微服务架构1.1 微服务架构的兴趣与挑战1.2 为什么选择Go语言构建微服务架构 2. Go语言简介2.1 Go 语言的特点与应用2.2 Go 语言的生态系统 3. 微服务架构中的 gRPC 实践3.1 什么是 gRPC?3.2 gRPC 在 Go 语言中的实…