基于Pytorch构建DenseNet网络对cifar-10进行分类

news2024/12/23 22:26:53

DenseNet是指Densely connected convolutional networks(密集卷积网络)。它的优点主要包括有效缓解梯度消失、特征传递更加有效、计算量更小、参数量更小、性能比ResNet更好。它的缺点主要是较大的内存占用。

DenseNet网络与Resnet、GoogleNet类似,都是为了解决深层网络梯度消失问题的网络。

Resnet从深度方向出发,通过建立前面层与后面层之间的“短路连接”或“捷径”,从而能训练出更深的CNN网络。

GoogleNet从宽度方向出发,通过Inception(利用不同大小的卷积核实现不同尺度的感知,最后进行融合来得到图像更好的表征)。

DenseNet从特征入手,通过对前面所有层与后面层的密集连接,来极致利用训练过程中的所有特征,进而达到更好的效果和减少参数。

DenseNet网络

Dense Block:像GoogLeNet网络由Inception模块组成、ResNet网络由残差块(Residual Building Block)组成一样,DenseNet网络由Dense Block组成,论文截图如下所示:每个层从前面的所有层获得额外的输入,并将自己的特征映射传递到后续的所有层,使用级联(Concatenation)方式,每一层都在接受来自前几层的”集体知识(collective knowledge)”。增长率(growth rate)k是每个层的额外通道数。

58c8038c8e5f0cf7dea34eb09bd15c88.png

其实说了那么多我也不大明白原理和数学推理,只需要按照相关代码做就行了

class Bottleneck(nn.Module):
    def __init__(self, input_channel, growth_rate):
        super(Bottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(input_channel)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(input_channel, 4 * growth_rate, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(4 * growth_rate)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1)
    def forward(self, x):
        out = self.conv1(self.relu1(self.bn1(x)))
        out = self.conv2(self.relu2(self.bn2(out)))
        out = torch.cat([out, x], 1)
        return out
class Transition(nn.Module):
    def __init__(self, input_channels, out_channels):
        super(Transition, self).__init__()
        self.bn = nn.BatchNorm2d(input_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(input_channels, out_channels, kernel_size=1)
    def forward(self, x):
        out = self.conv(self.relu(self.bn(x)))
        out = F.avg_pool2d(out, 2)
        return out
class DenseNet(nn.Module):
    def __init__(self, nblocks, growth_rate, reduction, num_classes):
        super(DenseNet, self).__init__()
        self.growth_rate = growth_rate
        num_planes = 2 * growth_rate
        self.basic_conv = nn.Sequential(
            nn.Conv2d(3, 2 * growth_rate, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(2 * growth_rate),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.dense1 = self._make_dense_layers(num_planes, nblocks[0])
        num_planes += nblocks[0] * growth_rate
        out_planes = int(math.floor(num_planes * reduction))
        self.trans1 = Transition(num_planes, out_planes)
        num_planes = out_planes
        self.dense2 = self._make_dense_layers(num_planes, nblocks[1])
        num_planes += nblocks[1] * growth_rate
        out_planes = int(math.floor(num_planes * reduction))
        self.trans2 = Transition(num_planes, out_planes)
        num_planes = out_planes
        self.dense3 = self._make_dense_layers(num_planes, nblocks[2])
        num_planes += nblocks[2] * growth_rate
        out_planes = int(math.floor(num_planes * reduction))
        self.trans3 = Transition(num_planes, out_planes)
        num_planes = out_planes
        self.dense4 = self._make_dense_layers(num_planes, nblocks[3])
        num_planes += nblocks[3] * growth_rate
        self.AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d(1)
        # 全连接层
        self.fc = nn.Sequential(
            nn.Linear(num_planes, 256),
            nn.ReLU(inplace=True),
            # 使一半的神经元不起作用,防止参数量过大导致过拟合
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )
    def _make_dense_layers(self, in_planes, nblock):
        layers = []
        for i in range(nblock):
            layers.append(Bottleneck(in_planes, self.growth_rate))
            in_planes += self.growth_rate
        return nn.Sequential(*layers)
    def forward(self, x):
        out = self.basic_conv(x)
        out = self.trans1(self.dense1(out))
        out = self.trans2(self.dense2(out))
        out = self.trans3(self.dense3(out))
        out = self.dense4(out)
        out = self.AdaptiveAvgPool2d(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out
def DenseNet121():
    return DenseNet([6, 12, 24, 16], growth_rate=32, reduction=0.5, num_classes=10)
def DenseNet169():
    return DenseNet([6, 12, 32, 32], growth_rate=32, reduction=0.5, num_classes=10)
def DenseNet201():
    return DenseNet([6, 12, 48, 32], growth_rate=32, reduction=0.5, num_classes=10)
def DenseNet265():
    return DenseNet([6, 12, 64, 48], growth_rate=32, reduction=0.5, num_classes=10)
# 初始化模型
from torchstat import stat
# 定义模型输出模式,GPU和CPU均可
model = DenseNet121().to(DEVICE)

在NVIDIA GeForce GTX 1660 SUPER显卡上训练了100轮,大致上一轮1分钟,这是DenseNet网络训练的损失率和准确率,在验证集也是保持80%的准确率。

fef4e1c3ccec7ba873ae14960d444595.png

DenseNet也是一个系列,包括DenseNet-121、DenseNet-169等等,论文中给出了4种层数的DenseNet,论文截图如下所示:所有网络的增长率k是32,表示每个Dense Block中每层输出的feature map个数。

410bfbcfc3a6141a28efec184547aa49.png

关于图像分类的模型算法,热情也没了,到此也就告一段落了,后续再讨论一些新的话题。

最后欢迎关注公众号:python与大数据分析

47d362ba65d9cc25fac0dea80aa05dc0.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/905024.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何下载英伟达NVIDIA旧版本驱动,旧版本驱动官方网址

https://www.nvidia.cn/Download/Find.aspx?langcn 也可以直接搜索英伟达官网,点击驱动程序,然后点击试用版驱动程序,里面不但有试用版的驱动,还有之前发布的所有驱动

redis乐观锁+启用事务解决超卖

乐观锁用于监视库存(watch),然后接下来就启用事务。 启用事务,将减库存、下单这两个步骤,放到一个事务当中即可解决秒杀问题、防止超卖。 但是!!!乐观锁,会带来" …

Docker 使用归纳总结

mongodb 的 terminal 可执行的命令是基于这个文件夹

【LeetCode】剑指 Offer Ⅱ 第4章:链表(9道题) -- Java Version

题库链接:https://leetcode.cn/problem-list/e8X3pBZi/ 类型题目解决方案双指针剑指 Offer II 021. 删除链表的倒数第 N 个结点双指针 哨兵 ⭐剑指 Offer II 022. 链表中环的入口节点(环形链表)双指针:二次相遇 ⭐剑指 Offer I…

5.7.webrtc线程的启动与运行

那在上一节课中呢?我向你介绍了web rtc的三大线程,包括了信令线程,工作线程以及网络线程。那同时呢,我们知道了web rtc 3大线程创建的位置以及运行的时机。 对吧,那么今天呢?我们再继续深入了解一下&#…

SSM框架的学习与应用(Spring + Spring MVC + MyBatis)-Java EE企业级应用开发学习记录(第一天)Mybatis的学习

SSM框架的学习与应用(Spring Spring MVC MyBatis)-Java EE企业级应用开发学习记录(第一天)Mybatis的学习 一、当前的主流框架介绍(这就是后期我会发出来的框架学习) Spring框架 ​ Spring是一个开源框架,是为了解决企业应用程序开发复杂…

【2023最新爬虫】用python爬取知乎任意问题下的全部回答

老规矩,先上结果: 爬取了前200多页,每页5条数据,共1000多条回答。(程序设置的自动判断结束页,我是手动break的) 共爬到13个字段,包含: 问题id,页码,答主昵称,答主性别,…

Baumer工业相机堡盟工业相机如何通过BGAPISDK设置相机的Bufferlist序列(C#)

Baumer工业相机堡盟工业相机如何通过BGAPISDK设置相机的Bufferlist序列(C#) Baumer工业相机Baumer工业相机的Bufferlist序列功能的技术背景CameraExplorer如何查看相机Bufferlist功能在BGAPI SDK里通过函数设置相机固定帧率 Baumer工业相机通过BGAPI SDK…

文件同步工具rsync

文章目录 作用特性安装命令服务端启动增加安全认证及免密登录 实时推送源服务器配置结合inotify实现实时推送 参数详解 学些过程中遇到的问题 作用 rsync是linux系统下的数据镜像备份工具。使用快速增量备份工具Remote Sync可以远程同步,支持本地复制,或…

05有监督学习——神经网络

线性模型 给定n维输入: x [ x 1 , x 1 , … , x n ] T x {[{x_1},{x_1}, \ldots ,{x_n}]^T} x[x1​,x1​,…,xn​]T 线性模型有一个n维权重和一个标量偏差: w [ w 1 , w 1 , … , w n ] T , b w {[{w_1},{w_1}, \ldots ,{w_n}]^T},b w[w1​,w1​,…,wn​]T,b 输…

Elasticsearch 处理地理信息

1、GeoHash ​ GeoHash是一种地理坐标编码系统,可以将地理位置按照一定的规则转换为字符串,以方便对地理位置信息建立空间索引。首先要明确的是,GeoHash代表的不是一个点而是一个区域。GeoHash具有两个显著的特点:一是通过改变 G…

7-6 统计字符出现次数

分数 20 全屏浏览题目 切换布局 作者 C课程组 单位 浙江大学 本题要求编写程序,统计并输出某给定字符在给定字符串中出现的次数。 输入格式: 输入第一行给出一个以回车结束的字符串(少于80个字符);第二行输入一个…

Android JNI系列详解之JNI、NDK环境搭建和编译工具安装

本文主要介绍JNI、NDK环境变量的搭建,以及CMake工具的安装和ndk-build工具的安装。 一、JNI环境 JNI属于Java中的一部分,所以只需要搭建Java环境就有了JNI的环境,安装Java的环境可以网上查找教程,很多的安装JDK的博客。我电脑是安…

检测输电线上的鸟巢,用SSD结合HSV色彩空间滤波器相结合的检测方法--论文中图还少一张,欠点意思

Detection of Bird Nests on Power Line Patrol Using Single Shot Detector Abstract 电力塔上鸟巢的存在对输电线路的安全稳定构成了威胁。近年来,利用无人机探测输电线路上的鸟巢已成为电力巡检的重要任务之一。图像处理方法从计算机视觉向功率图像识别的迁移日…

MySQL数据库第十四课--------sql优化---------层层递进

作者前言 🎂 ✨✨✨✨✨✨🍧🍧🍧🍧🍧🍧🍧🎂 ​🎂 作者介绍: 🎂🎂 🎂 🎉🎉&#x1f389…

Git问题:解决“ssh:connect to host github.com port 22: Connection timed out”

操作系统 Windows11 使用Git IDEA 连接方式:SSH 今天上传代码出现如下报错:ssh:connect to host github.com port 22: Connection timed out 再多尝试几次,依然是这样。 解决 最终发现两个解决方案:(二选一&#xf…

GEEMAP 中如何拉伸图像

图像拉伸是最基础的图像增强显示处理方法,主要用来改善图像显示的对比度,地物提取流程中往往首先要对图像进行拉伸处理。图像拉伸主要有三种方式:线性拉伸、直方图均衡化拉伸和直方图归一化拉伸。 GEE 中使用 .sldStyle() 的方法来进行图像的…

js 的正则表达式(二)

1.正则表达式分类: 正则表达式分为普通字符和元字符。 普通字符: 仅能够描述它们本身,这些字符称作普通字符,例如所有的字母和数字。也就是说普通字符只能够匹配字符串中与它们相同的字符。 元字符: 是一些具有特殊含…

最新ChatGPT网站程序源码+AI系统+详细图文搭建教程/支持GPT4.0/AI绘画/H5端/Prompt知识库

一、前言 SparkAi系统是基于国外很火的ChatGPT进行开发的Ai智能问答系统。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。 那么如何搭建部署AI创作ChatGPT?小编这里写一个详细图文教程吧&#xff01…

情报与GPT技术大幅降低鱼叉攻击成本

邮件鱼叉攻击(spear phishing attack)是一种高度定制化的网络诈骗手段,攻击者通常假装是受害人所熟知的公司或组织发送电子邮件,以骗取受害人的个人信息或企业机密。 以往邮件鱼叉攻击需要花费较多的时间去采集情报、深入了解受…