《动手学深度学习 Pytorch版》 7.6 残差网络(ResNet)

news2024/12/23 20:21:23
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

7.6.1 函数类

如果把模型看作一个函数,我们设计的更强大的模型则可以看作范围更大的函数。为了使函数能逐渐靠拢到最优解,应尽量使函数嵌套,以减少不必要的偏移。

如下图,更复杂的非嵌套函数不一定能保证更接近真正的函数。只有当较复杂的函数类包含较小的函数类时,我们才能确保提高它们的性能。

在这里插入图片描述

7.6.2 残差块

何恺明等人针对上述问题提出了残差网络。它在2015年的ImageNet图像识别挑战赛夺魁,并深刻影响了后来的深度神经网络的设计。残差网络的核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。

假设原始输入是 x x x,而希望学习的理想映射为 f ( x ) f(x) f(x),则残差块需要拟合的便是残差映射 f ( x ) − x f(x)-x f(x)x。残差映射在现实中更容易优化,也更容易捕获恒等函数的细微波动。之后再和 x x x 进行加法从而使整个模型重新变成 f ( x ) f(x) f(x),这里的加法会更有益于靠近数据端的层的训练,因为乘法中的梯度波动会极大的影响链式法则的结果,而在残差块中输入可以通过加法通路更快的前向传播。

此即为正常块和残差块的区别:

在这里插入图片描述

ResNet 沿用了 VGG 完整的卷积层设计。残差块里首先有2个有相同输出通道数的 3 × 3 3\times 3 3×3 卷积层。每个卷积层后接一个批量规范化层和 ReLU 激活函数。然后通过跨层数据通路跳过这 2 个卷积运算,将输入直接加在最后的 ReLU 激活函数前。这样的设计需要 2 个卷积层的输出与输入形状一样才能使它们可以相加。

class Residual(nn.Module):  #@save
    def __init__(self, input_channels, num_channels,
                 use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,
                               kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,
                               kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

残差块如果想改变通道数,就需要引入一个额外的 1 × 1 1\times1 1×1 卷积层来将输入变换成需要的形状后再做相加运算。上述类在 use_1x1conv=False 时,应用在 ReLU 非线性函数之前,将输入添加到输出;在当 use_1x1conv=True 时,添加通过 1 × 1 1\times1 1×1 卷积调整通道和分辨率。

在这里插入图片描述

blk = Residual(3, 3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
Y.shape  # 输入形状和输出形状一致
torch.Size([4, 3, 6, 6])
blk = Residual(3, 6, use_1x1conv=True, strides=2)  # 增加输出通道数的同时 减半输出的高度和宽度
blk(X).shape
torch.Size([4, 6, 3, 3])

7.6.3 ResNet 模型

每个模块有 4 个卷积层(不包括恒等映射的 1 × 1 1\times 1 1×1 卷积层)。加上第一个 $ 7\times 7$ 卷积层和最后一个全连接层,共有18层。因此,这种模型通常被称为 ResNet-18。虽然 ResNet 的主体架构跟 GoogLeNet 类似,但 ResNet 架构更简单,修改也更方便。

在这里插入图片描述

ResNet 的前两层跟 GoogLeNet 一样,在输出通道数为 64、步幅为 2 的 7 × 7 7\times7 7×7 卷积层后,接步幅为 2 的 3 × 3 3\times3 3×3 最大汇聚层。不同之处在于 ResNet 每个卷积层后增加了批量规范化层。

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

ResNet 在后面使用了 4 个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。第一个模块的通道数同输入通道数一致。由于之前已经使用了步幅为 2 的最大汇聚层,所以无须减小高和宽。之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。

def resnet_block(input_channels, num_channels, num_residuals,
                 first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:  # 第一块特别处理
            blk.append(Residual(input_channels, num_channels,
                                use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk

每个模块使用 2 个残差块,最后加入全局平均汇聚层,以及全连接层输出。

b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))
net = nn.Sequential(b1, b2, b3, b4, b5,
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten(), nn.Linear(512, 10))
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
Sequential output shape:	 torch.Size([1, 64, 56, 56])
Sequential output shape:	 torch.Size([1, 64, 56, 56])
Sequential output shape:	 torch.Size([1, 128, 28, 28])
Sequential output shape:	 torch.Size([1, 256, 14, 14])
Sequential output shape:	 torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 512, 1, 1])
Flatten output shape:	 torch.Size([1, 512])
Linear output shape:	 torch.Size([1, 10])

7.6.4 训练模型

lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())  # 大约需要十五分钟,慎跑
loss 0.010, train acc 0.998, test acc 0.913
731.5 examples/sec on cuda:0

在这里插入图片描述

练习

(1)图 7-5 中的 Inception 块与残差块之间的主要区别是什么?在删除了 Inception 块中的一些路径之后,它们是如何相互关联的?

残差块并没有像 Inception 那样使用太多并行路径。和 Inception 的相似之处在于都使用了并联的 1 × 1 1\times 1 1×1的卷积核。


(2)参考 ResNet 论文中的表 1,以实现不同的变体。

在这里插入图片描述

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
b2 = nn.Sequential(*resnet_block(64, 64, 3, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 4))
b4 = nn.Sequential(*resnet_block(128, 256, 6))
b5 = nn.Sequential(*resnet_block(256, 512, 3))
net34 = nn.Sequential(b1, b2, b3, b4, b5,
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten(), nn.Linear(512, 10))

lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net34, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())  # 大约需要二十五分钟,慎跑
loss 0.048, train acc 0.983, test acc 0.885
429.5 examples/sec on cuda:0

在这里插入图片描述

ResNet-34 还是阶梯状下降,只不过台阶变低了。起步就不如18,最终精度也不如 ResNet-18。


(3)对于更深层的网络,ResNet 引入了“bottleneck”架构来降低模型复杂度。请尝试它。

class Residual_bottleneck(nn.Module):
    def __init__(self, input_channels, mid_channels, num_channels,
                 use_1x1conv=False, strides=1):
        super().__init__()
        # 下面改成 bottleneck
        self.conv1 = nn.Conv2d(input_channels, mid_channels,
                               kernel_size=1)
        self.conv2 = nn.Conv2d(mid_channels, mid_channels,
                               kernel_size=3, padding=1, stride=strides)
        self.conv3 = nn.Conv2d(mid_channels, num_channels,
                               kernel_size=1)
        if use_1x1conv:
            self.conv4 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv4 = None
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.bn2 = nn.BatchNorm2d(mid_channels)
        self.bn3 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = F.relu(self.bn2(self.conv2(Y)))
        Y = self.bn3(self.conv3(Y))
        if self.conv4:
            X = self.conv4(X)
        Y += X
        return F.relu(Y)

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

def resnet_block_bottleneck(input_channels, mid_channels, num_channels, num_residuals,
                 first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:  # 第一块特别处理
            blk.append(Residual_bottleneck(input_channels, mid_channels, num_channels,
                                use_1x1conv=True, strides=2))
        else:
            blk.append(Residual_bottleneck(num_channels, mid_channels, num_channels))
    return blk

b2 = nn.Sequential(*resnet_block_bottleneck(64, 16, 64, 3, first_block=True))
b3 = nn.Sequential(*resnet_block_bottleneck(64, 32, 128, 4))
b4 = nn.Sequential(*resnet_block_bottleneck(128, 64, 256, 6))
b5 = nn.Sequential(*resnet_block_bottleneck(256, 128, 512, 3))

net1 = nn.Sequential(b1, b2, b3, b4, b5,
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten(), nn.Linear(512, 10))

lr, num_epochs, batch_size = 0.05, 10, 64
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net1, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())  # 大约需要十五分钟,慎跑
loss 0.115, train acc 0.957, test acc 0.915
887.9 examples/sec on cuda:0

在这里插入图片描述

ResNet-50 跑不了一点,十分钟一个batch都跑不完。还是给 ResNet-34 强行换上 bottleneck 吧。

可以说提速效果显著,训练嘎嘎快,精度还反升了。


(4)在 ResNet 的后续版本中,作者将“卷积层、批量规范化层和激活层”架构更改为“批量规范化层、激活层和卷积层”架构。请尝试做这个改进。详见参考文献[57]中的图1。

class Residual_change(nn.Module):
    def __init__(self, input_channels, num_channels,
                 use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,
                               kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,
                               kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(input_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):  # 修改顺序
        Y = self.conv1(F.relu(self.bn1(X)))
        Y = self.conv2(F.relu(self.bn2(Y)))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return Y

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

def resnet_block_change(input_channels, num_channels, num_residuals,
                 first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual_change(input_channels, num_channels,
                                use_1x1conv=True, strides=2))
        else:
            blk.append(Residual_change(num_channels, num_channels))
    return blk

b2 = nn.Sequential(*resnet_block_change(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block_change(64, 128, 2))
b4 = nn.Sequential(*resnet_block_change(128, 256, 2))
b5 = nn.Sequential(*resnet_block_change(256, 512, 2))

net2 = nn.Sequential(b1, b2, b3, b4, b5, nn.BatchNorm2d(512), nn.ReLU(),  # 如果最后不再上个BatchNorm2d则会完全不收敛
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten(), nn.Linear(512, 10))

lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net2, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())  # 大约需要二十五分钟,慎跑
loss 0.039, train acc 0.988, test acc 0.905
724.1 examples/sec on cuda:0

在这里插入图片描述

精度有所下降

在这里插入图片描述


(5)为什么即使函数类是嵌套的,我们也仍然要限制增加函数的复杂度呢?

限制复杂度是永远不变的主题,复杂度高更易过拟合,可解释性成吨下降。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1050324.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java 基于 SpringBoot 的在线学习平台

1 简介 基于SpringBoot的Java学习平台,通过这个系统能够满足学习信息的管理及学生和教师的学习管理功能。系统的主要功能包括首页,个人中心,学生管理,教师管理,课程信息管理,类型管理,作业信息…

F12报错前端对应请求接口未在NetWork显示

问题背景 今天看到一个接口在部分情况下为正常渲染数据 发现是后端发送数据有问题,但是在NetWork里面怎么都找不到 问题原因 翻看代码,发现是一种异步请求 内部报错了,所以浏览器看不到接口 具体情况 翻看控制台: 发现属性未…

QT用户登录注册,数据库实现

登录窗口头文件 #ifndef LOGINUI_H #define LOGINUI_H#include <QWidget> #include <QLineEdit> #include <QPushButton> #include <QLabel> #include <QMessageBox>#include <QSqlDatabase> //数据库管理类 #include <QSqlQuery> …

【力扣每日一题】2023.9.28 花期内花的数目

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 给我们一个二维数组来表示花期&#xff0c;在一段花期之内花是开的。另外给我们一个一维数组表示来人的时间&#xff0c;要我们返回一个一…

使用Vue3+elementPlus的Tree组件实现一个拖拽文件夹管理

文章目录 1、前言2、分析3、实现4、踩坑4.1、拖拽辅助线的坑4.2、数据的坑4.3、限制拖拽4.4、样式调整 1、前言 最近在做一个文件夹管理的功能&#xff0c;要实现一个树状的文件夹面板。里面包含两种元素&#xff0c;文件夹以及文件。交互要求如下&#xff1a; 创建、删除&am…

三子棋小游戏(简单详细)

设计总体思路 实现游戏可以一直玩&#xff0c;先打印棋盘&#xff0c;玩家和电脑下棋&#xff0c;最后分出胜负。 如果编写较大的程序&#xff0c;我们可以分不同模块 例如这个三子棋&#xff0c;我们可以创建三个文件 分别为&#xff1a; game.h 函数的声明game.c 函数…

求臻医学:乳腺癌治疗与基因检测 探索个性化医疗的未来

乳腺癌是全球女性最常见的恶性肿瘤&#xff0c;2020年全球新发乳腺癌病例约为230万&#xff0c;发病率超过肺癌&#xff0c;位居全部恶性肿瘤首位&#xff01;本文将为您总结乳腺癌的治疗策略与基因检测&#xff0c;揭示个性化医疗的重要意义。 乳腺癌的诊疗 早期乳腺癌通常不…

小程序echarts折线图去除圆圈

如图&#xff0c;默认的折线图上面是有圆圈的&#xff0c;鼠标放上去或者手指触摸的话会有对应的文字出现&#xff0c;但很多时候我们不需要这个圆圈&#xff0c;怎么办呢&#xff0c;其实很简单&#xff0c;只要在 series 中设置属性 showSymbol 为false 就好啦 symbol: none,…

SpringCloud Gateway--Predicate/断言(详细介绍)下

&#x1f600;前言 本篇博文是关于SpringCloud Gateway–Predicate/断言&#xff08;详细介绍&#xff09;下&#xff0c;希望你能够喜欢 &#x1f3e0;个人主页&#xff1a;晨犀主页 &#x1f9d1;个人简介&#xff1a;大家好&#xff0c;我是晨犀&#xff0c;希望我的文章可以…

(三)Python变量类型和运算符

所有的编程语言都支持变量&#xff0c;Python 也不例外。变量是编程的起点&#xff0c;程序需要将数据存储到变量中。 变量在 Python 内部是有类型的&#xff0c;比如 int、float 等&#xff0c;但是我们在编程时无需关注变量类型&#xff0c;所有的变量都无需提前声明&#x…

从C语言到C++:C++入门知识(2)

朋友们、伙计们&#xff0c;我们又见面了&#xff0c;本期来给大家解读一下有关C的基础知识点&#xff0c;如果看完之后对你有一定的启发&#xff0c;那么请留下你的三连&#xff0c;祝大家心想事成&#xff01; C 语 言 专 栏&#xff1a;C语言&#xff1a;从入门到精通 数据结…

云原生之使用Docker部署PDF多功能工具Stirling-PDF

云原生之使用Docker部署PDF多功能工具Stirling-PDF 一、Stirling-PDF介绍1.1 Stirling-PDF简介1.2 Stirling-PDF功能 二、本次实践规划2.1 本地环境规划2.2 本次实践介绍 三、本地环境检查3.1 检查Docker服务状态3.2 检查Docker版本3.3 检查docker compose 版本 四、下载Stirli…

全网最全面最精华的设计模式讲解,从程序员转变为工程师的第一步

前言 现代社会&#xff0c;技术日新月异&#xff0c;要想跟上技术的更新就必须不断学习&#xff0c;而学习技术最有效方式就是阅读优秀的源码&#xff0c;而优秀的源码都不是简单的逻辑堆积&#xff0c;而是有很灵活的设计模式应用其中&#xff0c;如果我们不懂设计模式&#…

移动机器人运动规划 --- 基于图搜索的A*算法

移动机器人运动规划 --- 基于图搜索的A*算法 A*算法A*算法伪代码A* 算法步骤示例A*算法分析启发函数设计 A*应用的更好方式 A*算法 A算法与Dijkstra算法的框架是完全一样的&#xff0c;**A算法就是有启发性的Dijkstra算法** 代价函数&#xff1a;g(n) 表示的是从开始节点到当…

python tempfile模块:生成临时文件和临时目录

嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! python更多源码/资料/解答/教程等 点击此处跳转文末名片免费获取 tempfile 模块专门用于创建临时文件和临时目录&#xff0c;它既可以在 UNIX 平台上运行良好&#xff0c;也可以在 Windows 平台上运行良好。 tempfile 模块中常用…

蓝牙技术|蓝牙在物联网产品上的功能,特别是苹果Find My中的应用

蓝牙技术经历了不同的迭代&#xff0c;引入了新功能和改进。最初的蓝牙版本于1999年推出。低功耗蓝牙(BLE)&#xff0c;也称为蓝牙4.0或蓝牙智能&#xff0c;于2010年发明&#xff0c;旨在最大限度地降低功耗。这使得它非常适合使用电池供电的物联网设备&#xff0c;从而延长电…

私有继承和虚函数私有化能用么?

源起 以前就知道private私有化声明关键字&#xff0c;和virtual虚函数关键字两者并不冲突&#xff0c;可以同时使用。 但是&#xff0c;它所表示的场景没有那么明晰&#xff0c;也觉得难以理解&#xff0c;直到近段时间遇到一个具体场景。 场景 借助ACE遇到的问题进行展示 …

深眸科技入局AI视觉行业,以深度学习赋能视觉应用推进智造升级

随着科技的飞速发展&#xff0c;人工智能技术已经成为改变我们生活的重要力量&#xff0c;而深度学习作为人工智能的一个重要分支&#xff0c;近年来随着卷积神经网络的突破和推广&#xff0c;取得了显著进展&#xff0c;并呈现爆发式增长势头。 目前AI技术已经被迅速引入到机…

数据集笔记:上海摩拜共享单车

2017年8月上海地区摩拜单车的数据&#xff0c;已脱敏处理 订单id、自行车id、用户id、起始时间、起始经纬度、终止时间、终止经纬度、路径 数据地址&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1LqL_VtCfgm3vv-NrVCoTkw 提取码&#xff1a;3d3y

各种不同情景的现场急救方法,正确急救的动作要领与操作步骤

一、教程描述 生活中的现场急救&#xff0c;应该是每个人必备的生活技能&#xff0c;可以成功挽救很多人的生命。本套教程为你讲解在各种不同情景下&#xff0c;针对宝宝、儿童与成人等不同群体&#xff0c;现场急救的操作步骤&#xff0c;正确急救的动作要领&#xff0c;以及…