动手学深度学习——稠密连接网络DenseNet(原理解释+代码详解)

news2024/10/5 22:24:31

稠密连接网络DenseNet

      • 1. 从ResNet到DenseNet
      • 2. 稠密块体
      • 3. 过渡层
      • 4. DenseNet模型
      • 5. 训练模型

CIFAR 和 SVHN 数据集上的错误率 (%)。DenseNet 比 ResNet 使用更少的参数,同时实现了更低的错误率。在没有数据增强的情况下,DenseNet 的性能大幅提高。
在这里插入图片描述

1. 从ResNet到DenseNet

稠密连接网络在某种程度上是ResNet的逻辑扩展。
回想一下任意函数的泰勒展开式,它把这个函数分解成越来越高阶的项。在x接近0时,
在这里插入图片描述
ResNet将函数展开为
在这里插入图片描述
ResNet将 f 分解为两部分:一个简单的线性项和一个复杂的非线性项。
那么再向前拓展一步,如果我们想将 f 拓展成超过两部分的信息呢? 一种方案便是DenseNet,使用连结。
在这里插入图片描述
执行从x到其展开式的映射
在这里插入图片描述
最后,将这些展开式结合到多层感知机中,再次减少特征的数量。
在这里插入图片描述
稠密网络主要由2部分构成:

  • 稠密块(dense block):定义如何连接输入和输出
  • 过渡层(transition layer):控制通道数量,使其不会太复杂。

2. 稠密块体

稠密块每一层都将所有前面的特征图作为输入
在这里插入图片描述
DenseNet使用了ResNet改良版的“批量规范化、激活和卷积”架构

# 泰勒公式
"""
稠密网络:
        1、稠密块:定义如果连接输入和输出
        2、过渡层:后者控制通道数
"""
import torch
from torch import nn
from d2l import torch as d2l


def conv_block(input_channels, num_channels):
    return nn.Sequential(
        nn.BatchNorm2d(input_channels), nn.ReLU(),
        nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1))

一个稠密块由多个卷积块组成,每个卷积块使用相同数量的输出通道。在前向传播中,将每个卷积块的输入和输出在通道维上连结。

# 稠密块
class DenseBlock(nn.Module):
    def __init__(self, num_convs, input_channels, num_channels):
        super(DenseBlock, self).__init__()
        layer = []
        for i in range(num_convs):
            # 稠密连接:执行从x到其展开式的映射,即每个卷积块的输入和输出在通道维度上连接
            layer.append(conv_block(i * num_channels + input_channels, num_channels))
        self.net = nn.Sequential(*layer)
    
    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            # 连接通道维度上每个卷积块的输入和输出
            X = torch.cat((X, Y), dim=1)
        return X

定义一个有2个输出通道数为10的DenseBlock。 使用通道数为3的输入时,我们会得到通道数为3+2x10=23的输出。

卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growth rate)。

# 定义有2个输出通道为10的DenseBlock
# 使用通道数为3的输入,会得到3+2x10=23
# 卷积块的通道数控制了输出通道数相对于输入通道数的增长程度,因此被称为增长率
blk = DenseBlock(2, 3, 10)
X = torch.randn(4, 3, 8, 8)
Y = blk(X)
Y.shape

在这里插入图片描述

3. 过渡层

每个稠密块都会带来通道数的增加,使用过多则会过于复杂化模型。

过渡层可以用来控制模型复杂度,它通过1x1卷积层来减小通道数,并使用步幅为2的平均汇聚层减半高和宽,从而进一步降低模型复杂度。

# 过渡层:用来控制模型复杂度
# 1x1卷积层减少通道数,使用步幅为2的平均汇聚层减半高度和宽度
def transition_block(input_channels, num_channels):
    return nn.Sequential(
        nn.BatchNorm2d(input_channels), nn.ReLU(),
        nn.Conv2d(input_channels, num_channels, kernel_size=1),
        nn.AvgPool2d(kernel_size=2, stride=2))

对于上一个例子,将通道数从23变为10

# 对于上一个例子,将通道数从23变为10
blk = transition_block(23, 10)
blk(Y).shape

在这里插入图片描述

4. DenseNet模型

DenseNet首先使用同ResNet一样的单卷积层和最大汇聚层

# 构建DenseNet
# 首先,使用和ResNet一样的卷积层和最大汇聚层
b1 = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64), nn.ReLU(),
    nn.AvgPool2d(kernel_size=3, stride=2, padding=1))

DenseNet使用4个稠密块,稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。

在每个模块之间,ResNet通过步幅为2的残差块减小高和宽,DenseNet则使用过渡层来减半高和宽,并减半通道数。

"""
DenseNet使用4个稠密块:
1、稠密块里的卷积层通道数(即增长率)设为32,每个稠密块将增加4x32=128个通道
2、每个模块之间,DenseNet使用过渡层减半高度和宽度,并减半通道数
"""
# num_channels为当前的通道数
# growth_rate为增长率,即输出通道相对于输入通道的增长程度
num_channels, growth_rate = 64, 32
# 4个稠密块,每个稠密块4个卷积层,每个稠密块增加4x32=128个通道
num_convs_in_dense_blocks = [4, 4, 4, 4]
blks = []
for i, num_convs in enumerate(num_convs_in_dense_blocks):
    blks.append(DenseBlock(num_convs, num_channels, growth_rate))
    # 上一个稠密块的输出通道数
    num_channels += num_convs * growth_rate
    # 在稠密块之间添加一个过渡层,使通道数减半
    if i != len(num_convs_in_dense_blocks) - 1:
        blk.append(transition_block(num_channels, num_channels // 2))
        num_channels // 2

最后接上全局汇聚层和全连接层来输出结果。

# 最后连接全局汇聚层和全连接层来输出结果
net = nn.Sequential(
    b1, *blks,
    nn.BatchNorm2d(num_channels), nn.ReLU(),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(num_channels, 10))

5. 训练模型

定义精度评估函数

"""
    定义精度评估函数:
    1、将数据集复制到显存中
    2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    # 判断net是否属于torch.nn.Module类
    if isinstance(net, nn.Module):
        net.eval()
        
        # 如果不在参数选定的设备,将其传输到设备中
        if not device:
            device = next(iter(net.parameters())).device
    
    # Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            # 将X, y复制到设备中
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            
            # 计算正确预测的数量,总预测的数量,并存储到metric中
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

定义GPU训练函数

"""
    定义GPU训练函数:
    1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;
    2、使用Xavier随机初始化模型参数;
    3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    # 定义初始化参数,对线性层和卷积层生效
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    
    # 在设备device上进行训练
    print('training on', device)
    net.to(device)
    
    # 优化器:随机梯度下降
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    
    # 损失函数:交叉熵损失函数
    loss = nn.CrossEntropyLoss()
    
    # Animator为绘图函数
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    
    # 调用Timer函数统计时间
    timer, num_batches = d2l.Timer(), len(train_iter)
    
    for epoch in range(num_epochs):
        
        # Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量
        metric = d2l.Accumulator(3)
        net.train()
        
        # enumerate() 函数用于将一个可遍历的数据对象
        for i, (X, y) in enumerate(train_iter):
            timer.start() # 进行计时
            optimizer.zero_grad() # 梯度清零
            X, y = X.to(device), y.to(device) # 将特征和标签转移到device
            y_hat = net(X)
            l = loss(y_hat, y) # 交叉熵损失
            l.backward() # 进行梯度传递返回
            optimizer.step()
            with torch.no_grad():
                # 统计损失、预测正确数和样本数
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop() # 计时结束
            train_l = metric[0] / metric[2] # 计算损失
            train_acc = metric[1] / metric[2] # 计算精度
            
            # 进行绘图
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
                
        # 测试精度
        test_acc = evaluate_accuracy_gpu(net, test_iter) 
        animator.add(epoch + 1, (None, None, test_acc))
        
    # 输出损失值、训练精度、测试精度
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'
          f'test acc {test_acc:.3f}')
    
    # 设备的计算能力
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'
          f'on {str(device)}')

在这里插入图片描述

由于模型较深,这里将输入高度和宽度从224降为96

# 训练模型
# 由于模型较深,这里将输入高度和宽度从224降为96
lr, num_epochs, batch_size = 0.1, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1168253.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL数据库之表的增删查改

目录 表的操作1.创建表创建表案例 2.查看表结构3.修改表4.删除表 表的操作 1.创建表 语法: CREATE TABLE table_name (field1 datatype,field2 datatype,field3 datatype ) character set 字符集 collate 校验规则 engine 存储引擎;说明: field 表示列…

pinia简单使用

新命令-创建vue3项目 vue create 方式使用脚手架创建项目,vue cli处理, vue3后新的脚手架工具create-vue 使用npm init vuelatest 命令创建即可。 在pinia中,将使用的组合式函数识别为状态管理内容 自动将ref 识别为stste,computed 相当于 ge…

Anaconda安装与配置

1.打开Anaconda官网,选择对应版本,下载到对应目录即可 或者进入: Index of /anaconda/archive/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror 2.双击打开.exe文件,然后点击next ; 3.点击agree 4.点击just me,然后next; 5.在Choose Install L…

百科创建系列天花板!一文看懂百度百科如何创建,百度百科怎么编辑才能通过!(百科创建必看)

在互联网时代,拥有一个权威的个人或企业信息展示平台显得尤为重要。百度百科作为全球最大的中文百科全书,已经成为了许多人和企业展示自己的重要途径。那么,如何创建一个百度百科词条呢? 分媒互动将为大家详细介绍百度百科创建的…

Locust:可能是一款最被低估的压测工具

01、Locust介绍 开源性能测试工具https://www.locust.io/,基于Python的性能压测工具,使用Python代码来定义用户行为,模拟百万计的并发用户访问。每个测试用户的行为由您定义,并且通过Web UI实时监控聚集过程。 压力发生器作为性…

第三章 python数据类型

系列文章目录 第一章 Python 基础知识 第二章 python 字符串处理 第三章 python 数据类型 第四章 python 运算符与流程控制 第五章 python 文件操作 第六章 python 函数 第七章 python 常用内建函数 第八章 python 类(面向对象编程) 第九章 python 异常处理 第十章 python 自定…

Git https方式拉的代码IDEA推送代码报错

报错信息 fatal: could not read Username for ‘https://codehub-cn-south-1.devcloud.huaweicloud.com’: No such file or directory 18:18:39.885: [recovery_pattern] git -c credential.helper -c core.quotepathfalse -c log.showSignaturefalse push --progress --porc…

python第一课 变量

1.离线的情况下首选txt文档 2.有道云笔记 3.思维导图 xmind mindmaster 4.博客 5.wps流程图 # 变量的命名规则 1.变量名只能由数字字母下划线组成 2.变量名不能以数字开头 3.变量名不能与关键字重名 快捷键 撤销:Ctrl/Command Z 新建:Ctrl/Com…

【CesiumJS】(1)Hello world

介绍 Cesium 起源于2011年,初衷是航空软件公司(Analytical Graphics, Inc.)的一个团队要制作世界上最准确、性能最高且具有时间动态性的虚拟地球。取名"Cesium"是因为元素铯Cesium让原子钟非常准确(1967年,人们依据铯原子的振动而对…

气膜场馆的降噪方法

在现代社会,噪音已经成为我们生活中难以避免的问题,而气膜场馆也不例外。传统的气膜场馆常常因其特殊结构而面临噪音扩散和回声问题,影响了人们的体验和活动效果。然而,随着科技的进步,多功能声学综合馆应运而生&#…

413 (Payload Too Large) 2023最新版解决方法

文章目录 出现问题解决方法 出现问题 博主在用vue脚手架开发的时候,在上传文件的接口中碰到 这样一个错误,查遍所有csdn,都没有找到解决方法,通过一些方式,终于解决了。 解决方法 1.打开Vue项目的根目录。 2.在根目…

what?腾讯云3年轻量2核4G5M服务器566.6元哪去了?

what?腾讯云3年轻量2核4G5M服务器566.6元哪去了?腾讯云双11优惠活动3年轻量2核4G5M服务器从566.6元涨价到756元三年,3年轻量2核2G4M服务器从366.6元恢复到540元三年,大家抓紧吧,三年轻量已经库存已经不多了&#xff0c…

EthernetIP主站转EtherCAT协议网关采集电力变压器的 Ethernet IP 数据

怎么通过捷米JM-EIPM-ECT网关把ABB电力变压器的 Ethernet IP 数据,连接到欧姆龙PLC上,通过plc去监控电力设备的数据呢,下面是介绍简单的连接方法,采集Ethernet IP从站数据和EtherCAT协议 1 ,捷米JM-EIPM-ECT网关连接Et…

基于51单片机电子秤-proteus仿真-源程序

一、系统方案 本设计采用52单片机作为主控器,液晶1602显示,HX711模块,按键设置单价,计算总价,超量程报警,蜂鸣器报警。 二、硬件设计 原理图如下: 三、单片机软件设计 1、首先是系统初始化 I…

基于openresty waf二次开发多次匹配到的ip再做拉黑

我们想在openresty waf的基础上做二次开发,比如再精确一些。比如我们先匹配到了select的url我们先打分10分,匹配到cc 1000/s我们再给这个ip打10分…直到100分我们就拉黑这个ip。 [openresty waf][1] #cat reids_w.lua require lib local redis require…

Zookeeper安装及配置

Zookeeper官网:Apache ZooKeeper 一般作为服务注册中心 无论在Windows下还是Linux下,Zookeeper的安装步骤是一样的,用的包也是同一个包 Window下安装及配置Zookeeper 下载后解压 linux安装 window及Linux安装及配置zookeeper_访问windos上的zookeeper-CSDN博客

用Python写了13个小游戏,上班摸鱼我能玩一天

分享13个Python小游戏,本内容来源于网络。 用Python写个魂斗罗,另附30个Python小游戏源码​segmentfault.com/a/1190000041782623 1、吃金币 源码分享: import os import cfg import sys import pygame import random from modules import …

【框架篇】统一用户登录权限验证

✅作者简介:大家好,我是小杨 📃个人主页:「小杨」的csdn博客 🐳希望大家多多支持🥰一起进步呀! 统一用户登录权限验证 1,自定义拦截器 对于统一用户登录权限验证的问题&#xff0c…

照片拼图软件 CollageIt Pro mac中文版功能特色

CollageIt Pro mac是一款拼图软件,CollageIt Pro for mac不仅支持多种模式的拼贴风格,还能够完美满足您对自己图片的美化需要,以一种全新的方式来呈现您收藏的图片,并且只需短短的几秒,便可以轻松实现将一组照片编程一…

黑色木工板覆膜板:耐用防水的建筑模板选择

黑色木工板覆膜板是一种耐用的建筑模板材料,以其优异的防水性能和稳定性成为建筑行业的理想选择。本文将重点介绍黑色木工板覆膜板的特点以及其在建筑模板领域的应用。 黑色木工板覆膜板采用杨木芯,并在表面覆盖一层黑色防水膜。杨木芯的选择使得木工板具…