【深度学习实验】卷积神经网络(七):实现深度残差神经网络ResNet

news2024/12/23 12:12:52

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. Residual(残差连接)

__init__(初始化)

forward(前向传播)

2. resnet_block(残差网络块)

3. ResNet(网络模型)

__init__(初始化)

forward(前向传播)

4. 代码整合


一、实验介绍

        本实验实现了实现深度残差神经网络ResNet

        残差网络(ResNet)是一种深度神经网络架构,用于解决深层网络训练过程中的梯度消失和梯度爆炸问题。通过引入残差连接(residual connection)来构建网络层与层之间的跳跃连接,使得网络可以更好地优化深层结构。

        残差网络的一个重要应用是在图像识别任务中,特别是在深度卷积神经网络(CNN)中。通过使用残差模块,可以构建非常深的网络,例如ResNet,其在ILSVRC 2015图像分类挑战赛中取得了非常出色的成绩。

        在ResNet中,每个残差块由一个或多个卷积层组成,其中包含了跳跃连接。跳跃连接将输入直接添加到残差块的输出中,从而使得网络可以学习残差函数,即残差块只需学习将输入的变化部分映射到输出,而不需要学习完整的映射关系。这种设计有助于减轻梯度消失问题,使得网络可以更深地进行训练。

二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 导入必要的工具包

from torch import nn
import torch.nn.functional as F

1. Residual(残差连接)

class Residual(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        # 批量归一化层,将会在第7章讲到
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

__init__(初始化)

  • 参数:
    • 输入通道数`input_channels`
    • 输出通道数`num_channels`
    • 是否使用1x1卷积`use_1x1conv`
    • 步幅`strides`
  • 在初始化过程中,创建了两个卷积层`conv1`和`conv2`,分别使用不同的输入和输出通道数,并指定了卷积核的大小、填充和步幅。
  • 如果`use_1x1conv`为True,则创建一个1x1卷积层`conv3`,用于进行维度匹配;
  • 否则,将`conv3`设为None。
  • 创建两个批量归一化层`bn1`和`bn2`,用于对卷积层的输出进行批量归一化操作。

forward(前向传播)

  • 将输入`X`通过`conv1`进行卷积操作,然后经过批量归一化层`bn1`和ReLU激活函数。
  • 将输出通过`conv2`进行卷积操作,再经过批量归一化层`bn2`。
  • 如果`conv3`不为None,则将输入`X`通过`conv3`进行卷积操作,用于进行维度匹配。
  • 最后,将经过卷积和批量归一化的结果与输入相加,得到残差连接的输出。
  • 通过ReLU激活函数处理输出,并返回结果。

2. resnet_block(残差网络块)

        生成由多个残差块组成的残差网络块。

def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels,
                                use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk
  • 参数
    • input_channels:输入通道数,即每个残差块的输入的通道数。
    • num_channels:每个残差块中卷积层的输出通道数,也是每个残差块内部卷积层的通道数。
    • num_residuals:残差块的数量。
    • first_block:一个布尔值,表示是否为整个 ResNet 中的第一个残差块。
  • 创建一个空列表 blk,用于存储构建的残差块。
  • 通过一个循环迭代 num_residuals 次,每次迭代都构建一个残差块并将其添加到 blk 列表中。
    • 在每个迭代中,首先检查是否为第一个残差块且 first_block 为 False。
      • 如果是,则创建一个具有下采样(strides=2)的残差块,并将其添加到 blk 列表中。这是为了在整个 ResNet 中的第一个残差块中进行下采样。
      • 如果不是第一个残差块或者 first_block 为 True,则创建一个普通的残差块,并将其添加到 blk 列表中。
  • 返回构建好的残差块列表 blk

3. ResNet(网络模型

        ResNet 网络模型,包含了多个残差块,用于实现图像分类任务。

class ResNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
                                nn.BatchNorm2d(64), nn.ReLU(),
                                nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
        self.b3 = nn.Sequential(*resnet_block(64, 128, 2))
        self.b4 = nn.Sequential(*resnet_block(128, 256, 2))
        self.b5 = nn.Sequential(*resnet_block(256, 512, 2))
        self.head = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, num_classes))

    def forward(self, x):
        net = nn.Sequential(self.b1, self.b2, self.b3, self.b4, self.b5, self.head)

        return net(x)

__init__(初始化)

  • 参数:
    • num_classes,表示分类的类别数目
  • 调用父类的构造函数 `super().__init__()`。
  • self.b1是一个包含了卷积层、批归一化层、ReLU激活函数和最大池化层的序列。它对输入数据进行卷积操作,然后进行批归一化、ReLU激活和最大池化,用于提取输入图像的特征。
    • nn.Conv2d,使用 7x7 的卷积核对输入进行卷积操作,输出通道数为 64,步长为 2,填充为 3。
    • nn.BatchNorm2d 层,用于进行批归一化操作。
    •  ReLU 激活函数层 nn.ReLU()。
    • nn.MaxPool2d`层,使用 3x3 的池化核进行最大池化操作,步长为 2,填充为 1。
  • self.b2self.b3self.b4self.b5分别是几个残差块(resnet_block)的序列。这些残差块包含了卷积层、批归一化层和ReLU激活函数,用于进一步提取输入数据的特征。
    • self.b2使用构建了 2 个残差块,输入通道数为 64,输出通道数也为 64,并且指定 `first_block=True`,表示它是第一个残差块;
    • ……
  • self.head是一个包含自适应平均池化层(AdaptiveAvgPool2d)、展平层(Flatten)和全连接层(Linear)的序列。它将输入数据进行自适应平均池化,然后展平为一维向量,并通过全连接层将特征映射到分类的类别数目上:
    • 自适应平均池化层nn.AdaptiveAvgPool2d:将输入的特征图池化为大小为 1x1 的特征图。
    • 展平层nn.Flatten,将池化后的特征图展平成一维向量。
    • 全连接层nn.Linear,将展平后的特征映射到输出类别的数量。

forward(前向传播)

        输入数据通过上述序列模块self.b1self.b2self.b3self.b4self.b5self.head进行处理,最终输出分类结果

4. 代码整合

# 导入必要的工具包
from torch import nn
import torch.nn.functional as F

#  残差连接, 输入和输出的维度有时是相同的, 有时是不同的, 所以需要 use_1x1conv来判断是否需要
class Residual(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        # 批量归一化层,将会在第7章讲到
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)


# 残差网络是由几个不同的残差块组成的
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels,
                                use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk


class ResNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
                                nn.BatchNorm2d(64), nn.ReLU(),
                                nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
        self.b3 = nn.Sequential(*resnet_block(64, 128, 2))
        self.b4 = nn.Sequential(*resnet_block(128, 256, 2))
        self.b5 = nn.Sequential(*resnet_block(256, 512, 2))
        self.head = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, num_classes))

    def forward(self, x):
        net = nn.Sequential(self.b1, self.b2, self.b3, self.b4, self.b5, self.head)

        return net(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1073971.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

9+代谢+分型,基于代谢通路对肝癌进行分型从而开展实验。

今天给同学们分享一篇代谢分型的生信文章“Bulk and single-cell transcriptome profiling reveal extracellular matrix mechanical regulation of lipid metabolism reprograming through YAP/TEAD4/ACADL axis in hepatocellular carcinoma”,这篇文章于2023年04…

【Linux 下 MySQL5.7 中文编码设置】

前言 原本要使用 Sqoop 把我 MySQL 的数据导入到 HBase 中,习惯了使用 windows 下的 MySQL 8.0 版本,但是用 Sqoop 从windows 传到 linux 下有点复杂,就索性用我自己之前没用过的 linux 下的 MySQL 5.7,结果果然一堆问题&#xff…

爱国者的润学日记-十月

首先需要科学的准备面试和润。如何进行科学的准备工作呢? 高效的按照面试考察内容进行针对性训练,按 Machine-learning-interview准备保证处于专注的心态,如今互联网娱乐发达,之前即使比赛时我也是一边比赛一边看视频。之后准备面…

MySQL:读写分离-amoeba(7)

环境介绍 mysql主服务器 192.168.254.1 mysql从服务器(1)192.168.254.2 mysql从服务器(2)192.168.254.3 amoeba代理服务器 192.168.254.4 测试服务器 192.168.254.5 此技术搭配主从复制,我的主服务器和从服务器都…

TS类中属性的封装

我们在如下的代码中,我们在类中设置属性,创建的对象可以随意修改自身的属性,对象中的属性可以任意被修改导致对象中的数据非常不安全。 // 创建一个Person类 class Person {name: string;age: number;constructor(name: string, age: number…

通道剪枝channel pruning

1、相关定义 过参数化:主要是指在训练阶段,在数学上需要进行大量的微分求解,去捕捉数据中微小的变化信息,一旦完成迭代式的训练之后,网络模型在推理的时候就不需要这么多参数。剪枝算法:核心思想就是减少网…

【【萌新的SOC学习之小水文系列】】

萌新的SOC学习之小水文系列 SD卡读写TXT文本实验 SD 卡共有 9 个引脚线,可工作在 SDIO 模式或者 SPI 模式。在 SDIO 模式下,共用到 CLK、CMD、DAT[3:0]六根信号线;在 SPI 模式下,共用到 CS(SDIO_DAT[3])、…

栅形状的影响及可靠性的优化

栅形状的影响 VD-MOSFET单元结构采用平面栅极拓扑结构,栅极电极位于半导体的平坦上表面。虽然在这种结构中,在平面结处会发生电场增强,但在栅极电极处不会发生电场增强,因为栅极电极的边缘与高度掺杂的N源区重叠。栅极电极的边缘被…

新能源+低代码:百数服务商新领域,跨行业结合所碰撞出的新火花

新能源行业的兴起主要是在最近几年,特别是“双碳”目标提出后,中国的新能源行业迎来了快速发展的阶段。在政策支持和资本加持下,各种新能源和绿色发展基金设立,以新能源为主体的新型电力系统也得到了深化改革,大力推动…

Qt中QTimer定时器的用法

Qt中提供了两种定时器的方式一种是使用Qt中的事件处理函数,另一种就是Qt中的定时器类QTimer。 使用QTimer类,需要创建一个QTimer类对象,然后调用其start()方法开启定时器,此后QTimer对象就会周期性的发出timeout()信号。 1.QTimer…

十五、异常(6)

本章概要 Try-With-Resources 用法 揭示细节 异常匹配 Try-With-Resources 用法 在考虑所有可能失败的方法时,找出放置所有 try-catch-finally 块的位置变得令人生畏。确保没有任何故障路径,使系统远离不稳定状态,这非常具有挑战性。 Inp…

Unity ToLua热更框架使用教程(1)

从本篇开始将为大家讲解ToLua在unity当中的使用教程。 Tolua的框架叫LuaFramework,首先附上下载链接: https://github.com/jarjin/LuaFramework_UGUI_V2 这个地址的是UGUI的。 下载完之后导入项目,首先,我们要先让这个项目跑起…

老卫带你学---Datagrip连接clickhouse

Datagrip连接clickhouse Datagrip是一个DB可视化特别方便的软件,因为一些业务需要采用clickhouse,然而在download相关driver的时候出现各种问题,于是整理一下方案 1.需要下载clickhouse-jdbc的jar包,可以直接在sonatype上去下载…

C# 人像卡通化

效果 项目 代码 using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using OpenCvSharp; using System; using System.Collections.Generic; using System.Drawing; using System.Linq; using System.Threading.Tasks; using System.Windows.Forms;nam…

图像分割-Segment Anything实践

一、模型介绍 Segment Anything 模型是一种新的图像分割模型,它可以在不需要大量标注数据的情况下,对图像中的任何物体进行分割。这种方法可以帮助计算机视觉领域的研究人员和开发人员更轻松地训练模型,从而提高计算机视觉应用程序的性能。该…

超前预告 | 云原生?大模型?这届乌镇双态IT大会亮点有点多

石道旁的水面,轻轻泛着微光,几片墨绿缓缓飘下,荡起柔和的波纹,向对岸游去。这儿不似北方秋阳如火的躁动,这儿的秋色是安静的,里便是江南水乡乌镇…… 2023年,第六届双态IT乌镇用户大会将于10月…

不再为文件名大小写烦恼:批量转换,一招搞定

在电脑使用过程中,我们经常需要处理各种文件,有时需要对文件名进行大小写转换以符合特定要求或便于管理。手动修改不仅费时还容易出错,那么有没有一种方法可以批量转换文件名大小写呢?答案是肯定的,下面就为大家介绍如…

DC电源模块在电容滤波器上的设计

BOSHIDA DC电源模块在电容滤波器上的设计 DC电源模块在电容滤波器上的设计是电源管理系统中非常重要的一部分,其目的是为了确保电源输出电压的稳定性和纹波尽可能小。在设计中,需要考虑到电源负载的变化和变压器等电源配件的电磁干扰等因素。下面我们详细…

基于Java的民宿管理系统设计与实现(源码+lw+部署文档+讲解等)(民宿预约、民宿预订、民宿管理、酒店预约通用)

文章目录 前言具体实现截图论文参考详细视频演示代码参考源码获取 前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技…

【数据结构】二叉树--堆排序

目录 一 降序(建小堆) 二 升序 (建大堆) ​三 优化(以升序为例) 四 TOP-K问题 一 降序(建小堆) void Swap(int* x, int* y) {int tmp *x;*x *y;*y tmp; }//降序 建小堆 void AdjustUp(int* a, int child) {int parent (child - 1) / 2;while (child > 0){if (a[chil…