分类神经网络2:ResNet模型复现

news2024/12/23 23:12:00

目录

ResNet网络架构

ResNet部分实现代码


ResNet网络架构

论文原址:https://arxiv.org/pdf/1512.03385.pdf

残差神经网络(ResNet)是由微软研究院的何恺明、张祥雨、任少卿、孙剑等人提出的,通过引入残差学习解决了深度网络训练中的退化问题,同时该网络结构特点主要表现为3点,超深的网络结构(超过1000层)、提出residual(残差结构)模块和使用Batch Normalization 加速训练(丢弃dropout)。ResNet网络的模型结构如下:

ResNet网络通过“跳跃连接”,将靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入部分。即神经网络学习到该层的参数是冗余的时候,它可以选择直接走这条“跳接”曲线(shortcut connection),跳过这个冗余层,而不需要再去拟合参数使得H(x)=F(x)=x。同时通过这种连接方式不仅保护了信息的完整性(避免卷积层堆叠存在的信息丢失),整个网络也只需要学习输入、输出差别的部分,这克服了由于网络深度加深而产生的学习效率变低与准确率无法有效提升的问题(梯度消失或梯度爆炸)。

残差模块如下图示:

残差结构有两种,常规残差和瓶颈残差常规残差:由2个3x3卷积层堆叠而成,当输入和输出维度一致时,可以直接将输入加到输出上,这相当于简单执行了同等映射,不会产生额外的参数,也不会增加计算复杂度(随着网络深度的加深,这种残差模块在实践中并不十分有效);瓶颈残差:依次由1x1 、3x3 、1x1个卷积层构成,这里1x1卷积,能够对通道数channel起到升维或降维的作用,从而令3x3 的卷积,以相对较低维度的输入进行卷积运算,提高计算效率。

ResNet网络的具体配置信息如下:

在构建神经网络时,首先采用了步长为2的卷积层进行图像尺寸缩减,即下采样操作,紧接着是多个残差结构,在网络架构的末端,引入了一个全局平均池化层,用于整合特征信息,最后是一个包含1000个类别的全连接层,并在该层后应用了softmax激活函数以进行多分类任务。值得注意的是,通过引入残差连接模块,其最深的网络结构达到了152层,同时在50层后均使用的是瓶颈残差结构。

ResNet部分实现代码

老样子,直接上代码,建议大家阅读代码时结合网络结构理解

import torch
import torch.nn as nn

__all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']

class ConvBNReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class ConvBN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBN, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(num_features=out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

def conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding:捕捉局部特征和空间相关性,学习更复杂的特征和抽象表示"""
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_channels, out_channels, stride=1):
    """1x1 convolution:实现降维或升维,调整通道数和执行通道间的线性变换"""
    return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.convbnrelu1 = ConvBNReLU(in_channels, out_channels, kernel_size=3, stride=stride)
        self.convbn1 = ConvBN(out_channels, out_channels, kernel_size=3)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.conv_down = nn.Sequential(
            conv1x1(in_channels, out_channels * self.expansion, self.stride),
            nn.BatchNorm2d(out_channels * self.expansion),
        )

    def forward(self, x):
        residual = x
        out = self.convbnrelu1(x)
        out = self.convbn1(out)

        if self.downsample:
            residual = self.conv_down(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        groups = 1
        base_width = 64
        dilation = 1

        width = int(out_channels * (base_width / 64.)) * groups   # wide = out_channels
        self.conv1 = conv1x1(in_channels, width)       # 降维通道数
        self.bn1 = nn.BatchNorm2d(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = nn.BatchNorm2d(width)
        self.conv3 = conv1x1(width, out_channels * self.expansion)   # 升维通道数
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.conv_down = nn.Sequential(
            conv1x1(in_channels, out_channels * self.expansion, self.stride),
            nn.BatchNorm2d(out_channels * self.expansion),
        )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample:
            residual = self.conv_down(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64):
        super(ResNet, self).__init__()

        self.inplanes = 64
        self.dilation = 1
        replace_stride_with_dilation = [False, False, False]
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        downsample = False
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = True

        layers = nn.ModuleList()
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return layers

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        for layer in self.layer1:
            x = layer(x)
        for layer in self.layer2:
            x = layer(x)
        for layer in self.layer3:
            x = layer(x)
        for layer in self.layer4:
            x = layer(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

def resnet18(num_classes, **kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs)

def resnet34(num_classes, **kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, **kwargs)

def resnet50(num_classes, **kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, **kwargs)

def resnet101(num_classes, **kwargs):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, **kwargs)

def resnet152(num_classes, **kwargs):
    return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, **kwargs)


if __name__=="__main__":
    import torchsummary
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    input = torch.ones(2, 3, 224, 224).to(device)
    net = resnet50(num_classes=4)
    net = net.to(device)
    out = net(input)
    print(out)
    print(out.shape)
    torchsummary.summary(net, input_size=(3, 224, 224))
    # Total params: 134,285,380

希望对大家能够有所帮助呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1613196.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[Spring Cloud] (4)搭建Vue2与网关、微服务通信并配置跨域

文章目录 前言gatway网关跨域配置取消微服务跨域配置 创建vue2项目准备一个原始vue2项目安装vue-router创建路由vue.config.js配置修改App.vue修改 添加接口访问安装axios创建request.js创建index.js创建InfoApi.js main.jssecurityUtils.js 前端登录界面登录消息提示框 最终效…

【八股】Spring Boot

SpringBoot是如何实现自动装配的? 首先,SpringBoot的核心注解SpringBootApplication里面包含了三个注解,SpringBootConfigurationEnableAutoConfigurationComponentScan,其中EnableAutoConfiguration是实现自动装配的注解&#x…

VUE运行找不到pinia模块

当我们的VUE运行时报错Module not found: Error: Cant resolve pinia in时 当我们出现这个错误时 可能是 没有pinia模块 此时我们之要下载一下这个模块就可以了 npm install pinia

AD高速板设计-DDR(笔记)

【一】二极管 最高工作频率: 定义:二极管的最高工作频率,即二极管在电路中能够正常工作的最高频率。常见的硅二极管的最高工作频率通常在几十MHz到几百MHz之间。在高频下,二极管可能无法有效地阻止反向电流,但也不会…

C# WPF布局

布局&#xff1a; 1、Grid: <Window x:Class"WpfApp2.MainWindow" xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation" xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml" xmlns:d"http://schemas.microsoft.com…

【大语言模型LLM】-使用大语言模型搭建点餐机器人

关于作者 行业&#xff1a;人工智能训练师/LLM 学者/LLM微调乙方PM发展&#xff1a;大模型微调/增强检索RAG分享国内大模型前沿动态&#xff0c;共同成长&#xff0c;欢迎关注交流… 大语言模型LLM基础-系列文章 【大语言模型LLM】-大语言模型如何编写Prompt?【大语言模型LL…

地质图、地质岩性数据、地质灾害分布、土壤理化性质数据集、土地利用数据、土壤重金属含量分布、植被类型分布

地质图是将沉积岩层、火成岩体、地质构造等的形成时代和相关等各种地质体、地质现象&#xff0c;用一定图例表示在某种比例尺地形图上的一种图件。 是表示地壳表层岩相、岩性、地层年代、地质构造、岩浆活动、矿产分布等的地图的总称。 地质图的编制多以实测资料为基础&#xf…

Eclipse+Java+Swing实现学生信息管理系统-TXT存储信息

一、系统介绍 1.开发环境 操作系统&#xff1a;Win10 开发工具 &#xff1a;Eclipse2021 JDK版本&#xff1a;jdk1.8 存储方式&#xff1a;Txt文件存储 2.技术选型 JavaSwingTxt 3.功能模块 4.工程结构 5.系统功能 1.系统登录 管理员可以登录系统。 2.教师-查看学生…

初学者如何选择ARM开发硬件?

1&#xff0e; 如果你有做硬件和单片机的经验,建议自己做个最小系统板&#xff1a;假如你从没有做过ARM的开发&#xff0c;建议你一开始不要贪大求全&#xff0c;把所有的应用都做好&#xff0c;因为ARM的启动方式和dsp或单片机有所不同&#xff0c;往往会碰到各种问题&#xf…

【天龙怀旧服】攻略day7

关键字&#xff1a; 新星1.49、金针渡劫、10灵 1】新星&#xff08;苍山破煞&#xff09; 周三周六限定副本&#xff0c;19.00-24.00 通常刷1.49w&#xff0c;刷149点元佑碎金 boss选择通常为狂鬼难度&#xff0c;八风不动即放大不选&#xff0c;第二排第一个也不选&#xf…

【Hadoop】- MapReduce YARN的部署[8]

目录 一、部署说明 二、集群规划 三、MapReduce配置文件 四、YARN配置文件 五、分发配置文件 六、集群启动命令 七、查看YARN的WEB UI 页面 一、部署说明 Hadoop HDFS分布式文件系统&#xff0c;我们会启动&#xff1a; NameNode进程作为管理节点DataNode进程作为工作节…

lua整合redis

文章目录 lua基础只适合lua连接操作redis1.下载lua依赖2.导包,连接3.常用的命令1.set,get,push命令 2.自增管道命令命令集合4.使用redis操作lua1.实现秒杀功能synchronized关键字 分布式锁 lua 基础只适合 1.编译 -- 编译 luac a.lua -- 运行 lua a.lua2.命名规范 -- 多行注…

【Hadoop】- MapReduce YARN 初体验[9]

目录 提交MapReduce程序至YARN运行 1、提交wordcount示例程序 1.1、先准备words.txt文件上传到hdfs&#xff0c;文件内容如下&#xff1a; 1.2、在hdfs中创建两个文件夹&#xff0c;分别为/input、/output 1.3、将创建好的words.txt文件上传到hdfs中/input 1.4、提交MapR…

Dynamic Wallpaper for Mac激活版:视频动态壁纸软件

Dynamic Wallpaper for Mac 是一款为Mac电脑量身打造的视频动态壁纸应用&#xff0c;为您的桌面带来无限生机和创意。这款应用提供了丰富多样的视频壁纸选择&#xff0c;涵盖了自然风景、抽象艺术、科幻奇观等多种主题&#xff0c;让您的桌面成为一幅活生生的艺术画作。 Dynami…

ES中文检索须知:分词器与中文分词器

ElasticSearch (es)的核心功能即为数据检索&#xff0c;常被用来构建内部搜索引擎或者实现大规模数据在推荐召回流程中的粗排过程。 ES分词 分词即为将doc通过Analyzer切分成一个一个Term&#xff08;关键字&#xff09;&#xff0c;es分词在索引构建和数据检索时均有体现&…

(避雷指引:管理页面超时问题)windows下载安装RabbitMQ

一、背景&#xff1a; 学习RabbitMQ过程中&#xff0c;由于个人电脑性能问题&#xff0c;直接装在windows去使用RabbitMQ&#xff0c;根据各大网友教程&#xff0c;去下载安装完之后&#xff0c;使用web端进行简单的入门操作时&#xff0c;总是一直提示超时&#xff0c;要么容…

【项目】仿muduo库One Thread One Loop式主从Reactor模型实现高并发服务器(TcpServer板块)

【项目】仿muduo库One Thread One Loop式主从Reactor模型实现⾼并发服务器&#xff08;TcpServer板块&#xff09; 一、思路图二、模式关系图三、定时器的设计1、Linux本身给我们的定时器2、我们自己实现的定时器&#xff08;1&#xff09;代码部分&#xff08;2&#xff09;思…

图论——基础概念

文章目录 学习引言什么是图图的一些定义和概念图的存储方式二维数组邻接矩阵存储优缺点 数组模拟邻接表存储优缺点 边集数组优缺点排序前向星优缺点链式前向星优缺点 学习引言 图论&#xff0c;是 C 里面很重要的一种算法&#xff0c;今天&#xff0c;就让我们一起来了解一下图…

使用docker搭建GitLab个人开发项目私服

一、安装docker 1.更新系统 dnf update # 最后出现这个标识就说明更新系统成功 Complete!2.添加docker源 dnf config-manager --add-repohttps://download.docker.com/linux/centos/docker-ce.repo # 最后出现这个标识就说明添加成功 Adding repo from: https://download.…

【数据结构】顺序表:与时俱进的结构解析与创新应用

欢迎来到白刘的领域 Miracle_86.-CSDN博客 系列专栏 数据结构与算法 先赞后看&#xff0c;已成习惯 创作不易&#xff0c;多多支持&#xff01; 目录 一、数据结构的概念 二、顺序表&#xff08;Sequence List&#xff09; 2.1 线性表的概念以及结构 2.2 顺序表分类 …