使用 TensorFlow 创建 DenseNet 121

news2024/12/29 9:20:30

一、说明

        本篇示意DenseNet如何在tensorflow上实现,DenseNet与ResNet有类似的地方,都有层与层的“短路”方式,但两者对层的短路后处理有所不同,本文遵照原始论文的技术路线,完整复原了DenseNet的全部网络。

图1:DenseNet中的各种块和层(来源:原始DenseNet论文)

      

二、DenseNet综述

        DenseNet(密集卷积网络)是一种架构,专注于使深度学习网络更深入,但同时通过在层之间使用更短的连接来提高它们的训练效率。DenseNet 是一个卷积神经网络,其中每一层都连接到网络中更深的所有其他层,即第一层连接到第 2、3、4 层等,第二层连接到第 3、4、5 层等。这样做是为了在网络各层之间实现最大的信息流。为了保持前馈特性,每一层从前面的所有层获取输入,并将自己的特征图传递给它将要到达的所有层。与 Resnets 不同,它不是通过求和来组合特征,而是通过连接它们来组合特征。因此,“ith”层具有“i”输入,并且由其所有先前卷积块的特征图组成。它自己的特征图被传递到所有下一个“I-i”层。这在网络中引入了“(I(I+1)))/2”连接,而不是像传统深度学习架构中那样只是“I”连接。因此,与传统的卷积神经网络相比,它需要的参数更少,因为不需要学习不重要的特征图。

        DenseNet由两个重要的块组成,而不是基本的卷积层和池化层。它们是密集块和过渡层。

        接下来,我们看看所有这些块和层的外观,以及如何在 python 中实现它们。

图2:DenseNet121框架(来源:DenseNet原始论文,由作者编辑)

        DenseNet从基本的卷积和池化层开始。然后有一个密集块,然后是一个过渡层,另一个密集块后跟一个过渡层,另一个密集块后跟一个过渡层,最后是一个密集块,然后是一个分类层。

        第一个卷积块有 64 个大小为 7x7 的过滤器,步幅为 2。接下来是最大池化为 3x3 且步幅为 2 的 MaxPooling 层。这两行可以在 python 中用以下代码表示。

input = Input (input_shape)
x = Conv2D(64, 7, strides = 2, padding = 'same')(input)
x = MaxPool2D(3, strides = 2, padding = 'same')(x)

2.1 定义卷积块 —

        输入后的每个卷积块具有以下顺序:批处理归一化,然后是 ReLU 激活,然后是实际的 Conv2D 层。为了实现这一点,我们可以编写以下函数。

#batch norm + relu + conv
def bn_rl_conv(x,filters,kernel=1,strides=1):
        
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(filters, kernel, strides=strides,padding = 'same')(x)
    return x

图3.密集块(来源:DenseNet论文-作者编辑)

2.2 定义密集块 —

        如图 3 所示,每个密集块都有两个卷积,具有 1x1 和 3x3 大小的内核。在密集块 1 中,重复 6 次,在密集块 2 中重复 12 次,在密集块 3 中重复 24 次,最后在密集块 4 中重复 16 次。

在密集块中,每个 1x1 卷积都有 4 倍的滤波器数量。所以我们使用 4*过滤器,但 3x3 过滤器只存在一次。此外,我们必须将输入与输出张量连接起来。

每个块分别运行 6、12、24、16 次重复,使用 'for 循环'。

def dense_block(x, repetition):
        
   for _ in range(repetition):
        y = bn_rl_conv(x, 4*filters)
        y = bn_rl_conv(y, filters, 3)
        x = concatenate([y,x])
   return x

图4:过渡层(来源:DenseNet论文,作者编辑)

2.3 定义过渡层 

        — 在过渡层中,我们将通道数减少到现有通道的一半。有一个 1x1 卷积层和一个 2x2 平均池化层,步幅为 2。bn_rl_conv,函数中已经设置了 1x1 的内核大小,因此我们不需要明确地再次定义它。

        在过渡层中,我们必须将通道删除到现有通道的一半。我们有输入张量x,我们想找到有多少个通道,我们需要得到其中的一半。因此,我们可以使用 Keras 后端 (K) 获取张量 x 并返回一个维度为 x 的元组。而且,我们只需要该形状的最后一个数字,即过滤器的数量。所以我们加上 [-1]。最后,我们可以将这个数量的过滤器除以 2 以获得所需的结果。

def transition_layer(x):
        
    x = bn_rl_conv(x, K.int_shape(x)[-1] //2 )
    x = AvgPool2D(2, strides = 2, padding = 'same')(x)
    return x

        因此,我们完成了定义密集块和过渡层的工作。现在我们需要将密集块和过渡层堆叠在一起。所以我们写了一个 for 循环来运行 6,12,24,16 次重复。因此,循环运行 4 次,每次使用 6、12、24 或 16 中的值之一。这样就完成了 4 个密集块和过渡层。

for repetition in [6,12,24,16]:
        
    d = dense_block(x, repetition)
    x = transition_layer(d)

        最后,是GlobalAveragePooling,然后是最终的输出层。正如我们在上面的代码块中看到的,密集块由“d”定义,而在最后一层,在密集块 4 之后,没有过渡层 4,而是直接进入分类层。因此,“d”是应用GlobalAveragePooling的连接,而不是“x”。另一种选择是从上面的代码中删除“for”循环,并将层一个接一个地堆叠,而不使用最终的过渡层。

x = GlobalAveragePooling2D()(d)
output = Dense(n_classes, activation = 'softmax')(x)

现在我们已经将所有块放在一起,让我们将它们合并以查看整个DenseNet架构。

三、完整的 DenseNet 121 架构 

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Dense
from tensorflow.keras.layers import AvgPool2D, GlobalAveragePooling2D, MaxPool2D
from tensorflow.keras.models import Model
from tensorflow.keras.layers import ReLU, concatenate
import tensorflow.keras.backend as K
# Creating Densenet121
def densenet(input_shape, n_classes, filters = 32):
    
    #batch norm + relu + conv
    def bn_rl_conv(x,filters,kernel=1,strides=1):
        
        x = BatchNormalization()(x)
        x = ReLU()(x)
        x = Conv2D(filters, kernel, strides=strides,padding = 'same')(x)
        return x
    
    def dense_block(x, repetition):
        
        for _ in range(repetition):
            y = bn_rl_conv(x, 4*filters)
            y = bn_rl_conv(y, filters, 3)
            x = concatenate([y,x])
        return x
        
    def transition_layer(x):
        
        x = bn_rl_conv(x, K.int_shape(x)[-1] //2 )
        x = AvgPool2D(2, strides = 2, padding = 'same')(x)
        return x
    
    input = Input (input_shape)
    x = Conv2D(64, 7, strides = 2, padding = 'same')(input)
    x = MaxPool2D(3, strides = 2, padding = 'same')(x)
    
    for repetition in [6,12,24,16]:
        
        d = dense_block(x, repetition)
        x = transition_layer(d)
    x = GlobalAveragePooling2D()(d)
    output = Dense(n_classes, activation = 'softmax')(x)
    
    model = Model(input, output)
    return model
input_shape = 224, 224, 3
n_classes = 3
model = densenet(input_shape,n_classes)
model.summary()

输出:(假设 3 个最终类 — 模型摘要的最后几行)

四、 查看体系结构关系图 

        可以使用以下代码。

from tensorflow.python.keras.utils.vis_utils import model_to_dot
from IPython.display import SVG
import pydot
import graphviz

SVG(model_to_dot(
    model, show_shapes=True, show_layer_names=True, rankdir='TB',
    expand_nested=False, dpi=60, subgraph=False
).create(prog='dot',format='svg'))

        输出 — 图表的前几个块

        这就是我们如何实现DenseNet 121架构。

五、引用 

  1. 黄高、刘壮、劳伦斯·范德马滕和基利安·温伯格,密集连接的卷积网络,arXiv 1608.06993 (2016)

    阿琼·萨卡尔

       2 密网论文链接:https://arxiv.org/pdf/1608.06993.pdf 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1067379.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

评价指标篇——IOU(交并比)

什么是IoU(Intersection over Union) IoU是一种测量在特定数据集中检测相应物体准确度的一个标准。 即是产生的候选框(candidate bound)与原标记框(ground truth bound)的交叠率 即它们的交集与并集的比值。最理想情况是完全重叠…

CVE-2023-5129:libwebp开源库10分漏洞

谷歌为libwebp漏洞分配新的CVE编号,CVSS评分10分。 Libwebp是一个用于处理WebP格式图像编解码的开源库。9月6日,苹果公司安全工程和架构(SEAR)部门和加拿大多伦多大学研究人员在libwebp库中发现了一个0 day漏洞,随后&…

Linux SSH连接远程服务器(免密登录、scp和sftp传输文件)

1 SSH简介 SSH(Secure Shell,安全外壳)是一种网络安全协议,通过加密和认证机制实现安全的访问和文件传输等业务。传统远程登录和文件传输方式,例如Telnet、FTP,使用明文传输数据,存在很多的安全…

水果种植与果园监管“智慧化”,AI技术打造智慧果园视频综合解决方案

一、方案背景 我国是水果生产大国,果园种植面积大、产量高。由于果园的位置大都相对偏远、面积较大,值守的工作人员无法顾及到园区每个角落,因此人为偷盗、野生生物偷吃等事件时有发生,并且受极端天气如狂风、雷暴、骤雨等影响&a…

NOSQL Redis 数据持久化 RDB、AOF(二) 恢复

redis 执行flushall 或 flushdb 也会产生dump.rdb文件,但里面是空的。 注意:千万执行,不然rdb文件会被覆盖的。 dump.rdb 文件如何恢复数据 讲备份文件 dump.rdb 移动到redis安装目录并启动服务即可。 dump.rdb 自动触发 和手动触发 自…

Android 更新图标

什么是Android Multidex热更新 • Worktile社区 在不重启app的情况下热更新 &#xff0c;在所有新文件下载完成后&#xff0c;提示用户&#xff0c;是否重启 在不频繁新增图标的情况下可以使用 <adaptive-icon>在AndroidManifest.xml中设置app别名&#xff0c;以实现…

PCB走线的传输延时有多少

信号在PCB上的传输速度虽然很快&#xff0c;但也是存在延时的&#xff0c;一般经验值是6mil/ps。 也就是在PCB上&#xff0c;当信号线走线长度为6mil的时候&#xff0c;信号从驱动端到达接收端需要经过1ps。 信号在PCB上的传输速率&#xff1a; 其中C为信号在真空中的传播速率…

2023年【煤炭生产经营单位(安全生产管理人员)】证考试及煤炭生产经营单位(安全生产管理人员)模拟考试题库

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 煤炭生产经营单位&#xff08;安全生产管理人员&#xff09;证考试是安全生产模拟考试一点通生成的&#xff0c;煤炭生产经营单位&#xff08;安全生产管理人员&#xff09;证模拟考试题库是根据煤炭生产经营单位&…

记一次问题排查

1785年&#xff0c;卡文迪许在实验中发现&#xff0c;把不含水蒸气、二氧化碳的空气除去氧气和氮气后&#xff0c;仍有很少量的残余气体存在。这种现象在当时并没有引起化学家的重视。 一百多年后&#xff0c;英国物理学家瑞利测定氮气的密度时&#xff0c;发现从空气里分离出来…

练[BJDCTF2020]EasySearch

[BJDCTF2020]EasySearch 文章目录 [BJDCTF2020]EasySearch掌握知识解题思路关键paylaod 掌握知识 ​ 目录扫描&#xff0c;index.php.swp文件泄露&#xff0c;代码审计&#xff0c;MD5区块爆破&#xff0c;请求响应包的隐藏信息&#xff0c;.shtml文件RCE漏洞利用 解题思路 …

cpp primer plus笔记01-注意事项

cpp尽量以int main()写函数头而不是以main()或者int main(void)或者void main()写。 cpp尽量上图用第4行的注释而不是用第5行注释。 尽量不要引用命名空间比如:using namespace std; 函数体内引用的命名空间会随着函数生命周期结束而失效&#xff0c;放置在全局引用的命名空…

【LeetCode: 901. 股票价格跨度 | 单调栈】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

【代码随想录】LC 27. 移除元素

文章目录 前言一、题目1、原题链接2、题目描述 二、解题报告1、思路分析2、时间复杂度3、代码详解 三、知识风暴 前言 本专栏文章为《代码随想录》书籍的刷题题解以及读书笔记&#xff0c;如有侵权&#xff0c;立即删除。 一、题目 1、原题链接 27. 移除元素 2、题目描述 二、…

文创行业如何利用软文出圈?媒介盒子告诉你

经济快速发展与社会进步&#xff0c;带来的是人们消费观念的转型&#xff0c;人们的精神需求与文化自信不断增强&#xff0c;随着文化产业和旅游业的不断升级&#xff0c;文创产品凭借独特的概念、创新的形象&#xff0c;吸引许多消费者。那么新时代的文创行业应该如何强势出圈…

华为数通方向HCIP-DataCom H12-831题库(单选题:201-220)

第201题 DHCP Snooping是一种DHCP安全特性,这项技术可以防御以下哪些攻击? A、DHCP Server仿冒者攻击 B、针对DHCP客户端的畸形报文泛洪攻击 C、仿冒DHCP报文攻击 D、DHCP Server的拒绝服务攻击 答案:ABD 解析: 第202题 两台PE之间通过MP-BGP传播VPNv4路由,以下哪些场景…

csa从初阶到大牛(网络配置)

添加新的网络连接ens170&#xff0c;并设置静态IP地址 sudo vi /etc/sysconfig/network-scripts/ifcfg-ens170TYPEEthernet BOOTPROTOstatic DEFROUTEyes NAMEens170 DEVICEens170 ONBOOTyes IPADDR192.168.1.100 NETMASK255.255.255.02.向ens170网络连接添加2个ip地址 sudo …

2023年中国铁路通信系统发展历程、市场规模及行业发展趋势分析[图]

铁路通信技术是铁路行业发展的重要推动力&#xff0c;它不仅可以提高客运、货运、维修、安全等业务的运行效率&#xff0c;还能够有效提高铁路行业的客户服务水平&#xff0c;实现质量和效率的双重提升。我国通信技术在铁路领域的应用经历了三个阶段&#xff0c;分别是模拟通信…

VulnHub JANGOW

提示&#xff08;主机ip分配问题&#xff09; 因为直接在VulnHub上下载的盒子&#xff0c;在VMware上打开&#xff0c;默认是不分配主机的 所以我们可以在VirtualBox上打开 一、信息收集 发现开放了21和80端口&#xff0c;查看一下80端口 80端口&#xff1a; 检查页面后发现…

小白必看!上位机控制单片机原理

嗨&#xff0c;大家好&#xff01;今天&#xff0c;我们要探讨一个有趣的话题——"以上位机控制单片机"。不要担心&#xff0c;我们会用最简单的方式来解释这个概念。 首先&#xff0c;你可以把以上位机想象成一台超级聪明的电脑&#xff0c;就像你用来上网、玩游戏、…

参与现场问题解决总结(Kafka、Hbase)

一. 背景 Kafka和Hbase在现场应用广泛&#xff0c;现场问题也较多&#xff0c;本季度通过对现场问题就行跟踪和总结&#xff0c;同时结合一些调研&#xff0c;尝试提高难点问题的解决效率&#xff0c;从而提高客户和现场满意度。非难点问题&#xff08;历史遇到过问题&#xf…