WRN: 宽度残差网络(论文复现)

news2024/10/12 9:34:01

WRN: 宽度残差网络(论文复现)

本文所涉及所有资源均在传知代码平台可获取

文章目录

    • WRN: 宽度残差网络(论文复现)
        • 概述
        • 模型结构
        • 核心逻辑
        • 实验
        • 训练与测试
        • 在线部署
        • 使用方式

概述

本文复现论文 Wide Residual Networks提出的深度神经网络模型。

为了解决深度神经网络梯度消失的问题,深度残差网络(Residual Network[2])被提出。然而,仅为了提高千分之一的准确率,也要将网络的层数翻倍,这使得网络的训练变得非常缓慢。为了解决这些问题,该论文对ResNet基本块的架构进行了改进并提出了一种新颖的架构——宽度残差网络(Wide Residual Network),其减少了深度并增加了残差网络的宽度。

我基于Pytorch复现了该网络并在CIFAR-10[3]、CIFAR-100[3]和SVHN[4]数据集上进行试验。此外,我提供了一个基于SVHN数据集训练的数字识别系统用于体验

模型结构

宽度残差网络共包含四组结构。其中,第一组固定为一个卷积神经网络,第二、三、四组都包含 n 个基本残差块。

基本残差块的结构如图所示

在这里插入图片描述

与普通的残差块不同的地方在于,普通残差块中的批归一化层和激活层都放在卷积层之后,而该论文将批归一化层和激活层都放在卷积层之前,该做法一方面加快了计算,另一方面使得该网络可以不需要用于特征池化的瓶颈层。此外,宽度残差网络成倍地增加了普通残差网络的特征通道数。

宽度残差网络在第三、四组的第一个卷积层进行下采样,即设置卷积步长为2

核心逻辑

Wide Residual Network 的模型代码如下所示

import torch
import torch.nn as nn
import torch.nn.functional as F


class WideBasicBlock(nn.Module):
    """Wide Residual Network的基本单元"""
    def __init__(self, in_channels, out_channels, stride, dropout):
        super(WideBasicBlock, self).__init__()
        self.stride = stride
        # 批归一化层、激活层、卷积层、Dropout层
        self.layers = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        ) 
    
    def forward(self, x):
        out = self.layers(x)
        if self.stride != 1:
            residual = F.adaptive_avg_pool2d(x, (out.size(2), out.size(3)))
        else:
            residual = x
        if out.size(1) != residual.size(1):
            # 对池化和升维的特殊处理
            if out.size(1) % residual.size(1) == 0:
                residual = residual.repeat(1, out.size(1) // residual.size(1), 1, 1)
            else:
                padding = torch.zeros(residual.size(0), out.size(1) - residual.size(1), residual.size(2), residual.size(3)).to(residual.device)
                residual = torch.cat((residual, padding), dim=1)
        out = out + residual
        return out
        
        
    
class WideResidualNetwork(nn.Module):
    """Wide Residual Network"""
    def __init__(self, in_channels, out_channels, depth, width, dropout=0):
        super(WideResidualNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = self.add_block(
            in_channels = 16,
            out_channels = 16 * width,
            depth = depth,
            stride = 1,
            dropout = dropout
        )
        self.conv3 = self.add_block(
            in_channels = 16 * width,
            out_channels = 32 * width,
            depth = depth,
            stride = 2,
            dropout = dropout
        )
        self.conv4 = self.add_block(
            in_channels = 32 * width,
            out_channels = 64 * width,
            depth = depth,
            stride = 2,
            dropout = dropout
        )
        self.linear = nn.Linear(64 * width, out_channels)
        
    def add_block(self, in_channels, out_channels, depth, stride, dropout):
        """添加一个基本单元的组合"""
        layers = nn.Sequential()
        layers.add_module(
            name = '0',
            module = WideBasicBlock(
                in_channels = in_channels, 
                out_channels = out_channels, 
                stride = stride,
                dropout = dropout
            )
        )
        for i in range(1, depth):
            layers.add_module(
                name = str(i),
                module = WideBasicBlock(
                    in_channels = out_channels, 
                    out_channels = out_channels, 
                    stride = 1,
                    dropout = dropout
                )
            )
        return layers
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.linear(out)
        return out
实验
训练与测试

所有实验基于WRN-37-2进行且使用SGD进行优化。对于CIFAR-10和CIFAR-100,学习率为0.01并在第60、120、160轮衰减到20%,dropout采用0.3,weight_decay和momentum分别为0.0005和0.9。对于SVHN,学习率为0.01并在第80、120轮衰减到10%,dropout为0,weight_decay和momentum分别为0.0005和0.9。三个数据集的batch size均为128。

此外,CIFAR-10和CIFAR-100使用了数据增强操作,具体为随机水平翻转和随机裁剪。

具体的实验结果如下表所示

数据集准确率
CIFAR-1094.16%
CIFAR-10074.12%
SVHN96.95%
在线部署

我从网络上随机截取了10张大小、颜色、形状、背景各异的数字图像。这些图片的来源包括:车牌(6、8、9)、扑克牌(3)、广告(1、2、4、5、7)、腰带卡扣(0)。测试结果显示正确率为100%

在这里插入图片描述

使用方式

解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令

unzip Wide-Residual-Networks.zip
cd Wide-Residual-Networks

代码的运行环境可通过如下命令进行配置

pip install -r requirements.txt

如果希望在本地训练模型,请运行如下命令

python main.py -d ['CIFAR-10''CIFAR-100''SVHN'三者其中之一]

如果希望在线部署,请运行如下命令

python main-flask.py

文章代码资源点击附件获取

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2207632.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件狗加密的高安全性

软件狗加密,即使用软件加密狗对软件进行加密保护的过程,是一种软硬件结合的加密方式。以下是对软件狗加密的详细解析: 一、软件加密狗的基本概念 软件加密狗,也称为硬件加密锁或USB密钥,是一种用于保护软件和数据安全的…

IEC104规约的秘密之十----令人眼花缭乱的各种限定词,品质描述词

当我们已经能用104通讯完成各种通讯也能解决帧序号等各种问题后,我们就更加关心报文的细节。 各种报文中的限定词就可以进行仔细分析了。 下面以单点遥信做为例子进行分析: SIQ是英文Single-point information with quality descriptor的缩写&#xff0…

HTML+CSS排行榜实现代码,复制粘贴可使用

如何用HTML和CSS创建一个具有吸引力的创作者排行榜 在数字化时代,排行榜是吸引用户注意的绝佳方式。无论是展示最受欢迎的产品、文章还是创作者,一个设计精良的排行榜都能提升用户的参与度和兴趣。本文将指导你如何使用HTML和CSS创建一个具有吸引力的创…

Rider + xmake DX12 开发环境

Rider xmake DX12 开发环境 背景 如题,想要接近 UE 的开发流程 正文 大的流程就是 xmake 生成 vs 的 sln,用 Rider 进行开发 intellisense,断点调试 加了个脚本手动刷新 sln xmake project -k vsxmake -m "debug;release" -…

msvcr100.dll丢失的解决方法,如何安全下载 msvcr100.dll 文件:完全指南

在使用 Windows 操作系统的电脑上运行某些程序或游戏时,可能会遇到一个常见的错误消息,提示缺少 msvcr100.dll 文件。这个 DLL 文件是 Microsoft Visual C 2010 Redistributable Package 的一部分,对于运行依赖于 C 的软件来说至关重要。如果…

Linux等保测评与加固

Linux三级系统测评及加固方法 身份鉴别 应对登录的用户进行身份标识和鉴别,身份标识具有唯一性,身份鉴别信息具有复杂度要求并定期更换 测评方法: ①一般采用用户名口令进行身份鉴别,身份标识具有唯一性无法创建相同用户名 通…

WPF 手撸插件 八 操作数据库一

1、本文将使用SqlSugar创建Sqlite数据库,进行入门的增删改查等操作。擦,咋写着写着凌乱起来了。 SqlSugar官方文档:简单示例,1分钟入门 - SqlSugar 5x - .NET果糖网 2、环境SqlSugar V5.0版本需要.Net Framework 4.6 &#xff0…

MySQL 创建子账号

1. 使用 root 账号登录 MySQL 使用 root 账号登录 MySQL,登录成功如图所示: 新建一个 MySQL 子账号,新建子账号命令如下: 命令 : CREATE USER testlocalhost IDENTIFIED BY 123456;若出现如下图所示,则表示新建 MySQL…

技术总结(三)

Checked Exception 和 Unchecked Exception 有什么区别? Checked Exception 即 受检查异常 ,Java 代码在编译过程中,如果受检查异常没有被 catch或者throws 关键字处理的话,就没办法通过编译。 比如下面这段 IO 操作的代码&…

设计模式---责任链模式快速demo

Handler(处理者): 定义一个处理请求的接口。通常包括一个处理请求的方法。它可以是抽象类或接口,也可以是具体类,具体类中包含了对请求的处理逻辑。处理者通常包含一个指向下一个处理者的引用。ConcreteHandler&#x…

esp32-c3 Supermini 模块下载

1.此模块自带usb 功能,可以直接利用数据线连接模块与pc进行下载。此模块不带uart to usb 集成块。 2. 此模块下载只能用自带type c 数据口与pc usb 连接进行。不能用usb 转 uart 对模块下载,但可以通讯 3. 利用idf.py 对模块下载前,必…

Thread类的用法练习

目录 1.继承 Thread, 重写 run 2.实现 Runnable, 重写 run 3.继承 Thread, 重写 run, 使用匿名内部类 4.实现 Runnable, 重写 run, 使用匿名内部类 5.使用 lambda 表达式 6.请回答以下代码的输出, 并解释原因 1.继承 Thread, 重写 run 2.实现 Runnable, 重写 run 3.继承 Th…

四、远程登录到Linux服务器

说明 linux 服务器是开发小组共享,正式上线的项目是运行在公网,因此需要远程登录到 Linux 进行项目管理或者开发 Xshell 1、特点 Xshell 是目前最好的远程登录到 Linux 操作的软件,流畅的速度并且完美解决了中文乱码的问题, 是目…

计算机毕业设计 基于Python+Django的旅游景点数据分析与推荐系统的设计与实现 Python毕业设计 Python毕业设计选题【附源码+安装调试】

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

黑马程序员C++提高编程学习笔记

黑马程序员C提高编程 提高阶段主要针对泛型编程和STL技术 文章目录 黑马程序员C提高编程一、模板1.1 函数模板1.1.1 函数模板基础知识 案例一: 数组排序1.2.1 普通函数与函数模板1.2.2 函数模板的局限性 1.2 类模板1.2.1 类模板的基础知识1.2.2 类模板与函数模板1.…

ssm基于SSM框架的餐馆点餐系统的设计+VUE

系统包含:源码论文 所用技术:SpringBootVueSSMMybatisMysql 免费提供给大家参考或者学习,获取源码请私聊我 需要定制请私聊 目 录 摘要 I Abstract II 1绪论 1 1.1研究背景与意义 1 1.1.1研究背景 1 1.1.2研究意义 1 1.2国内外研究…

有什么方法可以保护ppt文件不被随意修改呢?

在工作或学习中,我们常常需要制作powerpoint演示文稿,担心自己不小心改动了或者不想他人随意更改,我们可以如何保护PPT呢?下面小编就来分享两个常用的方法。 方法一:为PPT设置打开密码 为PPT设置打开密码是最直接有效…

Prim算法实现最小生成树

Prim算法是一种用来寻找图的最小生成树的贪心算法。最小生成树是连接图中所有顶点的边的子集,这些边的权重总和最小,且形成一个树形结构,包含所有顶点。 Prim算法的基本步骤如下: 初始化: 选择任意一个顶点作为起始点…

全栈开发要掌握的技术

文章目录 1、前端开发2、后台开发2.1 编程语言2.2 网络框架 3、数据库开发3.1 RDBMS3.2 NoSQL 数据库 4、移动开发4.1 本地开发4.2 跨平台开发 5、云计算5.1 云平台5.2 容器化和协调 6、用户界面/用户体验设计6.1 设计工具6.2 原型和线框图 7、基础设施和 DevOps7.1 基础设施即…

C语言读取data.json文件并存入MySQL数据库小案例

本地有一个data.json文件 data.json [{"id": 1,"name": "Alice","age": 30},{"id": 2,"name": "Bob","age": 25} ]要将 data.json 文件中的数据存储到 MySQL 数据库中,首先需要…