【深度学习笔记】稠密连接网络(DenseNet)

news2024/10/5 17:20:13

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

5.12 稠密连接网络(DenseNet)

ResNet中的跨层连接设计引申出了数个后续工作。本节我们介绍其中的一个:稠密连接网络(DenseNet) [1]。 它与ResNet的主要区别如图5.10所示。

在这里插入图片描述

图5.10 ResNet(左)与DenseNet(右)在跨层连接上的主要区别:使用相加和使用连结

图5.10中将部分前后相邻的运算抽象为模块 A A A和模块 B B B。与ResNet的主要区别在于,DenseNet里模块 B B B的输出不是像ResNet那样和模块 A A A的输出相加,而是在通道维上连结。这样模块 A A A的输出可以直接传入模块 B B B后面的层。在这个设计里,模块 A A A直接跟模块 B B B后面的所有层连接在了一起。这也是它被称为“稠密连接”的原因。

DenseNet的主要构建模块是稠密块(dense block)和过渡层(transition layer)。前者定义了输入和输出是如何连结的,后者则用来控制通道数,使之不过大。

5.12.1 稠密块

DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构,我们首先在conv_block函数里实现这个结构。

import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def conv_block(in_channels, out_channels):
    blk = nn.Sequential(nn.BatchNorm2d(in_channels), 
                        nn.ReLU(),
                        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
    return blk

稠密块由多个conv_block组成,每块使用相同的输出通道数。但在前向计算时,我们将每块的输入和输出在通道维上连结。

class DenseBlock(nn.Module):
    def __init__(self, num_convs, in_channels, out_channels):
        super(DenseBlock, self).__init__()
        net = []
        for i in range(num_convs):
            in_c = in_channels + i * out_channels
            net.append(conv_block(in_c, out_channels))
        self.net = nn.ModuleList(net)
        self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结
        return X

在下面的例子中,我们定义一个有2个输出通道数为10的卷积块。使用通道数为3的输入时,我们会得到通道数为 3 + 2 × 10 = 23 3+2\times 10=23 3+2×10=23的输出。卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growth rate)。

blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape # torch.Size([4, 23, 8, 8])

5.12.2 过渡层

由于每个稠密块都会带来通道数的增加,使用过多则会带来过于复杂的模型。过渡层用来控制模型复杂度。它通过 1 × 1 1\times1 1×1卷积层来减小通道数,并使用步幅为2的平均池化层减半高和宽,从而进一步降低模型复杂度。

def transition_block(in_channels, out_channels):
    blk = nn.Sequential(
            nn.BatchNorm2d(in_channels), 
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.AvgPool2d(kernel_size=2, stride=2))
    return blk

对上一个例子中稠密块的输出使用通道数为10的过渡层。此时输出的通道数减为10,高和宽均减半。

blk = transition_block(23, 10)
blk(Y).shape # torch.Size([4, 10, 4, 4])

5.12.3 DenseNet模型

我们来构造DenseNet模型。DenseNet首先使用同ResNet一样的单卷积层和最大池化层。

net = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
        nn.BatchNorm2d(64), 
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

类似于ResNet接下来使用的4个残差块,DenseNet使用的是4个稠密块。同ResNet一样,我们可以设置每个稠密块使用多少个卷积层。这里我们设成4,从而与上一节的ResNet-18保持一致。稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。

ResNet里通过步幅为2的残差块在每个模块之间减小高和宽。这里我们则使用过渡层来减半高和宽,并减半通道数。

num_channels, growth_rate = 64, 32  # num_channels为当前的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]

for i, num_convs in enumerate(num_convs_in_dense_blocks):
    DB = DenseBlock(num_convs, num_channels, growth_rate)
    net.add_module("DenseBlosk_%d" % i, DB)
    # 上一个稠密块的输出通道数
    num_channels = DB.out_channels
    # 在稠密块之间加入通道数减半的过渡层
    if i != len(num_convs_in_dense_blocks) - 1:
        net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))
        num_channels = num_channels // 2

同ResNet一样,最后接上全局池化层和全连接层来输出。

net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10))) 

我们尝试打印每个子模块的输出维度确保网络无误:

X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():
    X = layer(X)
    print(name, ' output shape:\t', X.shape)

输出:

0  output shape:	 torch.Size([1, 64, 48, 48])
1  output shape:	 torch.Size([1, 64, 48, 48])
2  output shape:	 torch.Size([1, 64, 48, 48])
3  output shape:	 torch.Size([1, 64, 24, 24])
DenseBlosk_0  output shape:	 torch.Size([1, 192, 24, 24])
transition_block_0  output shape:	 torch.Size([1, 96, 12, 12])
DenseBlosk_1  output shape:	 torch.Size([1, 224, 12, 12])
transition_block_1  output shape:	 torch.Size([1, 112, 6, 6])
DenseBlosk_2  output shape:	 torch.Size([1, 240, 6, 6])
transition_block_2  output shape:	 torch.Size([1, 120, 3, 3])
DenseBlosk_3  output shape:	 torch.Size([1, 248, 3, 3])
BN  output shape:	 torch.Size([1, 248, 3, 3])
relu  output shape:	 torch.Size([1, 248, 3, 3])
global_avg_pool  output shape:	 torch.Size([1, 248, 1, 1])
fc  output shape:	 torch.Size([1, 10])

5.12.4 获取数据并训练模型

由于这里使用了比较深的网络,本节里我们将输入高和宽从224降到96来简化计算。

batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0020, train acc 0.834, test acc 0.749, time 27.7 sec
epoch 2, loss 0.0011, train acc 0.900, test acc 0.824, time 25.5 sec
epoch 3, loss 0.0009, train acc 0.913, test acc 0.839, time 23.8 sec
epoch 4, loss 0.0008, train acc 0.921, test acc 0.889, time 24.9 sec
epoch 5, loss 0.0008, train acc 0.929, test acc 0.884, time 24.3 sec

小结

  • 在跨层连接上,不同于ResNet中将输入与输出相加,DenseNet在通道维上连结输入与输出。
  • DenseNet的主要构建模块是稠密块和过渡层。

参考文献

[1] Huang, G., Liu, Z., Weinberger, K. Q., & van der Maaten, L. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (Vol. 1, No. 2).


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1500384.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WiFi模块助力少儿编程:创新学习与实践体验

随着科技的飞速发展,少儿编程已经成为培养孩子们创造力和问题解决能力的重要途径之一。在这个过程中,WiFi模块的应用为少儿编程领域注入了新的活力,使得学习编程不再是单一的代码教学,而是一个充满创新与实践的综合性体验。 物联网…

E4991A 射频阻抗/材料分析仪

新利通 E4991A 射频阻抗/材料分析仪 —— 1 MHz到 3 GHz —— 简述 E4991A 射频阻抗/材料分析提供极限阻抗测量性能和功能强大的内置分析功能。它将为元器件和电路设计人员测量 3 GHz 以内的元器件提供创新功能,帮助他们进行研发工作。与反射测量技术不同&#x…

Java对接腾讯云直播示例

首先是官网的文档地址 云直播 新手指南 可以发现它这个主要是按流量和功能收费的 价格总览 流量这里还只收下行的费用,就是只收观看消耗的流量费 其它的收费就是一些增值业务费 (包括直播转码、直播录制、直播截图、直播审核、智能鉴黄、实时监播、移动直…

软件测试面试200问(附100W字文档)

🍅 视频学习:文末有免费的配套视频可观看 🍅 关注公众号【互联网杂货铺】,回复 1 ,免费获取软件测试全套资料,资料在手,涨薪更快 软件测试面试题:项目 1、简单介绍下最近做过的项目…

centos7 python3.12.1 报错 No module named _ssl

https://blog.csdn.net/Amio_/article/details/126716818 安装python cd /usr/local/src wget https://www.python.org/ftp/python/3.12.1/Python-3.12.1.tgz tar -zxvf Python-3.12.1.tgz cd Python-3.12.1/ ./configure -C --enable-shared --with-openssl/usr/local/opens…

问题解决 | vscode无法连接服务器而ssh和sftp可以

解决步骤 进入家目录删除.vscode-server rm -rf .vscode-server 然后再次用vscode连接服务器时,会重新安装,这时可能报出一些缺少依赖的错 需要联系管理员安装相关依赖,比如 sudo apt-get install libstdc6 至此问题解决

Vulnhub内网渗透Jangow01靶场通关

详细请见个人博客 靶场下载地址。 下载下来后是 .vmdk 格式,vm直接导入。 M1请使用UTM进行搭建,教程见此。该靶场可能出现网络问题,解决方案见此 信息搜集 arp-scan -l # 主机发现ip为 192.168.168.15 nmap -sV -A -p- 192.168.168.15 # 端…

多功能线缆光纤验证器-AEM CV-100

TestPro CV100 多功能电缆验证器 屡获殊荣的 TestPro CV100 多功能电缆验证器专为当今的现代智能建筑网络基础设施而设计。 它提供了当今可用的功能最丰富的测试平台,以及允许定制所需的确切测试套件的基于模块化的平台。 智能建筑测试套件(K60 和 K61…

BUUCTF-Misc3

LSB1 1.打开附件 得到一张图片,像是某个大学的校徽 2.Stegsolve工具 根据标题LSB,可能是LSB隐写 放到Stegsolve中,点Analyse在点Data Extract 数据提取 因为是LSB隐写,发现含以.png结尾的图片 3.保存图片 4.得到flag 扫描二维…

蓝桥杯-最长递增

思路及代码详解:(此题为容易题) #include <iostream> using namespace std; int main() {int a[1000]{0};int n,temp;int num0;int count0;cin>>n;for(int i0;i<n;i){cin>>a[i];}//输入数据tempa[0];//设置一个临时比较的存储变量for(int i1;i<n;i){i…

性别和年龄的视频实时监测项目

注意&#xff1a;本文引用自专业人工智能社区Venus AI 更多AI知识请参考原站 &#xff08;[www.aideeplearning.cn]&#xff09; 性别和年龄检测 Python 项目 首先介绍性别和年龄检测的高级Python项目中使用的专业术语 什么是计算机视觉&#xff1f; 计算机视觉是使计算机能…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:DataPanel)

数据面板组件&#xff0c;用于将多个数据占比情况使用占比图进行展示。 说明&#xff1a; 该组件从API Version 7开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 子组件 无 接口 DataPanel(options: DataPanelOptions) 从API version …

MySQL基础-----约束

目录 前言 一、概述 二、约束演示 三、外键约束 1.介绍 2.语法 四、删除/更新行为 1.CASCADE 2.SET NULL 前言 本期我们开始MySQL约束的学习&#xff0c;约束一般是只数据键对本条数据的约束&#xff0c;通过约束我们可以保证数据库中数据的正确、有效性和完整性。 下面…

仿牛客项目Day1

SpringMVC 架构 spring的前端控制器是DispatcherServlet 模板引擎Thymeleaf 这个还不知道干嘛的 mvc演示 get请求 RequestMapping&#xff1a;声明访问路径和http方法get或set什么的 ResponseBody&#xff1a;java对象转为json格式的数据&#xff0c;表示该方法的返回结…

原创+顶级SCI优化!23年新算法PSA优化CNN-LSTM-Attention一键实现多变量回归预测!

声明&#xff1a;文章是从本人公众号中复制而来&#xff0c;因此&#xff0c;想最新最快了解各类智能优化算法及其改进的朋友&#xff0c;可关注我的公众号&#xff1a;强盛机器学习&#xff0c;不定期会有很多免费代码分享~ 目录 效果展示 数据介绍 创新点 模型流程 部…

7款前端实战型项目特效分享(附在线预览)

分享7款实用性的前端动画特效 其中有canvas特效、css动画、svg动画等等 下方效果图可能不是特别的生动 那么你可以点击在线预览进行查看相应的动画特效 同时也是可以下载该资源的 CSS春节灯笼特效 基于CSS实现的灯笼特效 灯笼会朝左右两个方向来回的摆动着 以下效果图只能体现…

调用Mybatis plus中的saveBatch方法报找不到表的问题

1.问题现象 在用Mybatis plus开发的项目中&#xff0c;用自带的API批量保存的方法saveBatch操作时&#xff0c;发现报没有找到表的错误。 错误日志截图如下&#xff1a; 表实际是存在的&#xff0c;且发现其他的方法都没有问题&#xff0c;包括save、update等单个的方法&…

Linux基础命令[13]-nl

文章目录 1. nl 命令说明2. nl 命令语法3. nl 命令示例3.1 不加参数3.2 -b&#xff08;依据样式显示行号&#xff09;3.3 -n&#xff08;格式化行号&#xff09;3.4 -w&#xff08;占位数长度&#xff09;3.5 -i&#xff08;依据数值增长行号&#xff09;3.6 -v&#xff08;定义…

【数据库】数据库学习使用总结

一、数据库介绍 二、数据库系统 1、DB——>存储数据的 2、DBMS——>用来管理数据的 DBMS&#xff1a; 1、DCL 用&#xff1b;用来创建和维护用户账户 2、DDL 数据定义语言 3、DML 用来操作数据 三、DDL 1、操作数据库&#xff08;创建和删除&#xff09; create d…

基于sprinbgoot的火锅店管理系统(程序+数据库+文档)

** &#x1f345;点赞收藏关注 → 私信领取本源代码、数据库&#x1f345; 本人在Java毕业设计领域有多年的经验&#xff0c;陆续会更新更多优质的Java实战项目&#xff0c;希望你能有所收获&#xff0c;少走一些弯路。&#x1f345;关注我不迷路&#x1f345;** 一、研究背景…