【模型复现】resnet,使用net.add_module()的方法构建模型。小小的改进大大的影响,何大神思路很奇妙,基础很扎实

news2025/2/25 16:03:22
  • 从经验来看,网络的深度对模型的性能至关重要,当增加网络层数后,网络可以进行更加复杂的特征模式的提取,所以当模型更深时理论上可以取得更好的结果。但是更深的网络其性能一定会更好吗?实验发现深度网络出现了退化问题(Degradation problem):网络深度增加时,网络准确度出现饱和,甚至出现下降。在 [1512.03385] Deep Residual Learning for Image Recognition (arxiv.org) 中表明56层的网络比20层网络效果还要差。这不会是过拟合问题,因为56层网络的训练误差同样高。我们知道深层网络存在着梯度消失或者爆炸的问题,这使得深度学习模型很难训练。但是现在已经存在一些技术手段如BatchNorm来缓解这个问题。因此,出现深度网络的退化问题是非常令人诧异的。【AI Talking】CVPR2016 最佳论文, ResNet 现场演讲 - 知乎 (zhihu.com)

  • 深度网络的退化问题【读点论文】Deep Residual Learning for Image Recognition 训练更深的网络_羞儿的博客-CSDN博客至少说明深度网络不容易训练。但是我们考虑这样一个事实:现在你有一个浅层网络,你想通过向上堆积新层来建立深层网络,一个极端情况是这些增加的层什么也不学习,仅仅复制浅层网络的特征,即这样新层是恒等映射(Identity mapping)。在这种情况下,深层网络应该至少和浅层网络性能一样,也不应该出现退化现象。好吧,你不得不承认肯定是目前的训练方法有问题,才使得深层网络很难去找到一个好的参数。后面何凯明提出了残差的结构融合到模型设计中。

  • 残差即观测值与估计值之间的差。假设我们要建立深层网络,当我们不断堆积新的层,但增加的层什么也不学习(极端情况),那么我们就仅仅复制浅层网络的特征,即新层是浅层的恒等映射(Identity mapping),这样深层网络的性能应该至少和浅层网络一样,那么退化问题就得到了解决。

  • 假设要求解的映射为 H(x),也就是观测值,假设上一层 resnet/上一个残差块输出的特征映射为 x(identity function跳跃连接),也就是估计值。那么我们就可以把问题转化为求解网络的残差映射函数 F(x) = H(x) - x。如果网络很深,出现了退化问题,那么我们就只需要让我们的残差映射F(x)等于 0,即要求解的映射 H(x)就等于上一层输出的特征映射 x,因为x是当前输出的最佳解,这样我们这一层/残差块的网络状态仍是最佳的一个状态。但是上面提到的是理想假设,实际情况下残差F(x)不会为0,x肯定是很难达到最优的,但是总会有那么一个时刻它能够无限接近最优解。采用ResNet的话,就只用小小的更新F(x)部分的权重值就行了,可以更好地学到新的特征。

  • ResNet网络是参考了VGG19网络,在其基础上进行了修改,并通过短路机制加入了残差单元,如下图所示。变化主要体现在ResNet直接使用stride=2的卷积做下采样,并且用global average pool层替换了全连接层。ResNet的一个重要设计原则是:当feature map大小降低一半时,feature map的数量增加一倍,这保持了网络层的复杂度。ResNet相比普通网络每两层间增加了短路机制,这就形成了残差学习,其中虚线表示feature map数量发生了改变。下图展示的34-layer的ResNet,还可以构建更深的网络如下表所示。从表中可以看到,对于18-layer和34-layer的ResNet,其进行的两层间的残差学习,当网络更深时,其进行的是三层间的残差学习,三层卷积核分别是1x1,3x3和1x1,一个值得注意的是隐含层的feature map数量是比较小的,并且是输出feature map数量的1/4。

    • 在这里插入图片描述

    • 在这里插入图片描述

  • ResNet使用两种残差单元,如下图所示。左图对应的是浅层网络,而右图对应的是深层网络。对于短路连接,当输入和输出维度一致时,可以直接将输入加到输出上。但是当维度不一致时(对应的是维度增加一倍),这就不能直接相加。有两种策略:(1)采用zero-padding增加维度,此时一般要先做一个downsamp,可以采用strde=2的pooling,这样不会增加参数;(2)采用新的映射(projection shortcut),一般采用1x1的卷积,这样会增加参数,也会增加计算量。短路连接除了直接使用恒等映射,当然都可以采用projection shortcut。你必须要知道CNN模型:ResNet - 知乎 (zhihu.com)

    • 在这里插入图片描述

    • ResNet block有两种,一种两层结构,一种是三层的bottleneck结构,即将两个3x3的卷积层替换为1x1 + 3x3 + 1x1,它通过1x1 conv来巧妙地缩减或扩张feature map维度,从而使得我们的3x3 conv的filters数目不受上一层输入的影响,它的输出也不会影响到下一层。中间3x3的卷积层首先在一个降维1x1卷积层下减少了计算,然后在另一个1x1的卷积层下做了还原。既保持了模型精度又减少了网络参数和计算量,节省了计算时间。

  • 作者对比18-layer和34-layer的网络效果,如下图所示。可以看到普通的网络出现退化现象,但是ResNet很好的解决了退化问题。

    • 在这里插入图片描述
  • 何凯明作者在他的另外一个工作中又对不同的残差单元做了细致的分析与实验,这里直接抛出最优的残差结构,如下图所示。改进前后一个明显的变化是采用pre-activation,BN和ReLU都提前了。而且作者推荐短路连接采用恒等变换,这样保证短路连接不会有阻碍。[1603.05027] Identity Mappings in Deep Residual Networks (arxiv.org)

    • 在这里插入图片描述

    • 改进前后一个明显的变化是采用pre-activation,BN和ReLU都提前了。而且作者推荐短路连接采用恒等变换,这样保证短路连接不会有阻碍。

使用pytorch复现resnet模型

  • 导包,查看设备信息

  • import time
    import torch
    from torch import nn, optim
    import torch.nn.functional as F
    import sys
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(torch.__version__)
    print(device)
    
  • 1.13.1
    cpu
    
  • 模块构建

  • class Residual(nn.Module):  
        def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
            super(Residual, self).__init__()
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
            if use_1x1conv:
                self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
            else:
                self.conv3 = None  # 不做任何操作
            self.bn1 = nn.BatchNorm2d(out_channels)
            self.bn2 = nn.BatchNorm2d(out_channels)
        def forward(self, X):
            Y = F.relu(self.bn1(self.conv1(X)))
            Y = self.bn2(self.conv2(Y))
            if self.conv3:  # 残差部分,有的需要经过1*1的卷积,有的直接连接过来不做任何操作
                X = self.conv3(X)
            return F.relu(Y + X)
    blk = Residual(3, 3)
    X = torch.rand((4, 3, 6, 6))
    print(blk(X).shape)
    blk = Residual(3, 6, use_1x1conv=True, stride=2)
    print(blk(X).shape)
    
  • torch.Size([4, 3, 6, 6])
    torch.Size([4, 6, 3, 3])
    
  • 模型搭建

  • net = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3), # stem
            nn.BatchNorm2d(64), 
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
    class GlobalAvgPool2d(nn.Module):
        # 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现
        def __init__(self):
            super(GlobalAvgPool2d, self).__init__()
        def forward(self, x):
            return F.avg_pool2d(x, kernel_size=x.size()[2:])
    class FlattenLayer(torch.nn.Module):
        def __init__(self):
            super(FlattenLayer, self).__init__()
        def forward(self, x): # x shape: (batch, *, *, ...)
            return x.view(x.shape[0], -1)
    def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
        if first_block:
            assert in_channels == out_channels # 第一个模块的通道数同输入通道数一致
        blk = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
            else:
                blk.append(Residual(out_channels, out_channels))
        return nn.Sequential(*blk)
    net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True)) # 构建网络的另一种方式
    net.add_module("resnet_block2", resnet_block(64, 128, 2))
    net.add_module("resnet_block3", resnet_block(128, 256, 2))
    net.add_module("resnet_block4", resnet_block(256, 512, 2))
    net.add_module("global_avg_pool", GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)
    net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(512, 10))) 
    X = torch.rand((1, 1, 224, 224))
    for name, layer in net.named_children():
        X = layer(X)
        print(name, ' output shape:\t', X.shape)
    
  • 0  output shape:     torch.Size([1, 64, 112, 112])
    1  output shape:     torch.Size([1, 64, 112, 112])
    2  output shape:     torch.Size([1, 64, 112, 112])
    3  output shape:     torch.Size([1, 64, 56, 56])
    resnet_block1  output shape:     torch.Size([1, 64, 56, 56])
    resnet_block2  output shape:     torch.Size([1, 128, 28, 28])
    resnet_block3  output shape:     torch.Size([1, 256, 14, 14])
    resnet_block4  output shape:     torch.Size([1, 512, 7, 7])
    global_avg_pool  output shape:     torch.Size([1, 512, 1, 1])
    fc  output shape:     torch.Size([1, 10])
    
  • 数据加载及模型训练

  • import torchvision
    def evaluate_accuracy(data_iter, net, device=None):
        if device is None and isinstance(net, torch.nn.Module):
            # 如果没指定device就使用net的device
            device = list(net.parameters())[0].device 
        acc_sum, n = 0.0, 0
        with torch.no_grad():
            for X, y in data_iter:
                if isinstance(net, torch.nn.Module):
                    net.eval() # 评估模式, 这会关闭dropout
                    acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                    net.train() # 改回训练模式
                else: 
                    if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数
                        # 将is_training设置成False
                        acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 
                    else:
                        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 
                n += y.shape[0]
        return acc_sum / n
    def train_mnist(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):
        net = net.to(device)
        print("training on ", device)
        loss = torch.nn.CrossEntropyLoss()
        for epoch in range(num_epochs):
            train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
            for X, y in train_iter:
                X = X.to(device)
                y = y.to(device)
                y_hat = net(X)
                l = loss(y_hat, y)
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
                train_l_sum += l.cpu().item()
                train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
                n += y.shape[0]
                batch_count += 1
            test_acc = evaluate_accuracy(test_iter, net)
            print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
                  % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
    def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):
        """Download the fashion mnist dataset and then load into memory."""
        trans = []
        if resize:
            trans.append(torchvision.transforms.Resize(size=resize))
        trans.append(torchvision.transforms.ToTensor())
        transform = torchvision.transforms.Compose(trans)
        mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
        mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)
        if sys.platform.startswith('win'):
            num_workers = 0  # 0表示不用额外的进程来加速读取数据
        else:
            num_workers = 4
        train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        return train_iter, test_iter
    batch_size = 256
    # 如出现“out of memory”的报错信息,可减小batch_size或resize
    train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=96)
    lr, num_epochs = 0.001, 5
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    train_mnist(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
    
  • training on  cpu
    epoch 1, loss 0.4161, train acc 0.847, test acc 0.845, time 1232.7 sec
    epoch 2, loss 0.2490, train acc 0.908, test acc 0.904, time 1202.3 sec
    epoch 3, loss 0.2074, train acc 0.924, test acc 0.905, time 1212.4 sec
    epoch 4, loss 0.1831, train acc 0.932, test acc 0.915, time 1142.8 sec
    epoch 5, loss 0.1570, train acc 0.942, test acc 0.876, time 1127.6 sec
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/413230.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Git的安装与基本使用

Git是一个分布式版本控制工具,可以快速高效地处理从小型到大型的各种项目。 1.Git的安装 官网下载地址 :https://git-scm.com/ 安装过程 选择 Git 安装位置,要求是非中文并且没有空格的目录,然后下一步。 Git 选项配置&#xf…

ChatGPT搭建语音智能助手

环境 python:3 ffmpeg:用于处理视频和语音 gradio:UI界面和读取语音 概述 我们的目的是做一个语音智能助手 下面我们开始 准备工作 下载Visual Studio Code Visual Studio Code 因为需要写python代码,用Visual Studio Code比较方便。 安装pytho…

( “树” 之 DFS) 101. 对称二叉树 ——【Leetcode每日一题】

101. 对称二叉树 给你一个二叉树的根节点 root , 检查它是否轴对称。 示例 1: 输入:root [1,2,2,3,4,4,3] 输出:true 示例 2: 输入:root [1,2,2,null,3,null,3] 输出:false 提示&#xff1a…

webgl-画任意多边形

注意: let canvas document.getElementById(webgl) canvas.width window.innerWidth canvas.height window.innerHeight let radio window.innerWidth/window.innerHeight; let ctx canvas.getContext(webgl) 由于屏幕长宽像素不一样,导致了长宽像素…

移远云服务QuecCloud正式发布,一站式为全球客户提供创新有效的解决方案

4月12日,在“万物智联共数未来”移远通信物联网生态大会上,移远通信宣布正式推出其物联网云服务——QuecCloud。QuecCloud具备智能硬件开发、物联网开放平台、行业解决方案三大能力,可为开发者和企业用户提供从硬件接入到软件应用的全流程解决…

Java 进阶(5) Java IO流

⼀、File类 概念:代表物理盘符中的⼀个⽂件或者⽂件夹。 常见方法: 方法名 描述 createNewFile() 创建⼀个新文件。 mkdir() 创建⼀个新⽬录。 delete() 删除⽂件或空⽬录。 exists() 判断File对象所对象所代表的对象是否存在。 getAbsolute…

4.2 方差

学习目标: 我认为学习方差需要以下几个步骤: 确定学习目标:在开始学习方差之前,需要明确学习的目标和意义,例如,理解方差的定义、掌握方差的计算方法、了解方差在实际问题中的应用等。 学习相关数学概念&…

宝塔Linux面板安装命令脚本大全(Centos/Ubuntu/Debian/Fedora/Deepin)

宝塔面板Linux服务器操作系统安装命令大全,包括Centos、Alibaba Cloud Linux、Ubuntu、TencentOS Server、Deepin、Debian和Fedora安装脚本,云服务器吧分享宝塔面板Linux服务器系统安装命令大全: 目录 宝塔面板Linux系统安装命令 Centos安…

【Vue】学习笔记-事件处理

事件的基本用法 使用v-on:xxx 或xxx 绑定事件,其中xxx是事件名事件的回调需要配置在methods对象中,最终会在vm上methods中配置的函数,不要用箭头函数,否则this就不是vm了methods中配置的函数,都是被vue所管理的函数。…

Pandas库:从入门到应用(三)——多表连接操作

一 、concat数据连接 1.1、concat()函数参数 pd.concat(objs, axis0, joinouter, ignore indexFalse, keysNone, levelsNone, namesNoneverify integrityFalse, sort False, copyTrue)objs:多个 DataFrame 或者 Series axis:0-行拼接 1-列拼接 join&am…

011:Mapbox GL两种方式隐藏logo和版权,个性化版权的声明

第011个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+mapbox中用两种方式隐藏logo和版权,并个性化版权的声明 。 直接复制下面的 vue+mapbox源代码,操作2分钟即可运行实现效果 文章目录 示例效果配置方式示例源代码(共91行)相关API参考:专栏目标示例效果 配置方式…

2023高性价比学生手机选购攻略,预算不多入手这3款超值

学生党在预算不多的情况,想要换颜值高的新手机,应该选什么样的手机才实惠? 手机已经成为生活中的必需品,市场上的手机品牌和型号多种多样,价格逐年攀升,对于预算有限的学生党来说,在保证性能和…

编译原理期末速成笔记

哈喽大家好,又要考试了,在这里分享一下我的两天速成笔记,参考视频为哔站 Deeplei_ 的《编译原理期末速成》。本文仅是知识点总结,至于考试内容待我研究一下,后续我会再发文对考试的各个模块做详细分析,欢迎…

JavaWeb开发 —— Ajax

目录 一、介绍 二、原生Ajax 三、Axios 四、案例分析 一、介绍 ① 概念:Asynchronous JavaScript And XML,异步的JavaScript和XML。 ② 作用: 数据交换:通过Ajax可以给服务器发送请求,并获取服务器响应的数据。…

多元函数的基本概念——“高等数学”

各位CSDN的uu们你们好呀,今天,小雅兰的内容是多元函数的基本概念,下面,让我们一起进入多元函数的世界吧 平面点集 多元函数的概念 多元函数的极限 多元函数的连续性 有界闭区域上多元连续函数的性质 平面点集 第一个是坐标平…

中间表示- 到达定义分析

基本概念 定义(def):对变量的赋值 使用(use):对变量值的读取 问题:能把上图中的y替换为3吗?如果能,这称之为“常量传播”优化。 该问题等价于,有哪些对变量y…

R730服务器热插拔换磁盘(raid阵列)

r730服务器发现磁盘闪橙等,说明磁盘报警了,这时候我们就要换磁盘了。 由于本服务器磁盘是raid5的阵列磁盘,所以要采用热插拔的方式换磁盘。 这边要注意的是,不能关机的时候,直接来换磁盘。 因为关机换磁盘&#xff0c…

golang指针相关

指针相关的部分实在是没有搞太明白,抽时间来总结下。 1.指针相关基础知识 比如现在有一句话:『谜底666』,这句话在程序中一启动,就要加载到内存中,假如内存地址0x123456,然后我们可以将这句话复制给变量A&…

什么是服务架构?微服务架构的优势又是什么?

文章目录1.1 单体架构1.2 分布式架构1.3 微服务架构1.4 单体架构和分布式架构的区分1.4 服务架构的优劣点1.4.1 单体架构1.4.2 分布式架构1.4.3 微服务架构1.5 总结1.1 单体架构 单体架构(Monolithic Architecture)是一种传统的软件架构,它将…

算法学习day56

算法学习day561.力扣583. 两个字符串的删除操作1.1 题目描述1.2分析1.3 代码2.力扣72. 编辑距离2.1 题目描述2.2 分析2.3 代码3.参考资料1.力扣583. 两个字符串的删除操作 1.1 题目描述 题目描述: 给定两个单词word1和word2,找到使得word1和word2相同…