动手学深度学习——残差网络ResNet(原理解释+代码详解)

news2024/11/28 10:39:15

残差网络ResNet

      • 1. 函数类
      • 2. 残差块
      • 3. ResNet模型
      • 4. 训练模型

ResNet为了解决“新添加的层如何提升神经网络的性能”,它在2015年的ImageNet图像识别挑战赛夺魁
在这里插入图片描述

它深刻影响了后来的深度神经网络的设计,ResNet的被引用量更是达到了19万+。
在这里插入图片描述

1. 函数类

假设有一类特定的神经网络架构F,它包括学习速率和其他超参数设置。对于所有f∈F,存在一些参数集(例如权重和偏置),这些参数可以通过在合适的数据集上进行训练而获得。

现在假设 f* 是我们真正想要找到的函数,如果是 f*∈F,那可以轻而易举的训练得到它。

给定一个具有X特性和y标签的数据集
在这里插入图片描述
在这里插入图片描述是我们要找的函数,为了使其更近似真正的 f* ,则需要更强的架构F’。

在这里插入图片描述
对于非嵌套函数类,较复杂的函数类并不总是向“真”函数 f* 靠拢(复杂度由F1向F6递增)。虽然F3比F1更接近 f*,但却离F6的更远了。

而右侧的嵌套函数可以避免上述问题。

2. 残差块

残差网络的核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。
在这里插入图片描述
F(x) + x包含了原始元素。

ResNet沿用了VGG完整的3x3卷积层设计。

  • 残差块里首先有2个有相同输出通道数的3x3卷积层。
  • 每个卷积层后接一个批量规范化层和ReLU激活函数。
  • 然后通过跨层数据通路,跳过这2个卷积运算,将输入直接加在最后的ReLU激活函数前。
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l


class Residual(nn.Module): #@save
    # use_ixiconv:残差连接是直接连接还是通过卷积层连接
    def __init__(self, input_channels, num_channels,
                 use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels, 
                               kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels, 
                               kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels, 
                               kernel_size=3, padding=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X 
        return F.relu(Y)

此代码生成两种类型的网络:

  • 一种是当use_1x1conv=False时,应用ReLU非线性函数之前,将输入添加到输出。
  • 另一种是当use_1x1conv=True时,添加通过1x1卷积调整通道和分辨率。
    在这里插入图片描述查看输入和输出形状一致的情况
# 输入、输出的情况
blk = Residual(3, 3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
Y.shape

在这里插入图片描述
在增加输出通道数的同时,减半输出的高和宽

# 增加输出通道的同时,减半输出的高度和宽度
blk = Residual(3, 6, use_1x1conv=True, strides=2)
blk(X).shape

在这里插入图片描述

3. ResNet模型

ResNet的前两层跟之前介绍的GoogLeNet中的一样: 在输出通道数为64、步幅为2的7x7卷积层后,接步幅为2的3x3的最大汇聚层。 不同之处在于ResNet每个卷积层后增加了批量规范化层。

# ResNet在每个卷积层后增加了批量规范层
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

ResNet使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。
第一个模块的通道数同输入通道数一致。之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。

# 实现残差连接模块:由4个残差连接块组成
def resnet_block(input_channels, num_channels, num_residuals, 
                first_block=False):
    # 定义空网络结构
    blk = []
    for i in range(num_residuals):
        # 第2,3,4个Inception块的第一个残差模块连接1x1卷积层
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels, 
                               use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk

接着在ResNet加入所有残差块,这里每个模块使用2个残差块。

# 在ResNet加入所有残差块,每个模版使用2个残差块
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))

最后在ResNet中加入全局平均汇聚层,以及全连接层输出。

# 在ResNet加入全局平均汇聚层以及全连接层输出
net = nn.Sequential(b1, b2, b3, b4, b5,
                   nn.AdaptiveAvgPool2d((1, 1)),
                   nn.Flatten(), nn.Linear(512, 10))

每个模块有4个卷积层,加上第一个7x7卷积层和最后一个全连接层,共有18层,这种模型通常被称为ResNet-18。
在这里插入图片描述
观察一下ResNet中不同模块的输入形状是如何变化

# 每个模块有4个卷积层(不包括恒等映射的1x1卷积层)。 
# 加上第一个7x7卷积层和最后一个全连接层,共有18层。 
# 因此,这种模型通常被称为ResNet-18。
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, 'output_shape:\t', X.shape)

在这里插入图片描述

4. 训练模型

在Fashion-MNIST数据集上训练ResNet
定义精度评估函数

"""
    定义精度评估函数:
    1、将数据集复制到显存中
    2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    # 判断net是否属于torch.nn.Module类
    if isinstance(net, nn.Module):
        net.eval()
        
        # 如果不在参数选定的设备,将其传输到设备中
        if not device:
            device = next(iter(net.parameters())).device
    
    # Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            # 将X, y复制到设备中
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            
            # 计算正确预测的数量,总预测的数量,并存储到metric中
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

定义GPU训练函数

"""
    定义GPU训练函数:
    1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;
    2、使用Xavier随机初始化模型参数;
    3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    # 定义初始化参数,对线性层和卷积层生效
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    
    # 在设备device上进行训练
    print('training on', device)
    net.to(device)
    
    # 优化器:随机梯度下降
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    
    # 损失函数:交叉熵损失函数
    loss = nn.CrossEntropyLoss()
    
    # Animator为绘图函数
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    
    # 调用Timer函数统计时间
    timer, num_batches = d2l.Timer(), len(train_iter)
    
    for epoch in range(num_epochs):
        
        # Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量
        metric = d2l.Accumulator(3)
        net.train()
        
        # enumerate() 函数用于将一个可遍历的数据对象
        for i, (X, y) in enumerate(train_iter):
            timer.start() # 进行计时
            optimizer.zero_grad() # 梯度清零
            X, y = X.to(device), y.to(device) # 将特征和标签转移到device
            y_hat = net(X)
            l = loss(y_hat, y) # 交叉熵损失
            l.backward() # 进行梯度传递返回
            optimizer.step()
            with torch.no_grad():
                # 统计损失、预测正确数和样本数
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop() # 计时结束
            train_l = metric[0] / metric[2] # 计算损失
            train_acc = metric[1] / metric[2] # 计算精度
            
            # 进行绘图
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
                
        # 测试精度
        test_acc = evaluate_accuracy_gpu(net, test_iter) 
        animator.add(epoch + 1, (None, None, test_acc))
        
    # 输出损失值、训练精度、测试精度
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'
          f'test acc {test_acc:.3f}')
    
    # 设备的计算能力
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'
          f'on {str(device)}')

在这里插入图片描述

训练模型

# 训练模型
lr, num_epochs, batch_size = 0.05, 10 ,256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1168724.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android studio新版本多渠道打包配置

最近公司套壳app比较多 功能也都一样只有地址,和app名字还有icon不一样 签名文件也是一样的,所以就研究了多渠道打包 配置如下: 在app下build.gradle配置 因为最新版as中禁用了BuildConfig 所以我们需要手动配置一下 android { //TODO 其他省略buildFe…

网络编程套接字(二)

目录 简单的TCP网络程序服务端创建套接字服务端绑定服务端监听服务端获取连接服务端处理请求单执行流服务器的弊端 多进程版TCP网络程序捕捉SIGCHLD信号让孙子进程提供服务多线程版的TCP网络程序客户端创建套接字客户端链接服务器客户端发起请求 线程池版的TCP网络程序 简单的T…

SpringBoot整合数据库版本管理工具Liquibase,赶紧整起来!

SpringBoot整合数据库版本管理工具Liquibase 背景一、什么是数据库版本管理工具?数据库版本管理工具主要特性什么是数据库版本管理工具Flyway和Liquibase对比及选型 二、Liquibase整合步骤1.引入pom依赖2.配置application.yml3.新建master.xml(用于配置你…

使用pandas处理excel文件【Demo】

一、代码示例 import pandas as pd from pandas import Series,DataFrame from pandasql import sqldf import matplotlib.pyplotidInfos DataFrame(pd.read_excel(home_data.xlsx))print(idInfos.head(2))print(idInfos.dtypes)# print(idInfos[:][姓名]) # 自定义一个函数s…

QTreeView 常见节点操作

目录 1、节点遍历 2、设置当前选中项 3、树节点数据绑定 4、树节点自定义样式 5、数据检索 6、获取当前选中项 QTreeView作为项目最经常使用的空间,常用接口和操作必须熟悉熟悉在熟悉!!! 1、节点遍历 void ParamSettingDl…

【python VS vba】(3) 在python直接调用vba脚本

目录 0 前言 1 VBA 内容 1.1 EXCEL这边VBA的内容 1.2 VBA的测试代码 2 python 调用 2.1 python 调用VBA的过程和结果 2.2 代码 0 前言 前面写了这么多,没想到,其实py是可以直接支持VBA的 python的模块import xlwings,可以让python直…

Leetcode1128. 等价多米诺骨牌对的数量

Every day a Leetcode 题目来源&#xff1a;1128. 等价多米诺骨牌对的数量 解法1&#xff1a;暴力 代码&#xff1a; class Solution { public:int numEquivDominoPairs(vector<vector<int>> &dominoes){int n dominoes.size(), count 0;for (int i 0;…

理解训练深度前馈神经网络的难度【PMLR 2010】

论文地址&#xff1a;Excellent-Paper-For-Daily-Reading/summarize at main 类别&#xff1a;综述 时间&#xff1a;2023/11/03 摘要 这篇论文比较久了&#xff0c;但仍能从里面获得一些收获&#xff0c;论文主要是讨论并研究了不同的非线性激活函数的影响&#xff0c;sig…

Java与Redis的集成以及Redis中的项目应用

一、Java连接Redis Redis与MySQL都是数据库&#xff0c;java操作redis其实跟操作mysql的过程是一样的。 1.1 导入依赖 打开IDEA&#xff0c;进入Java项目&#xff0c;导入pom依赖&#xff0c;代码如下&#xff1a; <dependency><groupId>redis.clients</gro…

threejs 2. 辅助对象

ArrowHelper 箭头 const arrowHelper new THREE.ArrowHelper(new THREE.Vector3( 1, 2, 0 ).normalize(), // 方向向量必须是单位向量,默认值为(0,0,1)new THREE.Vector3( 0, 0, 0 ), // 起点,默认值为(0,0,0)1, // 长度,默认值为10xffff00, // 颜色,默认值为0xffff00undefine…

JavaSE基础 --- 类与对象

1.类与对象的定义 类是一种抽象的数据类型&#xff0c;它描述了一类对象的行为和状态。例如&#xff0c;我们可以定义一个名为“Dog”的类&#xff0c;它描述了狗这类动物的一般特性&#xff0c;如颜色、品种等状态&#xff0c;以及跑、叫等行为。 对象则是类的实例&#xff0c…

【金TECH频道】企业架构转型组合拳来袭,助力金融机构一臂之力

当前&#xff0c;数字化转型已经成为时代共性课题 在政策和技术的双重指引下 金融机构逐渐走向差异化竞争的格局 面对转型阵痛 以契合、明晰的战略规划及 企业架构调整来辅助业务变革 成为助力企业数字化转型的有效路径 金融机构也纷纷开始探索 企业架构转型的新思路、…

SQL Server2000mdf升级SQL Server2005数据库还原

SQL Server2000数据库还原sqlserver 2000mdf升级 sqlserver 2008数据库还原SQL Server2005数据库脚本 sqlserver数据库低版本升级成高版本 sqlserver数据库版本升级 数据库版本还原 如果本机安装了sqlserver2012或者sqlserver2019等高版本 怎么样才能运行sqlserver2000的数据库…

2003-2022年地级市-财政收支明细数据(企业、个人所得税、科学、教育、医疗等)

2003-2022年地级市-财政收支明细数据&#xff08;企业、个人所得税、科学、教育、医疗等&#xff09; 1、时间&#xff1a;2003-2022年 2、指标&#xff1a;行政区划代码、年份、地区、一般公共预算收入、一般公共预算-税收收入、一般公共预算-税收收入-增值税收入、一般公共…

软件测试面试题及答案2024

1、你们的缺陷等级如何划分的&#xff1f;☆☆☆☆☆ 我们的缺陷一般分为四个等级&#xff0c;致命级&#xff0c;严重级&#xff0c;一般级和轻微级。致命级指能够导致软件程序无法使用的缺陷&#xff0c;比如宕机&#xff0c;崩溃&#xff0c;手机APP的闪退&#xff0c;数据…

【AUTOSAR CANTP】深入理解CAN传输层:N-SDU数据接收与缓冲处理

1. 前言 CanTp是PDU路由器和CAN接口模块之间的那个模块。它的主要作用就是对超过8字节或者CAN FD情况下超过64字节的CAN I-PDU进行分段和重组啦。PDU路由器会把AUTOSAR COM和DCM I-PDU放到不同的通信协议上去,具体是用哪个网络系统类型(比如CAN、LIN和FlexRay)来路由,就看…

python图像处理 ——几种图像增强技术

图像处理 ——几种图像增强技术 前言一、几种图像增强技术1.直方图均衡化2.直方图适应均衡化3.灰度变换4.同态滤波5.对比拉伸6.对数变换7.幂律变换&#xff08;伽马变换&#xff09; 前言 图像增强是指通过各种算法和技术&#xff0c;改善或提高数字图像的质量、清晰度、对比度…

半阵法单脉冲测角

半阵法单脉冲测角 单脉冲测角的类型确知波束形成导向矢量半阵测向原理半阵测向仿真 单脉冲测角的类型 传统的单脉冲测向方法主要有3种&#xff0c;分别是半阵法、加权法和和差比幅法。在了解单脉冲测向之前&#xff0c;首先要知道确知波束形成&#xff0c;确知波束形成就是设计…

python GUI tkinter实战

筛选出列长度不为指定长度的列 from os import path from tkinter import (BOTH, BROWSE, EXTENDED, INSERT, Button, Frame, Label,Text, Tk, filedialog, mainloop, messagebox) import matplotlib.pyplot as plt import pandas as pd from PIL import Image, ImageTk from …

P3398 仓鼠找 sugar

Portal. LCA。 询问树上两条路径是否有交点。 画图发现无非两种情况&#xff1a; 发现一条路径的起点和终点的 LCA 经过另一条路径&#xff0c;是两路径相交的充要条件。 考虑如何判断这个 LCA 在不在路径上。若 d ( s , LCA ) d ( LCA , t ) d ( s , t ) d(s,\text{LCA…