李沐--动手学深度学习 ResNet

news2024/11/13 9:00:10

1.理论

2.残差块

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

#ResNet沿用了VGG完整的3*3卷积层设计.残差块的实现如下:
#此代码生成两种类型的网络:
#一种是当use_1x1conv=False时,应用ReLU非线性函数之前,将输入添加到输出。
#另一种是当use_1x1conv=True时,添加通过1*1卷积调整通道和分辨率。
class Residual(nn.Module):
    def __init__(self,input_channels,num_channels,
                 use_1x1conv = False, strides = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels,num_channels,
                               kernel_size=3,padding=1,stride=strides)
        self.conv2 = nn.Conv2d(num_channels,num_channels,
                               kernel_size=3,padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels,num_channels,
                                   kernel_size=1,stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self,X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

#下面来查看输入和输出形状一致的情况。
b1k = Residual(3,3)
X = torch.rand(4,3,6,6)
Y = b1k(X)
print(Y.shape)
#也可以在增加输出通道数的同时,减半输出的高和宽
b1k = Residual(3,6,use_1x1conv=True,strides=2)
print(b1k(X).shape)

3.ResNet模型

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

#ResNet沿用了VGG完整的3*3卷积层设计.残差块的实现如下:
#此代码生成两种类型的网络:
#一种是当use_1x1conv=False时,应用ReLU非线性函数之前,将输入添加到输出。
#另一种是当use_1x1conv=True时,添加通过1*1卷积调整通道和分辨率。
class Residual(nn.Module):
    def __init__(self,input_channels,num_channels,
                 use_1x1conv = False, strides = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels,num_channels,
                               kernel_size=3,padding=1,stride=strides)
        self.conv2 = nn.Conv2d(num_channels,num_channels,
                               kernel_size=3,padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels,num_channels,
                                   kernel_size=1,stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self,X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

'''
#下面来查看输入和输出形状一致的情况。
b1k = Residual(3,3)
X = torch.rand(4,3,6,6)
Y = b1k(X)
print(Y.shape)
#也可以在增加输出通道数的同时,减半输出的高和宽
b1k = Residual(3,6,use_1x1conv=True,strides=2)
print(b1k(X).shape)
'''

#ResNet的前两层跟之前介绍的GoogLeNet中的一样,在输出通道数为64、步幅为2的7*7卷积层后,接步幅为2的3*3的最大汇聚层
#不同之处在于ResNet每个卷积层后增加了批量规范化层。
b1 = nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),
                   nn.BatchNorm2d(64),nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
#GoogLeNet在后面接了4个由Inception块组成的模块。
#ResNet则使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。
#第一个模块的通道数同输入通道数一致。 由于之前已经使用了步幅为2的最大汇聚层,所以无须减小高和宽。
#之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。
def resnet_block(input_channels,num_channels,num_residuals,
                 first_block = False):
    b1k = []
    for i in range(num_residuals):
       if i == 0 and not first_block:
           b1k.append(Residual(input_channels,num_channels,
                               use_1x1conv=True,strides=2))
       else:
           b1k.append(Residual(num_channels,num_channels))
    return b1k
#接着在ResNet加入所有残差块,这里每个模块使用2个残差块。
b2 = nn.Sequential(*resnet_block(64,64,2,first_block=True))
b3 = nn.Sequential(*resnet_block(64,128,2))
b4 = nn.Sequential(*resnet_block(128,256,2))
b5 = nn.Sequential(*resnet_block(256,512,2))
#最后,与GoogLeNet一样,在ResNet中加入全局平均汇聚层,以及全连接层输出。
net = nn.Sequential(b1,b2,b3,b4,b5,
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten(),nn.Linear(512,10))


#在训练ResNet之前,让我们观察一下ResNet中不同模块的输入形状是如何变化的。
#在之前所有架构中,分辨率降低,通道数量增加,直到全局平均汇聚层聚集所有特征。
X = torch.rand(size=(1,1,224,224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)

4.ResNet模型训练

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

#ResNet沿用了VGG完整的3*3卷积层设计.残差块的实现如下:
#此代码生成两种类型的网络:
#一种是当use_1x1conv=False时,应用ReLU非线性函数之前,将输入添加到输出。
#另一种是当use_1x1conv=True时,添加通过1*1卷积调整通道和分辨率。
class Residual(nn.Module):
    def __init__(self,input_channels,num_channels,
                 use_1x1conv = False, strides = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels,num_channels,
                               kernel_size=3,padding=1,stride=strides)
        self.conv2 = nn.Conv2d(num_channels,num_channels,
                               kernel_size=3,padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels,num_channels,
                                   kernel_size=1,stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self,X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

'''
#下面来查看输入和输出形状一致的情况。
b1k = Residual(3,3)
X = torch.rand(4,3,6,6)
Y = b1k(X)
print(Y.shape)
#也可以在增加输出通道数的同时,减半输出的高和宽
b1k = Residual(3,6,use_1x1conv=True,strides=2)
print(b1k(X).shape)
'''

#ResNet的前两层跟之前介绍的GoogLeNet中的一样,在输出通道数为64、步幅为2的7*7卷积层后,接步幅为2的3*3的最大汇聚层
#不同之处在于ResNet每个卷积层后增加了批量规范化层。
b1 = nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),
                   nn.BatchNorm2d(64),nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
#GoogLeNet在后面接了4个由Inception块组成的模块。
#ResNet则使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。
#第一个模块的通道数同输入通道数一致。 由于之前已经使用了步幅为2的最大汇聚层,所以无须减小高和宽。
#之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。
def resnet_block(input_channels,num_channels,num_residuals,
                 first_block = False):
    b1k = []
    for i in range(num_residuals):
       if i == 0 and not first_block:
           b1k.append(Residual(input_channels,num_channels,
                               use_1x1conv=True,strides=2))
       else:
           b1k.append(Residual(num_channels,num_channels))
    return b1k
#接着在ResNet加入所有残差块,这里每个模块使用2个残差块。
b2 = nn.Sequential(*resnet_block(64,64,2,first_block=True))
b3 = nn.Sequential(*resnet_block(64,128,2))
b4 = nn.Sequential(*resnet_block(128,256,2))
b5 = nn.Sequential(*resnet_block(256,512,2))
#最后,与GoogLeNet一样,在ResNet中加入全局平均汇聚层,以及全连接层输出。
net = nn.Sequential(b1,b2,b3,b4,b5,
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten(),nn.Linear(512,10))


'''
#在训练ResNet之前,让我们观察一下ResNet中不同模块的输入形状是如何变化的。
#在之前所有架构中,分辨率降低,通道数量增加,直到全局平均汇聚层聚集所有特征。
X = torch.rand(size=(1,1,224,224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)
'''

#在Fashion-MNIST数据集上训练ResNet。
lr,num_epochs,batch_size = 0.05,10,256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size,resize=96)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())
d2l.plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2079547.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

96.SAP MII功能详解(09)Workbench-Transaction Debugging

目录 1.About Transaction Debugging Use Features Activities 2.How to Debug Start Debugging Create Breakpoint Watch Variables Debugging logs 1.About Transaction Debugging Use You use this function to monitor and manipulate a transaction while it …

java框架第一课(mybatis认识)

一.关于mybatis 1.mybatis的背景 mybatis原来 是Apache的一个开源项目iBatis,2010年6月这个项目由ApacheSoftware Foundation 迁移到了 Google Code,并改名为mybatis。 2.mybitis的介绍 (1)MyBatis 是一款优秀的持久层框架(在与数据库交互,数据访问层,数据持久层)。…

深入学习SQL优化的第四天

目录 高级查询与连接 1731 每 位 经 理 的 下 属 员 工 数 量 1789 员 工 的 直 属 部 门 610 判 断 三 角 形 180 连 续 出 现 的 数 字 1164 指 定 日 期 的 产 品 价 格 1204 最 后 一 个 能 进 入 巴 士 的 人 1907 按 分 类 统 计 薪 水 子查询 1978 上…

机器学习入门(Datawhale X 李宏毅苹果书 AI夏令营-Task1)

📝本文介绍 本文为作者初探机器学习,读苹果书写下的笔记。 👋作者简介:一个正在积极探索的本科生 📱联系方式:943641266(QQ) 🚪Github地址:https://github.com/sankexilianhua &…

AcWing854. Floyd求最短路

注意&#xff1a;Floyd是求图里面任意两个点x&#xff0c;y之间的最短距离 #include <cstring> #include <iostream> #include <algorithm>using namespace std;const int N 210, INF 1e9;int n, m, Q; int d[N][N];void floyd() { //枚举1~k个中间节点&a…

书生.浦江大模型实战训练营——(十三)茴香豆:企业级知识库问答工具

最近在学习书生.浦江大模型实战训练营&#xff0c;所有课程都免费&#xff0c;以关卡的形式学习&#xff0c;也比较有意思&#xff0c;提供免费的算力实战&#xff0c;真的很不错&#xff08;无广&#xff09;&#xff01;欢迎大家一起学习&#xff0c;打开LLM探索大门&#xf…

国货之光|暴雨信创服务器亮相北京科博会

8月13-16日&#xff0c;由北京市人民政府主办的第二十六届中国北京国际科技产业博览会&#xff08;简称北京科博会&#xff09;在北京国际会议中心成功举办。作为汇聚全球科技创新成果与智慧交流的高端盛会&#xff0c;北京科博会是推动创新发展成果展示的重要舞台。 青海科技展…

html2canvas ios慎用和createImageBitmap ios慎用

好好好&#xff0c;排查几天&#xff0c;原来是你 小本本记下了[翻白眼][翻白眼][翻白眼] ​html2canvas ios慎用&#xff0c;用了记得设置字体 ​2. createImageBitmap ios慎用&#xff0c;14及以下不兼容&#xff0c;建议更换api

Vue3基础2

1.Hooks 就是进行数据的封装&#xff0c;同一种类型的 数据 方法 计算属性 &#xff0c;放在一起 命名规范 use功能名称.ts 或.js 创建一个文件夹 hooks 1.useDog.ts import { reactive,onMounted } from "vue"; import axios from "axios";export def…

Golang | Leetcode Golang题解之第375题猜数字大小II

题目&#xff1a; 题解&#xff1a; func getMoneyAmount(n int) int {f : make([][]int, n1)for i : range f {f[i] make([]int, n1)}for i : n - 1; i > 1; i-- {for j : i 1; j < n; j {f[i][j] j f[i][j-1]for k : i; k < j; k {cost : k max(f[i][k-1], f[…

Linux命令:创建新的目录的工具mkdir命令详解

目录 一、概述 二、语法 1、基本语法 2、常用选项 3、获取帮助 三、示例 1. 创建单个目录 2. 创建多个目录 3. 使用 -p 选项创建多级目录 4. 设置目录权限 5. 显示创建目录的信息 &#xff08;1&#xff09;一般目录创建 &#xff08;2&#xff09;复杂目录创建 …

大数据技术之Flume 企业开发案例——负载均衡和故障转移(6)

目录 负载均衡和故障转移 1&#xff09;案例需求 2&#xff09;需求分析 3&#xff09;实现步骤 负载均衡和故障转移 1&#xff09;案例需求 使用 Flume1 监控一个端口&#xff0c;其 sink 组中的 sink 分别对接 Flume2 和 Flume3&#xff0c;采用 FailoverSinkProcessor…

裁员后的逆袭:程序员变外卖小哥,AI绘画成就全新职业生涯

一、初代程序员的困境 曾几何时&#xff0c;我是一名初代程序员&#xff0c;投身于互联网行业&#xff0c;为我国信息化建设贡献自己的力量。然而&#xff0c;随着年龄的增长和行业竞争的加剧&#xff0c;我不可避免地遭遇了裁员。面对突如其来的变故&#xff0c;我不得不重新审…

Nginx反向代理B

http协议反向代理 反向代理配置参数 proxy_pass; #用来设置将客户端请求转发给的后端服务器的主机 #可以是主机名(将转发至后端服务做为主机头首部)、IP地址&#xff1a;端口的方式 #也可以代理到预先设置的主机群组&#xff0c;需要模块ngx_http_upstream_module支持 #示例:…

机械学习—零基础学习日志(如何理解概率论9)

大数定律与中心定律 来看一道习题&#xff1a; 这个题目看看&#xff0c;应该是什么呢~下一章来看看解析~ 《概率论与数理统计期末不挂科|考研零基础入门4小时完整版&#xff08;王志超&#xff09;》学习笔记 王志超老师 &#xff08;UP主&#xff09;

构造,CF 1290B - Irreducible Anagrams

目录 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 二、解题报告 1、思路分析 2、复杂度 3、代码详解 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 1290B - Irreducible Anagrams 二、解题报告 1、思路分析 首先根据样例特…

系统编程-消息队列

消息队列 目录 消息队列 引入 一、消息队列的特点 二、使用指令查看消息队列 三、使用消息队列进行通信的步骤 1、获取键值 2、创建或获取消息队列 id 3、使用消息队列进行数据的传输 4、msgrcv -- 从消息队列中读取数据 5、消息队列的多种操作函数 引入 -- 进程间…

一、undo log、Buffer Pool、WAL、redo log

目录 1、undo log2、Buffer Pool3、WAL4、redo log5、总结6、问题 1、undo log undo log日志是一种用于撤销回退的逻辑日志&#xff0c;在事务未提交前会记录相反的操作到undo log&#xff0c;当事务回滚&#xff0c;使用undo log 进行回滚&#xff0c;保证了事务的原子性。MV…

Golang 深入理解 GC

垃圾是指程序向堆栈申请的内存空间&#xff0c;随着程序的运行已经不再使用这些内存空间&#xff0c;这时如果不释放他们就会造成垃圾也就是内存泄漏。 垃圾回收 (Garbage Collection&#xff0c;GC) 是编程语言中提供的自动的内存管理机制&#xff0c;自动释放不需要的内存对象…

【Qt】Qt系统 | Qt事件 | 定时器

文章目录 定时器QTimerEventQTimer获取系统日期及时间 定时器 Qt 中在进行窗口程序的处理过程中&#xff0c;经常要周期性的执行某些动作&#xff0c;或者制作一些动画效果&#xff0c;使用定时器可以实现这些需求。 定时器&#xff0c;会在间隔一定时间后&#xff0c;执行某一…