PyTorch从零开始实现ResNet

news2024/11/24 14:41:06

文章目录

    • 代码实现
    • 参考

代码实现

本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。
在这里插入图片描述
代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用3, 4, 6, 3个基础残差块,每个基础残差块由三个卷积层组成,核大小分别为1x1, 3x3, 1x1 。

残差连接的结构

在这里插入图片描述

复现代码如下:

import torch
import torch.nn as nn

# 基础残差块,后面ResNet要多次重复使用该块
class block(nn.Module):
    def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):
        super(block, self).__init__()
        self.expansion = 4  
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)
        self.relu = nn.ReLU()
        self.identity_downsample = identity_downsample
    
    def forward(self, x):
        identity = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        x += identity
        x = self.relu(x)
        return x
    
class ResNet(nn.Module):
    def __init__(self, block, layers, image_channels, num_classes):
        super(ResNet, self).__init__()
        # 初始化的层
        self.in_channels = 64
        self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # ResNet layers
        self.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)
        self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)
        self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)
        self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512*4, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        return x

	# 核心函数:调用block基础残差块,构造ResNet的每一层
    def _make_layer(self, block, num_residual_blocks, out_channels, stride):
        identity_downsample = None
        layers = []

        if stride != 1 or self.in_channels != out_channels * 4:
            identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels*4, kernel_size=1,
                                                          stride=stride),                                               
                                                nn.BatchNorm2d(out_channels*4))
            
        layers.append(block(self.in_channels, out_channels, identity_downsample, stride))
        self.in_channels = out_channels * 4

        for i in range(num_residual_blocks - 1):
            layers.append(block(self.in_channels, out_channels)) # 256 -> 64, 64*4(256) again

        return nn.Sequential(*layers)
 
 # 构造ResNet50层:默认图像通道3,分类类别为1000
def resnet50(img_channels=3, num_classes=1000):
    return ResNet(block, [3, 4, 6, 3], img_channels, num_classes)

 # 构造ResNet101层  
def resnet101(img_channels=3, num_classes=1000):
    return ResNet(block, [3, 4, 23, 3], img_channels, num_classes)

 # 构造ResNet152层  
def resnet152(img_channels=3, num_classes=1000):
    return ResNet(block, [3, 8, 36, 3], img_channels, num_classes)

# 测试输出y的形状是否满足1000类
def test():
    net = resnet152()
    x = torch.randn(2, 3, 224, 224)
    y = net(x)
    print(y.shape) # [2, 1000]

test()

参考

[1] Deep Residual Learning for Image Recognition
[2] https://www.youtube.com/watch?v=DkNIBBBvcPs&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=19

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/882204.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

利用rollup打包 第三方库 @sentry/browser 为 umd 格式

背景 最近 老项目 多入口打包 引入sentry监控 ,由于不想 在各个入口 都去加sentry 相关逻辑,最后 在 统一的模版文件 html 中 通过 script 标签 引入sentry ,并 初始化操作。 要想保证 script 标签 引入 sentry文件能使用。 需要 保证sentry…

学习ts(一)数据类型(基础类型和任意类型)

运行 起步安装 npm install typescript -g 运行tsc index.ts生成对应的js文件,然后使用node index.js执行js文件 为了方便运行还可以安装插件,ts-node index.ts运行即可 npm i ts-node -g npm init -y npm i types/node -D基本数据类型 // 1.字符…

让光存在,探索光耦继电器的魔力

光耦合器继电器是电路中的无名英雄,正在改变我们实现电气安全和控制的方式。这些卓越的设备(也称为光电耦合器继电器)由于其在电气隔离电路上传输信号和功率的独特能力而在各个行业中广受欢迎。今天,我们深入探讨光耦合器继电器背…

大厂面试篇--2023软件测试八股文最全文档,有它直接大杀四方

前言 已经到了金九银十的黄金招聘季节了,还在准备面试跳槽涨薪的小伙伴们可以看看本篇文章哟,这里呢笔者就不多说废话了直接上干货!答案已整理好,文末拿去即可!非常好用! 一、字节跳动测试面经篇 1、在搜…

【管理运筹学】第 5 章 | 整数规划 (1,问题提出与分支定界法)

文章目录 引言一、整数规划问题的提出1.1 整数规划的数学模型1.2 整数规划问题的求解 二、分支定界法2.1 分支与定界2.2 基本求解步骤(一)初始化(二)分支与分支树(三)定界与剪枝(四)…

正中优配:2023新股上市涨跌幅规则?新股上市涨跌幅限制为几天?

A股与美股不同,股票存在涨跌幅限制,那么,2023新股上市涨跌幅规矩?新股上市涨跌幅限制为几天?下面正中优配为我们预备了相关内容,以供参阅。 2023年新股上市涨跌幅存在以下规矩: 1、主板初次公开…

mock.js引发的报错Corrupted zip: missing xxx bytes

背景: 之前项目没引入mock.js,出于产品要宣传售卖该项目,后端那套服务需要真实场景,和产品经理商量下前端出个假数据的页面,所以复制几个页面mock数据用于产品宣传 首先了解下mock.js Mock 是一种用于模拟数据和行为的…

【QT】 QFileQFileInfo文件操作

很高兴在雪易的CSDN遇见你 ,给你糖糖 欢迎大家加入雪易社区-CSDN社区云 前言 本文分享QT对文件的操作技术,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞关注,小易会继续努力分享,一起进步! 你的点…

恢复已被删除,但是能然有进程调用的文件

当Linux计算机受到入侵时,常见的情况是日志文件被删除,以掩盖攻击者的踪迹。管理错误也可能导致意外删除重要的文件,比如在清理旧日志时,意外地删除了数据库的活动事务日志。有时可以通过lsof来恢复这些文件。 当进程打开了某个文…

【枚举边+MST+组合计数】CF1857G

Problem - 1857G - Codeforces 题意: 思路: 首先观察一下样例: 可以发现对于每一对点,贡献是 s - 这对点对应的环的最大边 1 那么这样就有了 n^2 的做法 然后,根据惯用套路,枚举树上的点对问题可以转…

深入解析美颜SDK:算法、效果与实现

在当今数字化社会中,图像处理和美化技术已经成为了许多应用领域的重要组成部分,尤其在视频直播领域,美颜技术更是无处不在。直播美颜SDK作为一种集成的软件工具包,为开发者和应用提供了强大的美颜功能。 一、算法原理 磨皮算法…

理解RDMA SGL

1. 前言 在使用RDMA操作之前,我们需要了解一些RDMA API中的一些需要的值。其中在ibv_send_wr我们需要一个sg_list的数组,sg_list是用来存放ibv_sge元素,那么什么是SGL以及什么是sge呢?对于一个使用RDMA进行开发的程序员来说&…

Python教程(9)——Python变量类型列表list的用法介绍

列表操作 创建列表访问列表更改列表元素增加列表元素修改列表元素删除列表元素 删除列表 在Python中,列表(list)是一种有序、可变的数据结构,用于存储多个元素。列表可以包含不同类型的元素,包括整数、浮点数、字符串等…

服务器如何防止cc攻击

对于搭载网站运行的服务器来说,cc攻击应该并不陌生,特别是cc攻击的攻击门槛非常低,有个代理IP工具,有个cc攻击软件就可以轻易对任何网站发起攻击,那么服务器如何防止cc攻击?请看下面的介绍。 服务器如何防止cc攻击&a…

只需要自动售货机,商业模式立马大变样!

随着互联网、大数据和人工智能的蓬勃发展,商业模式正以前所未有的方式融合,其中自动售货机作为新零售模式的一颗璀璨明珠,正引领着购物体验的革命。这个巧妙的结合将消费者的便利、数据的智能分析以及科技的创新融为一体,重新定义…

【日常积累】HTTP和HTTPS的区别

背景 在运维面试中,经常会遇到面试官提问http和https的区别,今天咱们先来简单了解一下。 超文本传输协议HTTP被用于在Web浏览器和网站服务器之间传递信息,HTTP协议以明文方式发送内容,不提供任何方式的数据加密,如果…

你的汽车充电桩控制板可能比你的智能手机还要智能?

你是否想过,你的汽车充电桩控制板可能比你的智能手机还要智能?今天我们就来聊聊这个话题。 汽车充电桩控制板的智能性让充电过程更加高效、安全。首先,它具备自检功能,就像你的手机一样,不仅能检查出设备的工作状态,还…

[杂症]PLSQL很卡

PLSQL症状如下: 1、PLSQL隔段时间没用再执行语句时会很卡2、就是很卡,干啥都卡 目前网上的方法汇总如下: 1、Tools >> Preferesnces >> Oracle >> Connection 打开自动连接 勾选检查连接、勾选所有会话,设置成…

正中优配:牛市旗手“又行了”

8月15日早盘,A股首要指数呈弱势盘整态势,截至记者发稿时,沪指小幅翻红,深证成指、创业板指依然飘绿。 中拉升;周一活泼的酒店、旅游板块则震荡调整;房地产板块盘中震荡,体现较弱。 “牛市旗手”…

软件测试简历撰写与优化,让你面试邀约率暴增99%!

如何撰写一份优秀的简历呢??这是一个求职者都会遇到的问题,今天就来详细带大家写一份软件测试工程师职位的简历!希望能给各位软件测试求职者一个带来帮助! 个人简历是求职者给招聘单位发的一份简要介绍。包含自己的基本…