深度学习Day-19:DenseNet算法实战与解析

news2025/1/23 8:07:04

 🍨 本文为:[🔗365天深度学习训练营] 中的学习记录博客
 🍖 原作者:[K同学啊 | 接辅导、项目定制]

要求:

  • 根据 Pytorch 代码,编写出 TensorFlow 代码
  • 研究 DenseNet 与 ResNetV 的区别
  • 改进思路是否可以迁移到其他地方

一、前言

        在计算机视觉领域,卷积神经网络(CNN)已经成为最主流的方法,比如GoogLenet,VGG-16,Incepetion等模型。CNN史上的一个里程碑事件是ResNet模型的出现,ResNet可以训练出更深的CNN模型,从而实现更高的准确度。ResNet模型的核心是通过建立前面层与后面层之间的“短路连接”进而训练出更深的CNN网络。

        DenseNet,它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接(dense connection),它的名称也是由此而来。DenseNet的另一大特色是通过特征在channel上的连接来实现特征重用(feature reuse)。这些特点让DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能。

二、设计理念

        相比ResNet,DenseNet提出了一个更激进的密集连接机制:即互相连接所有的层,具体来说就是每个层都会接受其前面所有层作为其额外的输入。

        图1为ResNet网络的残差连接机制,作为对比,图2为DenseNet的密集连接机制。可以看到,ResNet是每个层与前面的某层(一般是2~4层)短路连接在一起,连接方式是通过元素相加。而在DenseNet中,每个层都会与前面所有层在channel维度上连接(concat)在一起(即元素叠加),并作为下一层的输入。

        对于一个 L 层的网络,DenseNet共包含 \frac{L(L+1))}{2} 个连接,相比ResNet,这是一种密集连接。而且DenseNet是直接concat来自不同层的特征图,这可以实现特征重用,提升效率,这一特点是DenseNet与ResNet最主要的区别。

 图1:ResNet网络的短路连接机制(其中+代表的是元素级相加操作)

 图2:DenseNet网络的密集连接机制(其中c代表的是channel级连接操作)

        而对于DesNet,则是采用跨通道concat的形式来连接,会连接前面所有层作为输入,输入和输出的公式是Xl​=Hl​(X0​,X1​,...,Xl−1​)。这里要注意所有的层的输入都来源于前面所有层在channel维度的concat,用一张动图体会一下:

图3 DenseNet的前向过程 

 三、网络结构

图4 DenseNet的网络结构 

        CNN网络一般要经过Pooling或者stride>1的Conv来降低特征图的大小,而DenseNet的密集连接方式需要特征图大小保持一致。为了解决这个问题,DenseNet网络采用DenseBlock+Transition的结构,其中DenseBlock是包含很多层的模块,每个层的特征图大小相同,层与层之间采用密集连接方式。而Transition层是连接两个相邻的DenseBlock,并且通过Pooling使特征图大小降低。图5给出了DenseNet的网路结构,它共包含4个DenseBlock,各个DenseBlock之间通过Transition层连接在一起。

图5 采用DenseBlock+Transition的DenseNet网络 

         在DenseBlock中,各个层的特征图大小一致,可以在channel维度上连接。DenseBlock中的非线性组合函数 H ( ⋅ )的是BN + ReLU + 3x3 Conv的结构,如图所示。另外值得注意的一点是,与ResNet不同,所有DenseBlock中各个层卷积之后均输入k个特征图,即得到的特征图的channel数为k,或者说采用k个卷积核。k在DenseNet称为growth rate,这是一个超参数。一般情况下使用较小的k,就可以得到较佳的性能。

图6 DenseBlock中的非线性转换结构 

        由于后面层的输入会非常大,DenseBlock内部可以采用bottleneck层来减少计算量,主要是原有的结构中增加1x1 Conv,如图7所示,即 BN + ReLU + 1x1 Conv + BN + ReLU + 3x3 Conv,称为DenseNet-B结构。其中1x1 Conv得到4k个特征图,它起到的作用是降低特征数量,从而提升计算的效率。

图7 使用bottleneck层的DenseBlock结构 

        对于Transition层,它主要是连接两个相邻的DenseBlock,并且降低特征图大小。Transition层包括一个1x1的卷积和2x2的AvgPooling,结构为BN+ReLU+1x1Conv+2x2AvgPooling。另外,Transition层可以起到压缩模型的作用。

四、Tensorflow实现

1.代码

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow import keras

data_dir = "./data/bird_photos"
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,  # 分割数据集
    subset="training",  # 数据集类型
    seed=123,
    image_size=(224, 224),
    batch_size=32)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,  # 分割数据集
    subset="validation",  # 数据集类型
    seed=123,
    image_size=(224, 224),
    batch_size=32)


class BottleNeck(keras.Model):
    def __init__(self, growth_rate, bn_size=4, dropout=0.3):
        super().__init__()
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.Activation("relu"),
        self.conv1 = layers.Conv2D(filters=bn_size * growth_rate, kernel_size=(1, 1),
                                   strides=1, padding='same')
        self.bn2 = layers.BatchNormalization()
        self.conv2 = layers.Conv2D(filters=growth_rate, kernel_size=(3, 3),
                                   strides=1, padding='same')
        self.dropout = layers.Dropout(rate=dropout)

        self.listLayers = [
            self.bn1,
            self.relu,
            self.conv1,
            self.bn2,
            self.relu,
            self.conv2,
            self.dropout
        ]

    def call(self, x):
        tem = x
        for layer in self.listLayers.layers:
            x = layer(x)
        return layers.concatenate([tem, x], axis=-1)


class Transition(tf.keras.Model):
    def __init__(self, growth_rate):
        super().__init__()
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.Activation('relu')
        self.conv1 = layers.Conv2D(filters=growth_rate, kernel_size=(1, 1),
                                   strides=1, activation='relu', padding='same')
        self.pooling = layers.AveragePooling2D(pool_size=(2, 2), strides=2, padding='same')

        self.listLayers = [
            self.bn1,
            self.relu,
            self.conv1,
            self.pooling
        ]

    def call(self, x):
        for layer in self.listLayers.layers:
            x = layer(x)
        return x


class DenseBlock(tf.keras.Model):
    def __init__(self, num_layer, growth_rate, bn_size=4, dropout=0.3, efficient=False):
        super().__init__()
        self.efficient = efficient
        self.listLayers = []
        if self.efficient:
            _x = tf.recompute_grad(BottleNeck(growth_rate, bn_size=bn_size, dropout=dropout))
        else:
            _x = BottleNeck(growth_rate, bn_size=bn_size, dropout=dropout)
        for _ in range(num_layer):
            self.listLayers.append(BottleNeck(growth_rate, bn_size=bn_size, dropout=dropout))

    def call(self, x):
        for layer in self.listLayers.layers:
            x = layer(x)
        return x


class DenseNet(tf.keras.Model):
    def __init__(self, num_init_feature, growth_rate, block_config, num_classes,
                 bn_size=4, dropout=0.3, compression_rate=0.5, efficient=False):
        super().__init__()
        self.num_channels = num_init_feature
        self.conv = layers.Conv2D(filters=num_init_feature, kernel_size=7,
                                  strides=2, padding='same')
        self.bn = layers.BatchNormalization()
        self.relu = layers.Activation('relu')
        self.max_pool = layers.MaxPool2D(pool_size=3, strides=2, padding='same')

        self.dense_block_layers = []
        for i in block_config[:-1]:
            self.dense_block_layers.append(DenseBlock(num_layer=i, growth_rate=growth_rate,
                                                      bn_size=bn_size, dropout=dropout, efficient=efficient))
            self.num_channels = compression_rate * (self.num_channels + growth_rate * i)
            self.dense_block_layers.append(Transition(self.num_channels))

        self.dense_block_layers.append(DenseBlock(num_layer=block_config[-1], growth_rate=growth_rate,
                                                  bn_size=bn_size, dropout=dropout, efficient=efficient))

        self.avgpool = layers.GlobalAveragePooling2D()
        self.fc = tf.keras.layers.Dense(units=num_classes, activation=tf.keras.activations.softmax)

    def call(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.max_pool(x)

        for layer in self.dense_block_layers.layers:
            x = layer(x)

        x = self.avgpool(x)
        return self.fc(x)


model = DenseNet(num_init_feature=64,
                 growth_rate=32,
                 block_config=[6, 12, 24, 16],
                 compression_rate=0.5,
                 num_classes=4,
                 dropout=0.0,
                 efficient=True)

x = tf.random.normal((1, 224, 224, 3))
for layer in model.layers:
    x = layer(x)
    print(layer.name, 'output shape:\t', x.shape)

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
opt = tf.keras.optimizers.Adam(learning_rate=0.002, decay=0.01)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
epochs = 10
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)

2.输出

Found 565 files belonging to 4 classes.
Using 113 files for validation.
conv2d output shape:	 (1, 112, 112, 64)
batch_normalization output shape:	 (1, 112, 112, 64)
activation output shape:	 (1, 112, 112, 64)
max_pooling2d output shape:	 (1, 56, 56, 64)
dense_block output shape:	 (1, 56, 56, 256)
transition output shape:	 (1, 28, 28, 128)
dense_block_1 output shape:	 (1, 28, 28, 512)
transition_1 output shape:	 (1, 14, 14, 256)
dense_block_2 output shape:	 (1, 14, 14, 1024)
transition_2 output shape:	 (1, 7, 7, 512)
dense_block_3 output shape:	 (1, 7, 7, 1024)
global_average_pooling2d output shape:	 (1, 1024)
dense output shape:	 (1, 4)
Epoch 1/10
15/15 [==============================] - 83s 5s/step - loss: 2.5051 - accuracy: 0.4508 - val_loss: 37955.5703 - val_accuracy: 0.3186
Epoch 2/10
15/15 [==============================] - 75s 5s/step - loss: 1.1275 - accuracy: 0.6922 - val_loss: 3854.2537 - val_accuracy: 0.2301
Epoch 3/10
15/15 [==============================] - 78s 5s/step - loss: 0.6559 - accuracy: 0.7780 - val_loss: 794.8064 - val_accuracy: 0.3186
Epoch 4/10
15/15 [==============================] - 84s 6s/step - loss: 0.5599 - accuracy: 0.7926 - val_loss: 94.6405 - val_accuracy: 0.2655
Epoch 5/10
15/15 [==============================] - 74s 5s/step - loss: 0.7278 - accuracy: 0.7682 - val_loss: 45.8066 - val_accuracy: 0.3186
Epoch 6/10
15/15 [==============================] - 74s 5s/step - loss: 0.3567 - accuracy: 0.8904 - val_loss: 16.1634 - val_accuracy: 0.3186
Epoch 7/10
15/15 [==============================] - 74s 5s/step - loss: 0.2239 - accuracy: 0.9287 - val_loss: 10.3661 - val_accuracy: 0.3186
Epoch 8/10
15/15 [==============================] - 75s 5s/step - loss: 0.1488 - accuracy: 0.9406 - val_loss: 1.8957 - val_accuracy: 0.5044
Epoch 9/10
15/15 [==============================] - 75s 5s/step - loss: 0.1024 - accuracy: 0.9630 - val_loss: 1.1245 - val_accuracy: 0.6018
Epoch 10/10
15/15 [==============================] - 76s 5s/step - loss: 0.0563 - accuracy: 0.9895 - val_loss: 1.2219 - val_accuracy: 0.5133

        具体代码细节在之前的文章中已有涉及,故不再做具体解释。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1789183.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

概率论与数理统计,重要知识点——全部公式总结

二、一维随机变量及其分布 五个分布参考另外一篇文章 四、随机变量的数字特征 大数定理以及中心极限定理 六、数理统计

Python量化交易学习——Part4:基于基本面的单因子选股策略

技术分析与基本面分析是股票价格分析最基础也是最经典的两个部分。技术分析是针对交易曲线及成交量等指标进行分析,基本面分析是基于公司的基本素质进行分析。 一般来说选股要先选行业,在选个股,之后根据技术分析选择买卖节点,因此针对行业及个股的基本面分析是选股的基础。…

python数据分析案例-研究生成绩分析

一、简介 在本次研究中,我们对2018年硕士生考试成绩数据进行了深入的统计分析。这项分析旨在探索不同因素如性别、生源背景、基因型以及出生月份等对学生成绩的潜在影响。我们使用了一系列的统计方法,包括描述性统计分析、相关性分析、分组分析以及方差…

【Java数据结构】二叉树详解(二)

🔒文章目录: 1.❤️❤️前言~🥳🎉🎉🎉 2. 二叉树的模拟——正文 2.1获取树中节点的个数 2.2获取叶子节点的个数 2.3获取第K层节点的个数 2.4获取二叉树的高度 2.5 检测值为value的元素是否存在 …

WPF Treeview控件开虚拟化后定位节点

不开虚拟化&#xff0c;可以用下面的方法直接定位 <Window x:Class"WpfApplication2.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"Title"Main…

Qt OPC UA通信

介绍 OPC UA全称Open Platform Unified Architecture&#xff0c;开放平台统一架构&#xff0c;是工业自动化领域通用的数据交换协议&#xff0c;它有两套主要的通信机制&#xff1a;1.客户端-服务器通信&#xff1b;2.发布订阅。Qt对OPC UA通信标准也提供了支持&#xff0c;目…

JDBC学习笔记(三)高级篇

一、JDBC 优化及工具类封装 1.1 现有问题 1.2 JDBC 工具类封装 V1.0 resources/db.properties配置文件&#xff1a; driverClassNamecom.mysql.cj.jdbc.Driver urljdbc:mysql:///atguigu usernameroot password123456 initialSize10 maxActive20 工具类代码&#xff1a; p…

代码随想录算法训练营第二十八天|93.复原IP地址 ,78.子集 ,90.子集II

93. 复原 IP 地址 - 力扣&#xff08;LeetCode&#xff09; class Solution {ArrayList<String> results new ArrayList<>();public List<String> restoreIpAddresses(String s) {if(s.length() > 12){return new ArrayList<>();}char[] ipChars …

f4pga环境搭建教程

f4pga环境搭建教程 背景介绍 FOSS Flows For FPGA (F4PGA) project&#xff0c;是一套开源的FPGA工具链&#xff0c;号称the GCC of FPGAs&#xff0c;作用是将写的硬件描述语言&#xff08;verilog或VHDL&#xff09;转化为可以在FPGA上运行的可执行文件&#xff08;bit文件…

Python实现PPT表格的编写包含新建修改插图(收藏备用)

自动创建一个ppt文件并创建好表格 代码要用到pptx库 pip install python-pptx 创建含有表格的ppt文件代码&#xff1a; from pptx import Presentation from pptx.util import Inches# 创建一个PPT对象 ppt Presentation()# 添加一个幻灯片 slide ppt.slides.add_slide(p…

原美团项目管理专业通道执行主席边国华受邀为第十三届中国PMO大会演讲嘉宾

全国PMO专业人士年度盛会 峰项标&#xff08;北京&#xff09;管理咨询有限公司常务副总裁、原美团项目管理专业通道执行主席边国华先生受邀为PMO评论主办的2024第十三届中国PMO大会演讲嘉宾&#xff0c;演讲议题为“从组织级项目管理能力的评价角度看企业实践”。大会将于6月2…

Python读取字节数组

读取和处理bytearray中的值 # 输出&#xff1a;Combined 16-bit value: 1234 python-can发送和接收CAN报文 import can # 创建一个CAN总线对象&#xff08;这取决于你的硬件和驱动程序&#xff09; bus can.interface.Bus(channelcan0, bustypesocketcan) # 定义一个CAN…

django 内置 JSON 字段 使用场景

Django 内置的 JSON 字段&#xff08;JSONField&#xff09;是在 Django 3.1 版本中引入的&#xff0c;用于处理 JSON 格式的数据。JSONField 允许在数据库表中存储和查询 JSON 数据&#xff0c;并且在与 Python 代码交互时自动转换为合适的 Python 数据类型。以下是一些常见的…

成都欣丰洪泰文化传媒有限公司好不好?

在数字经济的浪潮中&#xff0c;电商行业以其独特的魅力和无限的发展潜力&#xff0c;吸引了越来越多的企业和个人投身其中。作为电商服务领域的佼佼者&#xff0c;成都欣丰洪泰文化传媒有限公司凭借专业的团队、优质的服务和创新的理念&#xff0c;不断引领电商新风尚&#xf…

INT202 例题

算法复杂度 O(n)&#xff1a;表示算法的渐进上界。如果一个算法的运行时间是O(n)&#xff0c;那么它的运行时间最多与输入规模n成正比。换句话说&#xff0c;当输入规模n增加时&#xff0c;算法的运行时间不会超过某个常数倍的n。比如&#xff0c;如果一个算法的时间复杂度是O(…

AndroidStudio使用高德地图API获取手机定位

一、高德地图API申请 首先去高德注册开发者账号 下面这两个选项&#xff0c;也是我们项目成功的关键 1.1怎么获取SHA1指纹密码 ①使用AS自带的签名文件 你的用户文件下面会有一个.android文件夹,进入文件夹,在这个路径下打开cmd 如果.android下面没有签名文件参考创建文章 …

【管理咨询宝藏124】通过BLM打通前端业务与财务的双轨制设计方案

本报告首发于公号“管理咨询宝藏”&#xff0c;如需阅读完整版报告内容&#xff0c;请查阅公号“管理咨询宝藏”。 【管理咨询宝藏124】通过BLM打通前端业务与财务的双轨制设计方案 【格式】PDF版本 【关键词】BLM、组织架构设计、流程优化 【核心观点】 - 运用“拉通业务财务…

【原创】springboot+mysql大学生综合素质测评管理系统设计与实现

个人主页&#xff1a;程序猿小小杨 个人简介&#xff1a;从事开发多年&#xff0c;Java、Php、Python、前端开发均有涉猎 博客内容&#xff1a;Java项目实战、项目演示、技术分享 文末有作者名片&#xff0c;希望和大家一起共同进步&#xff0c;你只管努力&#xff0c;剩下的交…

【python】成功解决“ValueError: Expected 2D array, got 1D array instead”错误的全面指南

成功解决“ValueError: Expected 2D array, got 1D array instead”错误的全面指南 一、引言 在Python的数据分析和机器学习领域&#xff0c;尤其是使用NumPy、Pandas、scikit-learn等库时&#xff0c;经常会遇到各种类型错误。其中&#xff0c;“ValueError: Expected 2D arr…

MP-SPDZ的学习与运用

目录 MP-SPDZ 的介绍主要功能典型应用场景 MP-SPDZ 的安装实验环境准备环境安装MP-SPDZ 下载和编译 MP-SPDZ 的使用测试程序第三方求和三方计算测试冒泡排序比较运算函数语法详解——Sint语法详解——Array基于AES电路实现OPRFORAM隐私集合求交实现两台虚拟机之间进行MPC简单实…