【从零开始学习深度学习】31. 卷积神经网络之残差网络(ResNet)介绍及其Pytorch实现

news2024/11/18 20:46:15

和之前介绍的批量归一化层作用类似,残差网络(ResNet)提出的主要目的也是为了优化深度神经网络中数值稳定性问题。

1. 残差块介绍

假设输入为 x \boldsymbol{x} x,希望学出的理想映射为 f ( x ) f(\boldsymbol{x}) f(x)。下图左右为普通网络结构与加入残差连接的网络对比。右侧是ResNet残差网络的基础块,即残差块(residual block)。在残差块中,输入可通过跨层的数据线路更快地向前传播。

在这里插入图片描述

ResNet网络沿用了VGG全 3 × 3 3\times 3 3×3卷积层的设计。残差块里首先有2个有相同输出通道数的 3 × 3 3\times 3 3×3卷积层。每个卷积层后接一个批量归一化层和ReLU激活函数。然后我们将输入跳过这两个卷积运算后直接加在最后的ReLU激活函数前。这样的设计要求两个卷积层的输出与输入形状一样,从而可以相加。如果想改变通道数,就需要引入一个额外的 1 × 1 1\times 1 1×1卷积层来将输入变换成需要的形状后再做相加运算。

残差块的实现如下。它可以设定输出通道数、是否使用额外的 1 × 1 1\times 1 1×1卷积层来修改通道数以及卷积层的步幅。

import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Residual(nn.Module):  
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return F.relu(Y + X)

下面我们来查看输入和输出形状一致的情况。

blk = Residual(3, 3)
X = torch.rand((4, 3, 6, 6))
blk(X).shape # torch.Size([4, 3, 6, 6])

我们也可以在增加输出通道数的同时减半输出的高和宽。

blk = Residual(3, 6, use_1x1conv=True, stride=2)
blk(X).shape # torch.Size([4, 6, 3, 3])

2. 构建ResNet残差模型

ResNet的前两层跟之前介绍的GoogLeNet中的一样:在输出通道数为64、步幅为2的 7 × 7 7\times 7 7×7卷积层后接步幅为2的 3 × 3 3\times 3 3×3的最大池化层。不同之处在于ResNet每个卷积层后增加的批量归一化层。

net = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
        nn.BatchNorm2d(64), 
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

GoogLeNet在后面接了4个由Inception块组成的模块。ResNet则使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。第一个模块的通道数同输入通道数一致。由于之前已经使用了步幅为2的最大池化层,所以无须减小高和宽。之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。

下面我们来实现这个模块。注意,这里对第一个模块做了特别处理。

def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
    if first_block:
        assert in_channels == out_channels # 第一个模块的通道数同输入通道数一致
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
        else:
            blk.append(Residual(out_channels, out_channels))
    return nn.Sequential(*blk)

接着我们为ResNet加入所有残差块。这里每个模块使用两个残差块。

net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
net.add_module("resnet_block2", resnet_block(64, 128, 2))
net.add_module("resnet_block3", resnet_block(128, 256, 2))
net.add_module("resnet_block4", resnet_block(256, 512, 2))

最后,与GoogLeNet一样,加入全局平均池化层后接上全连接层输出。

net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(512, 10))) 

这里每个模块里有4个卷积层(不计算 1 × 1 1\times 1 1×1卷积层),加上最开始的卷积层和最后的全连接层,共计18层。这个模型通常也被称为ResNet-18。通过配置不同的通道数和模块里的残差块数可以得到不同的ResNet模型,例如更深的含152层的ResNet-152。虽然ResNet的主体架构跟GoogLeNet的类似,但ResNet结构更简单,修改也更方便。这些因素都导致了ResNet迅速被广泛使用。

在训练ResNet之前,我们来观察一下输入形状在ResNet不同模块之间的变化。

X = torch.rand((1, 1, 224, 224))
for name, layer in net.named_children():
    X = layer(X)
    print(name, ' output shape:\t', X.shape)

输出:

0  output shape:	 torch.Size([1, 64, 112, 112])
1  output shape:	 torch.Size([1, 64, 112, 112])
2  output shape:	 torch.Size([1, 64, 112, 112])
3  output shape:	 torch.Size([1, 64, 56, 56])
resnet_block1  output shape:	 torch.Size([1, 64, 56, 56])
resnet_block2  output shape:	 torch.Size([1, 128, 28, 28])
resnet_block3  output shape:	 torch.Size([1, 256, 14, 14])
resnet_block4  output shape:	 torch.Size([1, 512, 7, 7])
global_avg_pool  output shape:	 torch.Size([1, 512, 1, 1])
fc  output shape:	 torch.Size([1, 10])

3. 获取数据和训练ResNet模型

下面我们在Fashion-MNIST数据集上训练ResNet。

batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0015, train acc 0.853, test acc 0.885, time 31.0 sec
epoch 2, loss 0.0010, train acc 0.910, test acc 0.899, time 31.8 sec
epoch 3, loss 0.0008, train acc 0.926, test acc 0.911, time 31.6 sec
epoch 4, loss 0.0007, train acc 0.936, test acc 0.916, time 31.8 sec
epoch 5, loss 0.0006, train acc 0.944, test acc 0.926, time 31.5 sec

4. 总结

  • 残差块通过跨层的数据通道从而能够训练出有效的深度神经网络。

如果文章内容对你有帮助,感谢点赞+关注!

关注下方GZH:阿旭算法与机器学习,可获取更多干货内容~欢迎共同学习交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/125154.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【GO】 K8s 管理系统项目[API部分--Namespace]

K8s 管理系统项目[API部分–Namespace] 1. 接口实现 service/dataselector.go type namespaceCell corev1.Namespacefunc(n namespaceCell) GetCreation() time.Time {return n.CreationTimestamp.Time }func(n namespaceCell) GetName() string {return n.Name }2. Namespa…

景联文科技:赋能智能安防,详谈其中运用到的数据标注类型

“数据显示,2013-2020年我国智能安防行业市场规模由101亿元增长至511亿元。随着智能安防在多个领域的深化应用,预计2023年我国智能安防行业市场规模将超1000亿元。 智能安防领域中,数据标注主要应用于计算机视觉与语音识别两个主要领域,具体…

人口数据可视化,深圳是人口密度最高的城市,东莞上海位居二三名

进入2022年以来,人口问题频频引起热议,人口老龄化、生育意愿再创新低、男女比例失衡等等问题频出。具体的人口问题如何,跟随可视化互动平台的数据可视化大屏一起来了解吧! 我国各省人口数量从地图分布图看,广东省、山…

安装Pytorch

太难了 之前在学校就没安装好 各种报错 终于安装好了 浅浅记录一下 撒花撒花 菜鸡经验: 1.本地python 与 Anaconda 是两个独立的东西 2.可直接在Anaconda中创建不同新的虚拟环境以适配不同的需求 3.cuda 的版本与 NVIDIA版本需要一致,与Python环境也需要…

Echarts图表相关知识

一个基于 JavaScript 的开源可视化图表库。目前我们的前端框架中已经集成了Echarts库v5.3.2),使用的时候不需要再次安装,直接使用即可,具体安装方法不再赘述。 有些时候官网的例子不满足我们的需求,这个时候就要求我们…

cq:fast lookup argument

1. 引言 Ariel Gabizon等人2022年论文《cq: Cached quotients for fast lookups》。 lookup argument的核心思想为: 对于特定的quotient多项式,经某种预处理之后,将更易于计算其commitments。 当前的lookup argument系列方案主要有&#…

实拍视频、图片素材库,高质量、免费下载。

这几个网站的实拍素材,质量高,还可以免费下载。 1、菜鸟图库 https://www.sucai999.com/?vNTYwNDUx 菜鸟图库有超多设计类素材,像平面、UI、电商、办公类等等在这个网站都能找到,网站还有很多实拍视频素材,质量很高&a…

k8s集群部署01

k8s集群部署01Kubernetes简介Kubernetes部署节点部署关于yum缓存提示满了,Rhel7换源解决报错解决过程配置文件内容—要自己看链接是否过期集群初始化Kubernetes-kubectl命令出现错误【The connection to the server localhost:8080 was refused - did you specify t…

git chrry pickup

git chrry pickup目录概述需求:设计思路实现思路分析1.java2.转移分支3.git merge4.cherry pick.切换到 master 分支Cherry pick 操作参考资料和推荐阅读Survive by day and develop by night. talk for import biz , show your perfect code,full busy,…

基于MVC的在线影票售卖系统/基于ASP.NET的电影院售票系统

摘 要 随着电影院规模的不断扩大,人流数量的急剧增加,有关电影院的各种信息量也在不断成倍增长。面对庞大的信息量,就需要有在线影票售卖系统来提高电影院工作的效率。通过这样的系统,我们可以做到信息的规范管理和快速查询&…

副业项目分享,旧衣回收项目怎么做

大家好,我是蝶衣王的小编,今天分享一个简单的项目 我们每个家庭都有多余的旧衣服。许多人会直接把它们扔进垃圾桶。然而,这里隐藏着巨大的商机。说到这里,每个人都应该想到:旧衣服的回收。 事实上,目前国…

《位图布隆过滤器》

【一】位图的概念 位图,就是用每一个比特位来存放某种状态,适用于海量数据,整数,数据无重复的场景,通常是用来判断某个数据存不存在的。例如:10个整数本应该存放四十个字节,此时用位图只需要十…

Ajax(JavaWebAjax、源生Ajax、跨域)

1.JavaWeb - Ajax 概念:AJAX(Asynchronous Java JavaScript And Xml ):异步的JavaScript和Xml AJAX作用: 与服务器进行数据交换:通过AJAX可以给服务器发送请求,并获取服务器响应的数据。 使用…

2022-12-28-面试题整理

1. Spring中Bean创建完成后执行指定代码的几种实现方式 实现ApplicationListener接口 实现ApplicationListener接口并实现方法onApplicationEvent()方法,Bean在创建完成后会执行onApplicationEvent()方法 Component public class DoByApplicationListener impleme…

Java操作redis数据库之读取csv文件

csv文件 要想对某个文件进行具体操作,首先要了解这个文件的结构。csv 全称“Comma-Separated Values”,是一种逗号分隔值格式的文件,是一种用来存储数据的纯文本格式文件。CSV 文件由任意数目的记录组成,记录间以某种换行符分隔&…

FPGA再入门——UART IP核调用

我的工作偏向硬件设计与调试,但是经过几年的发展,发觉不会调程序发展真的很受限制。最近越来越被这种限制折磨的很难受,所以开始学习调调程序。其实,本科与研究生阶段都有过做写代码的经历,算是入过门。但是&#xff0…

[3]ESP32连接MQTT服务端

MQTT库&#xff1a;PubSubClient 连接MQTT服务端 #include <Arduino.h> #include <WiFi.h> #include <PubSubClient.h>const char *ssid "613专属"; const char *password "613613613"; const char *mqttServer "test.ranye-…

CDGA|持续投入开展数据治理工作可以从这四大方向着手

数字化转型趋势下&#xff0c;外部监管以及内部数据使用都对数据治理提出更高效、更准确、更完备、更合规的要求&#xff0c;企业如何抓住新形势下的要求&#xff0c;开展自身数据治理工作&#xff1f; 纵观数据治理的发展历程&#xff0c;剖析数据治理的建设路径&#xff0c;持…

3. 中断向量是( )。 ————计算机组成原理

中断向量是&#xff08; &#xff09; A.子程序入口地址 B.中断向量表的首地址 C.终端服务程序入口地址 D.终端服务入口地址的地址 答案&#xff1a; C 知识点&#xff1a; 终端的概念&#xff1a; 1 机器出现了一些紧急事务&#xff0c;CPU不得不停下当前正在执行的程序&…

SQL经典练习:电脑商店

表结构 本文使用的表结构如下&#xff1a; 以下是创建表的语句&#xff1a; -- 厂商表 CREATE TABLE Manufacturers (Code INTEGER NOT NULL PRIMARY KEY, -- 编号&#xff0c;主键Name VARCHAR(255) NOT NULL, -- 名称 );-- 产品表 CREATE TABLE Products (Code INTEGER NO…