【深度学习笔记】5_12稠密连接网络(DenseNet)

news2024/12/23 16:22:02

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

5.12 稠密连接网络(DenseNet)

ResNet中的跨层连接设计引申出了数个后续工作。本节我们介绍其中的一个:稠密连接网络(DenseNet) [1]。 它与ResNet的主要区别如图5.10所示。

在这里插入图片描述

图5.10 ResNet(左)与DenseNet(右)在跨层连接上的主要区别:使用相加和使用连结

图5.10中将部分前后相邻的运算抽象为模块 A A A和模块 B B B。与ResNet的主要区别在于,DenseNet里模块 B B B的输出不是像ResNet那样和模块 A A A的输出相加,而是在通道维上连结。这样模块 A A A的输出可以直接传入模块 B B B后面的层。在这个设计里,模块 A A A直接跟模块 B B B后面的所有层连接在了一起。这也是它被称为“稠密连接”的原因。

DenseNet的主要构建模块是稠密块(dense block)和过渡层(transition layer)。前者定义了输入和输出是如何连结的,后者则用来控制通道数,使之不过大。

5.12.1 稠密块

DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构,我们首先在conv_block函数里实现这个结构。

import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def conv_block(in_channels, out_channels):
    blk = nn.Sequential(nn.BatchNorm2d(in_channels), 
                        nn.ReLU(),
                        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
    return blk

稠密块由多个conv_block组成,每块使用相同的输出通道数。但在前向计算时,我们将每块的输入和输出在通道维上连结。

class DenseBlock(nn.Module):
    def __init__(self, num_convs, in_channels, out_channels):
        super(DenseBlock, self).__init__()
        net = []
        for i in range(num_convs):
            in_c = in_channels + i * out_channels
            net.append(conv_block(in_c, out_channels))
        self.net = nn.ModuleList(net)
        self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结
        return X

在下面的例子中,我们定义一个有2个输出通道数为10的卷积块。使用通道数为3的输入时,我们会得到通道数为 3 + 2 × 10 = 23 3+2\times 10=23 3+2×10=23的输出。卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growth rate)。

blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape # torch.Size([4, 23, 8, 8])

5.12.2 过渡层

由于每个稠密块都会带来通道数的增加,使用过多则会带来过于复杂的模型。过渡层用来控制模型复杂度。它通过 1 × 1 1\times1 1×1卷积层来减小通道数,并使用步幅为2的平均池化层减半高和宽,从而进一步降低模型复杂度。

def transition_block(in_channels, out_channels):
    blk = nn.Sequential(
            nn.BatchNorm2d(in_channels), 
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.AvgPool2d(kernel_size=2, stride=2))
    return blk

对上一个例子中稠密块的输出使用通道数为10的过渡层。此时输出的通道数减为10,高和宽均减半。

blk = transition_block(23, 10)
blk(Y).shape # torch.Size([4, 10, 4, 4])

5.12.3 DenseNet模型

我们来构造DenseNet模型。DenseNet首先使用同ResNet一样的单卷积层和最大池化层。

net = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
        nn.BatchNorm2d(64), 
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

类似于ResNet接下来使用的4个残差块,DenseNet使用的是4个稠密块。同ResNet一样,我们可以设置每个稠密块使用多少个卷积层。这里我们设成4,从而与上一节的ResNet-18保持一致。稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。

ResNet里通过步幅为2的残差块在每个模块之间减小高和宽。这里我们则使用过渡层来减半高和宽,并减半通道数。

num_channels, growth_rate = 64, 32  # num_channels为当前的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]

for i, num_convs in enumerate(num_convs_in_dense_blocks):
    DB = DenseBlock(num_convs, num_channels, growth_rate)
    net.add_module("DenseBlosk_%d" % i, DB)
    # 上一个稠密块的输出通道数
    num_channels = DB.out_channels
    # 在稠密块之间加入通道数减半的过渡层
    if i != len(num_convs_in_dense_blocks) - 1:
        net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))
        num_channels = num_channels // 2

同ResNet一样,最后接上全局池化层和全连接层来输出。

net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10))) 

我们尝试打印每个子模块的输出维度确保网络无误:

X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():
    X = layer(X)
    print(name, ' output shape:\t', X.shape)

输出:

0  output shape:	 torch.Size([1, 64, 48, 48])
1  output shape:	 torch.Size([1, 64, 48, 48])
2  output shape:	 torch.Size([1, 64, 48, 48])
3  output shape:	 torch.Size([1, 64, 24, 24])
DenseBlosk_0  output shape:	 torch.Size([1, 192, 24, 24])
transition_block_0  output shape:	 torch.Size([1, 96, 12, 12])
DenseBlosk_1  output shape:	 torch.Size([1, 224, 12, 12])
transition_block_1  output shape:	 torch.Size([1, 112, 6, 6])
DenseBlosk_2  output shape:	 torch.Size([1, 240, 6, 6])
transition_block_2  output shape:	 torch.Size([1, 120, 3, 3])
DenseBlosk_3  output shape:	 torch.Size([1, 248, 3, 3])
BN  output shape:	 torch.Size([1, 248, 3, 3])
relu  output shape:	 torch.Size([1, 248, 3, 3])
global_avg_pool  output shape:	 torch.Size([1, 248, 1, 1])
fc  output shape:	 torch.Size([1, 10])

5.12.4 获取数据并训练模型

由于这里使用了比较深的网络,本节里我们将输入高和宽从224降到96来简化计算。

batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0020, train acc 0.834, test acc 0.749, time 27.7 sec
epoch 2, loss 0.0011, train acc 0.900, test acc 0.824, time 25.5 sec
epoch 3, loss 0.0009, train acc 0.913, test acc 0.839, time 23.8 sec
epoch 4, loss 0.0008, train acc 0.921, test acc 0.889, time 24.9 sec
epoch 5, loss 0.0008, train acc 0.929, test acc 0.884, time 24.3 sec

小结

  • 在跨层连接上,不同于ResNet中将输入与输出相加,DenseNet在通道维上连结输入与输出。
  • DenseNet的主要构建模块是稠密块和过渡层。

参考文献

[1] Huang, G., Liu, Z., Weinberger, K. Q., & van der Maaten, L. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (Vol. 1, No. 2).


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1512797.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【论文速读】| DeepGo:预测式定向灰盒模糊测试

本次分享论文为:DeepGo: Predictive Directed Greybox Fuzzing 基本信息 原文作者:Peihong Lin, Pengfei Wang, Xu Zhou, Wei Xie, Gen Zhang, Kai Lu 作者单位:国防科技大学计算机学院 关键词:Directed Greybox Fuzzing, Path…

Postman请求API接口测试步骤和说明

Postman请求API接口测试步骤 本文测试的接口是国内数智客(www.shuzike.com)的API接口手机三要素验证,验证个人的姓名,身份证号码,手机号码是否一致。 1、设置接口的Headers参数。 Content-Type:applicati…

2024蓝桥杯每日一题(区间合并)

一、第一题:挤牛奶 解题思路:区间合并 区间合并模板题 【Python程序代码】 n int(input()) a [] for i in range(n):l,r map(int,input().split())a.append([l,r]) def cmp(x):return x[0],x[1] a.sort(keycmp) res1,res20,0 st,ed a[0][0…

JS-12-关键字this、apply()、call()

一、对象的方法 在一个对象中绑定函数,称为这个对象的方法。 示例: 1、对象: var xiaoming {name: 小明,birth: 1990 }; 2、给xiaoming绑定一个函数。比如,写个age()方法,返回xiaoming的年龄: var x…

亲测抖音小程序备案流程,抖音小程序如何备案,抖音小程序备案所需准备资料

抖音小程序为什么要备案,抖音官方给出如下说明: 1、2024年3月15日后提交备案的小程序将不保证2024年3月31日前平台可初审通过; 2、2024年3月31日后未完成备案小程序将被下架处理。 一,备案前需准备资料 (一&#xff0…

Python 导入Excel三维坐标数据 生成三维曲面地形图(面) 1、线条折线曲面

环境和包: 环境 python:python-3.12.0-amd64包: matplotlib 3.8.2 pandas 2.1.4 openpyxl 3.1.2 代码: import pandas as pd import matplotlib.pyplot as plt import numpy as np from mpl_toolkits.mplot3d import Axes3D from matplotlib.colors import ListedColor…

k8s+wordpress+zabbix+elastic+filebeat+kibana服务搭建以及测试

一,环境:docker,k8s,zabbix,以及搭建worpdress,elasticsearch,filebeat,kibana 二,主机分配: 名称host详述个人博客3192.168.142.133 搭配mysql8.0.36的数据…

Stable Diffusion 模型:从噪声中生成逼真图像

你好,我是郭震 简介 Stable Diffusion 模型是一种生成式模型,可以从噪声中生成逼真的图像。它由 Google AI 研究人员于 2022 年提出,并迅速成为图像生成领域的热门模型。 数学基础 Stable Diffusion模型基于一种称为扩散概率模型(Diffusion P…

【QT】文件流操作(QTextStream/QDataStream)

文本流/数据流&#xff08;二级制格式&#xff09; 文本流 &#xff08;依赖平台&#xff0c;不同平台可能乱码&#xff09;涉及文件编码 #include <QTextStream>操作的都是基础数据类型&#xff1a;int float string //Image Qpoint QRect就不可以操作 需要下面的 …

ES分片均衡策略分析与改进

从故障说起 某日早高峰收到 Elasticsearch 大量查询超时告警&#xff0c;不同于以往&#xff0c;查看 Elasticsearch 查询队列监控后发现&#xff0c;仅123节点存在大量查询请求堆积。 各节点查询队列堆积情况 查看节点监控发现&#xff0c;123节点的 IO 占用远高于其他节点。…

喜报!聚铭网络实力入选2024年度扬州市网络安全技术支撑服务机构

近日&#xff0c;中共扬州市委网络安全和信息化委员会办公室正式公布了“2024年度扬州市网络安全技术支撑服务机构”名单&#xff0c;聚铭网络凭借其卓越的技术实力与优质的安服能力&#xff0c;在众多竞争者中脱颖而出&#xff0c;光荣上榜&#xff01; 为了健全扬州市网络安…

仿12306校招项目业务五(敏感信息模块)

加密存储 数据加密背景 数据加密是指对某些敏感信息通过加密规则进行数据的变形&#xff0c;实现敏感隐私数据的可靠保护。 涉及客户安全数据或者一些商业性敏感数据&#xff0c;如身份证号、手机号、卡号、客户号等个人信息按照相关部门规定&#xff0c;都需要进行数据加密。…

程序人生——Java中基本类型使用建议

目录 引出Java中基本类型使用建议建议21&#xff1a;用偶判断&#xff0c;不用奇判断建议22&#xff1a;用整数类型处理货币建议23&#xff1a;不要让类型默默转换建议24&#xff1a;边界、边界、还是边界建议25&#xff1a;不要让四舍五入亏了一方 建议26&#xff1a;提防包装…

Nodejs 第五十五章(socket.io)

传统的 HTTP 是一种单向请求-响应协议&#xff0c;客户端发送请求后&#xff0c;服务器才会响应并返回相应的数据。在传统的 HTTP 中&#xff0c;客户端需要主动发送请求才能获取服务器上的资源&#xff0c;而且每次请求都需要重新建立连接&#xff0c;这种方式在实时通信和持续…

细粒度IP定位参文27(HGNN):Identifying user geolocation(2022年)

[27] F. Zhou, T. Wang, T. Zhong, and G. Trajcevski, “Identifying user geolocation with hierarchical graph neural networks and explainable fusion,” Inf. Fusion, vol. 81, pp. 1–13, 2022. (用层次图、神经网络和可解释的融合来识别用户的地理定位) 论文地址:…

设计模式一 ---单例设计模式(动力节点,JavaSE基础)

设计模式 1.什么是设计模式&#xff1f; 2.设计模式的分类 单例设计模式就是GoF模式中的一种。 3.GoF设计模式的分类&#xff1a; 单例设计模式&#xff1a; 顾名思义&#xff1a;单个实例的设计模式&#xff01;

C#调用Halcon出现尝试读取或写入受保护的内存,这通常指示其他内存已损坏。System.AccessViolationException

一、现象 在C#中调用Halcon&#xff0c;出现异常提示&#xff1a;尝试读取或写入受保护的内存,这通常指示其他内存已损坏。System.AccessViolationException 二、原因 多个线程同时访问Halcon中的某个公共变量&#xff0c;导致程序报错 三、测试 3.1 Halcon代码 其中tsp_width…

【Android】 ClassLoader 知识点提炼

1.Java中的 ClassLoader 1.1 、ClassLoader的类型 Java 中的类加载器主要有两种类型&#xff0c;即系统类加载器和自定义类加载器。其中系统类 加载器包括3种&#xff0c;分别是 Bootstrap ClassLoader、Extensions ClassLoader 和 Application ClassLoader。 1.1.1.Bootstra…

开放原子开源大赛—基于OpenHarmony的团结引擎应用开发赛正式启动!

“基于OpenHarmony的团结引擎应用开发赛”是开放原子全球开源大赛下开设的新兴及应用赛的赛题之一&#xff0c;本次赛题旨在鼓励更多开发者基于OpenHarmony 4.x版本&#xff0c;使用Unity中国团结引擎创造出精彩的游戏与3D应用。 大赛分为“创新游戏”与“创新3D 化应用”两大赛…

全景解析 Partisia Blockchain:以用户为中心的全新数字经济网络

在区块链世界中&#xff0c;以比特币、以太坊网络为代表的主流区块链奠定了该领域早期的基础&#xff0c;并让去中心化、点对点、公开透明以及不可逆成为了该领域固有的意识形态。事实上&#xff0c;过于透明正在成为区块链规模性采用的一大障碍&#xff0c;我们看到 90% 以上的…