现代卷积神经网络之稠密连接网络(DenseNet),并对CFIAR10训练

news2025/1/20 3:40:31

专栏:神经网络复现目录


本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络(DenseNet)。

文章目录

  • 稠密连接网络(DenseNet)
  • 从ResNet到DenseNet
  • 稠密体块(密集块)
    • 定义
    • 实现
  • 过渡层
    • 定义
    • 实现
  • 模型设计
    • 网络结构
    • 实现
  • 实战:训练CIFAR10分类
    • 数据集
    • 损失函数和优化器
    • 训练
    • 可视化


稠密连接网络(DenseNet)

DenseNet(密集连接网络)是一种深度学习架构,由黄等人在2017年提出。DenseNet是一种卷积神经网络(CNN),它使用密集连接层来改善网络中信息的流动。

传统的CNN中,每一层的输出只传递给下一层。然而,在DenseNet中,每一层都与后续所有层相连。这意味着任何一层的输入都是来自网络中所有前面层的特征图的串联。

DenseNet的好处包括:

  1. 减少梯度消失问题:因为每一层都直接访问所有后续层的梯度,所以梯度信号可以更轻松地通过网络进行传播和保留。

  2. 参数效率:与相似性能的传统CNN相比,DenseNet具有更少的参数,因为它可以重用特征图。

  3. 提高准确性:DenseNet在各种计算机视觉任务上取得了最先进的性能,包括图像分类、目标检测和语义分割。

总的来说,DenseNet是图像识别任务的强有力工具,并已被证明在提高CNN的准确性的同时减少所需的参数数量方面非常有效。

从ResNet到DenseNet

ResNet将函数展开为
在这里插入图片描述
也就是说,ResNet将f分解为两部分:一个简单的线性项和一个复杂的非线性项。 那么再向前拓展一步,如果我们想将f拓展成超过两部分的信息呢? 一种方案便是DenseNet。

ResNet(左)与 DenseNet(右)在跨层连接上的主要区别:使用相加和使用连结。
在这里插入图片描述

ResNet和DenseNet的关键区别在于,DenseNet输出是连接(用图中的[ , ]表示)而不是如ResNet的简单相加。 因此,在应用越来越复杂的函数序列后,我们执行x从到其展开式的映射:
在这里插入图片描述
最后,将这些展开式结合到多层感知机中,再次减少特征的数量。 实现起来非常简单:我们不需要添加术语,而是将它们连接起来。 DenseNet这个名字由变量之间的“稠密连接”而得来,最后一层与之前的所有层紧密相连。稠密连接如图所示。
在这里插入图片描述
在这里插入图片描述

稠密网络主要由2部分构成:稠密块(dense block)和过渡层(transition layer)。 前者定义如何连接输入和输出,而后者则控制通道数量,使其不会太复杂。

稠密体块(密集块)

定义

密集块(Dense Block)是DenseNet网络结构的核心部分,它是由多个密集连接层(Dense Layer)组成的模块,用于提取图像中的特征信息。在每个密集块中,所有前面层的输出都会与当前层的输入进行连接,并通过一个非线性变换进行处理。具体地,假设当前是第 l l l个密集块,第 i i i层的输出为 x i x_i xi,则第 i i i层的计算公式为:

x i = H i ( [ x 0 , x 1 , ⋯   , x i − 1 ] ) x_i = H_i([x_0, x_1, \cdots, x_{i-1}]) xi=Hi([x0,x1,,xi1])

其中, H i ( ⋅ ) H_i(\cdot) Hi()表示第 i i i层的非线性变换, [ x 0 , x 1 , ⋯   , x i − 1 ] [x_0, x_1, \cdots, x_{i-1}] [x0,x1,,xi1]表示前 i i i层的输出的连接。

密集块的优点是可以促进信息的流动和梯度的传递,从而提高网络的性能和稳定性。另外,密集块还可以增加网络的深度和宽度,使得网络能够提取更多的特征信息。

在DenseNet中,每个密集块包含多个密集连接层,其中每个密集连接层都包含一个 3 × 3 3\times3 3×3的卷积层、一个BN层和ReLU激活函数。这些层共享相同的输入和输出,因此它们的输入和输出的通道数相同。而且,DenseNet中的每个密集块都会接上一个过渡层(Transition Block)用于控制网络的大小。

实现

class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super(DenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 4 * growth_rate, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(4 * growth_rate)
        self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(growth_rate)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)
        out = torch.cat([x, out], 1)
        return out

过渡层

定义

对于Transition层,它主要是连接两个相邻的DenseBlock,并且降低特征图大小。Transition层包括一个1x1的卷积和2x2的AvgPooling,结构为BN+ReLU+1x1 Conv+2x2 AvgPooling。另外,Transition层可以起到压缩模型的作用。假定Transition的上接DenseBlock得到的特征图channels数为m mm ,Transition层可以产生 θ m θmθm个特征(通过卷积层),其中 θ ∈ ( 0 , 1 ] θ∈(0,1]θ∈(0,1] 是压缩系数(compression rate)。当 θ = 1 θ=1θ=1 时,特征个数经过Transition层没有变化,即无压缩,而当压缩系数小于1时,这种结构称为DenseNet-C,文中使用 θ = 0.5 θ = 0.5θ=0.5。对于使用bottleneck层的DenseBlock结构和压缩系数小于1的Transition组合结构称为DenseNet-BC。

实现

class Transition(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Transition, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = F.relu(out)
        out = F.avg_pool2d(out, 2)
        return out

模型设计

网络结构

DenseNet是一种基于密集连接的卷积神经网络(CNN),其主要特点是在网络中引入了密集连接层,从而改善了信息的流动和梯度的传递。下面是DenseNet的网络结构:

1.输入层:输入层接收输入数据,并将其送入第一个卷积层中。

2.卷积层:DenseNet中的卷积层通常采用 3 × 3 3\times3 3×3的卷积核,并采用padding来保持特征图的大小不变。在每个卷积层后面,都会接上BN层和ReLU激活函数。

3.密集块(Dense Block):密集块是DenseNet的核心,它由多个密集连接层组成。在每个密集块中,所有前面层的输出都会与当前层的输入进行连接,并通过一个非线性变换进行处理。

4.过渡层(Transition Block):为了避免网络过深导致梯度消失和计算资源过度消耗,DenseNet中采用了过渡层来控制网络的大小。在每个密集块之间,都会接上一个过渡层,它包含一个 1 × 1 1\times1 1×1的卷积层、BN层和平均池化层,其中平均池化的步幅为2,用于减少特征图的大小。

5.全局池化层和全连接层:最后,DenseNet使用全局平均池化层将特征图降维为一个向量,然后通过一个全连接层进行分类。

综上所述,DenseNet的网络结构具有密集连接和过渡层两个核心特点,它能够有效地利用前面层的信息,提高网络的性能和稳定性。同时,DenseNet也非常适合处理较小的数据集,如CIFAR-10和CIFAR-100。

在这里插入图片描述

实现

class DenseNet(nn.Module):
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000):
        super(DenseNet, self).__init__()

        # 初始卷积层
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # DenseBlock 和 Transition 的组合
        self.dense1 = self._make_dense_block(growth_rate, block_config[0])
        self.trans1 = self._make_transition(256, growth_rate * block_config[0] // 2)

        self.dense2 = self._make_dense_block(growth_rate, block_config[1])
        self.trans2 = self._make_transition(512, growth_rate * block_config[1] // 2)

        self.dense3 = self._make_dense_block(growth_rate, block_config[2])
        self.trans3 = self._make_transition(1024, growth_rate * block_config[2] // 2)

        self.dense4 = self._make_dense_block(growth_rate, block_config[3])

        # 全局平均池化层和分类层
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(growth_rate * block_config[3], num_classes)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.pool1(out)

        out = self.dense1(out)
        out = self.trans1(out)

        out = self.dense2(out)
        out = self.trans2(out)

        out = self.dense3(out)
        out = self.trans3(out)

        out = self.dense4(out)

        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
    	out = self.fc(out)

    return out

def _make_dense_block(self, growth_rate, num_layers):
    layers = []
    for i in range(num_layers):
        layers.append(DenseBlock(growth_rate * i + 64, growth_rate))
    return nn.Sequential(*layers)

def _make_transition(self, in_channels, out_channels):
    return Transition(in_channels, out_channels)

实战:训练CIFAR10分类

数据集

# 导入数据集
from torchvision import datasets
import torch
import torchvision.transforms as transforms
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
])

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'trunk')
cifar_train = datasets.CIFAR10(root="/data",train=True, download=True, transform=transform)
cifar_test = datasets.CIFAR10(root="/data",train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(cifar_train, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(cifar_test, batch_size=16, shuffle=False)

损失函数和优化器

# 定义损失函数和优化器
net=DenseNet(num_classes=10);
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
epoch = 5
net = net.to(device)
total_step = len(train_loader)
train_all_loss = []
val_all_loss = []

训练

import numpy as np

for i in range(epoch):
    net.train()
    train_total_loss = 0
    train_total_num = 0
    train_total_correct = 0

    for iter, (images,labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = net(images)
        loss = criterion(outputs,labels)
        train_total_correct += (outputs.argmax(1) == labels).sum().item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_total_num += labels.shape[0]
        train_total_loss += loss.item()
        print("Epoch [{}/{}], Iter [{}/{}], train_loss:{:4f}".format(i+1,epoch,iter+1,total_step,loss.item()/labels.shape[0]))
    net.eval()
    test_total_loss = 0
    test_total_correct = 0
    test_total_num = 0
    for iter,(images,labels) in enumerate(test_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = net(images)
        loss = criterion(outputs,labels)
        test_total_correct += (outputs.argmax(1) == labels).sum().item()
        test_total_loss += loss.item()
        test_total_num += labels.shape[0]
    print("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(
        i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100

    ))
    train_all_loss.append(np.round(train_total_loss / train_total_num,4))
    val_all_loss.append(np.round(test_total_loss / test_total_num,4))

可视化

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

plt.figure()
plt.title("Train Loss and Test Loss Curve")
plt.xlabel('plot_epoch')
plt.ylabel('loss')
plt.plot(train_all_loss)
plt.plot(val_all_loss)
plt.legend(['train loss', 'test loss'])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/397870.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pikachu靶场CSRF之TOKEN绕过

简介 Pikachu靶场中的CSRF漏洞环节里面有一关CSRF TOKEN,这个关卡和其余关卡稍微有点不一样,因为表单里面存在一个刷新就会变化的token,那么这个token是否能绕过呢?接下来我们来仔细分析分析 实战过程 简单尝试 先利用任意一个…

CNCF x Alibaba云原生技术公开课 第三章 kubernetes核心概念

1、Kubernetes概念 核心功能 服务的发现与负载的均衡容器的自动装箱,我们也会把它叫做 scheduling,就是“调度”,把一个容器放到一个集群的某一个机器上Kubernetes 会帮助我们去做存储的编排,让存储的声明周期与容器的生命周期能…

SpringCloud-高级篇(一)

目录: (1)初识Sentinel-雪崩问题的解决方案 (2)服务保护Sentinel和Hystrix对比 (3)Sentinel初始-安转控制台 (4)整合微服务和Sentinel 微服务高级篇 (1&…

unity开发知识点小结04

混合动画 在动画器控制器中创建从新混合树,也就是创建混合动画 然后进入混合动画,选择混合类型为1D(表示传递参数只有一个),并且为此混合状态添加两个动画,并且设定混合状态参数为何值得时候启用相应动画…

Python中函数的分类、创建和调用,你真的懂了吗

文章目录前言一、函数分类二、创建函数三、调用函数前言 在前面的博客中,所有编写的代码都是从上到下依次执行的,如果某段代码需要多次使用,那么需要将该段代码复制多次,这种做法势必会影响开发效率,在实际项目开发中是…

特权级那些事儿-实模式下分段机制首次出现的原因

前言: 操作系统的特权级模块在整个操作系统的学习中应该算的上是最难啃的了,提到特权级就要绕不开保护模式下的分段机制;如果想要彻底弄明白就要对比实模式下的分段机制有什么缺陷。这就衍生出很多问题如:什么是实模式&#xff1f…

Nacos 注册中心核心能力以及现实原理解析

Nacos注册中心主要分两方面解析:动态服务发现和Nacos实现动态服务发现的原理; 动态服务发现 服务发现是指使用一个注册中心来记录分布式系统中的全部服务的信息,以便其他服务能够快速的找到这些已注册的服务。 在单体应用中,DNS…

MINE: Towards Continuous Depth MPI with NeRF for Novel View Synthesis

MINE: Towards Continuous Depth MPI with NeRF for Novel View Synthesis:利用NeRF实现新视图合成的连续深度MPI 摘要:在论文中,提出了MINE,通过从单个图像进行密集3D重建来执行新的视图合成和深度估计。通过引入神经辐射场&…

05-Oracle中的对象(视图,索引,同义词,系列)

本章主要内容: 1.视图管理:视图新增,修改,删除; 2.索引管理:索引目的,创建,修改,删除; 3.同义词管理:同义词的作用,创建&#xff0…

如何通过websoket实现即时通讯+断线重连?

本篇博客只是一个demo,具体应用还要结合项目实际情况,以下是目录结构: 1.首先通过express搭建一个本地服务器 npm install express 2.在serve.js中自定义测试数据 const express require(express); const app express(); const http req…

详细stm32驱动SDRAM的注意事项以及在keil中的使用

SDRAM的主要参数: 容量:SDRAM的容量是指其可以存储的数据量,通常以兆字节(MB)或千兆字节(GB)为单位。 时钟频率:SDRAM的时钟频率指的是其内部时钟的速度,通常以兆赫&…

94. 二叉树的中序遍历

94. 二叉树的中序遍历 给定一个二叉树的根节点 root ,返回 它的 中序 遍历 (左根右)。 首先我们需要了解什么是二叉树的中序遍历:按照访问左子树——根节点——右子树的方式遍历这棵树,而在访问左子树或者右子树的时候我们按照同样的方式遍历…

MQTT协议-订阅主题和订阅确认

MQTT协议-订阅主题和订阅确认 SUBSCRIBE——订阅主题 订阅是客户端向服务端订阅 订阅报文 订阅报文与CONNECT报文类似,都是由固定报头可变报头有效载荷组成 固定报头比较简单,也是由两个字节组成,第一个字节为82,第二个字节是…

像素密度提升33%,Quest Pro动态注视点渲染原理详解

在Connect 2022上,Meta发布了Quest Pro,并首次在VR中引入动态注视点渲染(ETFR)功能,这是一种新型图形优化技术,特点是以用户注视点为中心,动态调节VR屏幕的清晰度(注视点中心最清晰、…

Oracle VM VirtualBox6.1.36导入ova虚拟机文件报错,代码: E_INVALIDARG (0x80070057)

问题 运维人员去客户现场部署应用服务,客户是windows server 服务器(客户不想买新机器),我们程序是在linux系统里运行(其实windows也可以,主要是为了保持各地环境一致方便更新和排查问题)我们使…

吐血整理学习方法,2年多功能测试成功进阶自动化测试,月薪23k+......

目录:导读前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜)前言 测试进阶方向 测试进…

[Gin]框架底层实现理解(三)

1.engine.Run(port string) 这个就是gin框架的启动语句,看看就好了,下面我们解析一下那个engine.Handler() listenandserve 用于启动http包进行监听,获取链接conn // ListenAndServe listens on the TCP network address addr and then ca…

【SOP 】配电网故障重构方法研究【IEEE33节点】(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

Java中wait和sleep区别

文章目录1. Java中wait和sleep区别2. wait和sleep所属方法的不同3. wait的Demo3.1 没有synchronized同步代码块异常3.2 wait使用Demo4. sleep的Demo1. Java中wait和sleep区别 sleep属于Thread类中的static方法;wait属于Object类的方法sleep时线程状态进入TIMED_WAI…

java 如何实现在线日志

如何采集springboot日志至web页面查看 实现方案 基于Filter方式,在日志输出至控制台前,LoggerFitler 拦截日志通过websocket推送至前台页面 实现逻辑: LoggerFilter采集日志添加至LoggerQueue队列, LoggerConsumer 从LoggerQueue中采集推送至前台页面 #mermaid-s…