深度学习案例:DenseNet + SE-Net

news2024/12/26 23:51:09

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

 一 回顾DenseNet算法

DenseNet(Densely Connected Convolutional Networks)是一种深度卷积神经网络架构,提出的核心思想是通过在每一层与前面所有层进行直接连接,极大地增强了信息和梯度的流动。传统的卷积神经网络(CNN)结构中,每一层的输入仅来自前一层,而DenseNet通过让每一层的输入包含所有前面层的输出,形成了更密集的连接。这样的设计能够减少梯度消失的问题,促进特征复用,提高模型的表现力和学习效率。

DenseNet的优势主要体现在两个方面。首先,由于密集连接的特点,它在同等参数量下比传统的卷积网络能够学习到更丰富的特征,提升了网络的性能。其次,由于每层都接收前面层的特征图,DenseNet有效缓解了深度神经网络中训练难度较大的问题,特别是在处理深层网络时,可以显著提高梯度的传递效率,减少了对大规模数据集的需求。通过这些优点,DenseNet在图像分类、目标检测等任务中表现出色。

通道注意力机制上文提及,不再叙述。以下是DenseNet+SE-Net代码 

'''
SE模块实现
'''
import tensorflow as tf
from keras.models import Model
from keras import layers
from keras import backend

class Squeeze_excitation_layer(tf.keras.Model):
    def __init__(self, filter_sq):
        super().__init__()
        self.filter_sq = filter_sq
        self.avepool = tf.keras.layers.GlobalAveragePooling2D()

    def build(self, input_shape):
        self.dense1 = tf.keras.layers.Dense(self.filter_sq, activation='relu')
        self.dense2 = tf.keras.layers.Dense(input_shape[-1], activation='sigmoid')

    def call(self, inputs):
        squeeze = self.avepool(inputs)
        excitation = self.dense1(squeeze)
        excitation = self.dense2(excitation)
        excitation = tf.keras.layers.Reshape((1, 1, inputs.shape[-1]))(excitation)
        scale = inputs * excitation
        return scale



def dense_block(x,blocks,name):
    for i in range(blocks):
        x = conv_block(x,32,name=name+'_block'+str(i+1))
    return x

def conv_block(x,growth_rate,name):
    bn_axis = 3
    x1 = layers.BatchNormalization(axis=bn_axis,
                                   epsilon=1.001e-5,
                                   name=name+'_0_bn')(x)
    x1 = layers.Activation('relu',name=name+'_0_relu')(x1)
    x1 = layers.Conv2D(4*growth_rate,1,use_bias=False,name=name+'_1_conv')(x1)

    x1 = layers.BatchNormalization(axis=bn_axis,
                                   epsilon=1.001e-5,
                                   name=name + '_1_bn')(x1)
    x1 = layers.Activation('relu', name=name + '_1_relu')(x1)
    x1 = layers.Conv2D(growth_rate, 3, padding='same',use_bias=False, name=name + '_2_conv')(x1)
    x = layers.Concatenate(axis=bn_axis,name=name+'_concat')([x,x1])
    return x

def transition_block(x,reduction,name):
    bn_axis = 3
    x = layers.BatchNormalization(axis=bn_axis,epsilon=1.001e-5,name=name+'_bn')(x)
    x = layers.Activation('relu',name=name+'_relu')(x)
    x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction),1,use_bias=False,name=name+'_conv')(x)
    x = layers.AveragePooling2D(2,strides=2,name=name+'_pool')(x)
    return x

def DenseNet(blocks,input_shape=None,classes=1000,**kwargs):
    img_input = layers.Input(shape=input_shape)

    bn_axis = 3

    # 224,224,3 -> 112,112,64
    x = layers.ZeroPadding2D(padding=((3,3),(3,3)))(img_input)
    x = layers.Conv2D(64,7,strides=2,use_bias=False,name='conv1/conv')(x)
    x = layers.BatchNormalization(axis=bn_axis,epsilon=1.001e-5,name='conv1/bn')(x)
    x = layers.Activation('relu',name='conv1/relu')(x)

    # 112,112,64 -> 56,56,64
    x = layers.ZeroPadding2D(padding=((1,1),(1,1)))(x)
    x = layers.MaxPooling2D(3,strides=2,name='pool1')(x)

    # 56,56,64 -> 56,56,64+32*block[0]
    # DenseNet121 56,56,64 -> 56,56,64+32*6 == 56,56,256
    x = dense_block(x,blocks[0],name='conv2')

    # 56,56,64+32*block[0] -> 28,28,32+16*block[0]
    # DenseNet121 56,56,256 -> 28,28,32+16*6 == 28,28,128
    x = transition_block(x,0.5,name='pool2')

    # 28,28,32+16*block[0] -> 28,28,32+16*block[0]+32*block[1]
    # DenseNet121 28,28,128 -> 28,28,128+32*12 == 28,28,512
    x = dense_block(x,blocks[1],name='conv3')

    # DenseNet121 28,28,512 -> 14,14,256
    x = transition_block(x,0.5,name='pool3')

    # DenseNet121 14,14,256 -> 14,14,256+32*block[2] == 14,14,1024
    x = dense_block(x,blocks[2],name='conv4')

    # DenseNet121 14,14,1024 -> 7,7,512
    x = transition_block(x,0.5,name='pool4')

    # DenseNet121 7,7,512 -> 7,7,256+32*block[3] == 7,7,1024
    x = dense_block(x,blocks[3],name='conv5')

    # 加SE注意力机制
    x = Squeeze_excitation_layer(16)(x)

    x = layers.BatchNormalization(axis=bn_axis,epsilon=1.001e-5,name='bn')(x)
    x = layers.Activation('relu',name='relu')(x)

    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    x = layers.Dense(classes,activation='softmax',name='fc1000')(x)

    inputs = img_input

    if blocks == [6,12,24,16]:
        model = Model(inputs,x,name='densenet121')
    elif blocks == [6,12,32,32]:
        model = Model(inputs,x,name='densenet169')
    elif blocks == [6,12,48,32]:
        model = Model(inputs,x,name='densenet201')
    else:
        model = Model(inputs,x,name='densenet')
    return model

def DenseNet121(input_shape=[224,224,3],classes=3,**kwargs):
    return DenseNet([6,12,24,16],input_shape,classes,**kwargs)

def DenseNet169(input_shape=[224,224,3],classes=3,**kwargs):
    return DenseNet([6,12,32,32],input_shape,classes,**kwargs)

def DenseNet201(input_shape=[224,224,3],classes=3,**kwargs):
    return DenseNet([6,12,48,32],input_shape,classes,**kwargs)

from tensorflow.keras.optimizers import Adam

# 实例化模型,指定输入形状和类别数
model = DenseNet201(input_shape=[224,224,3], classes=2)
model.summary()
# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=1e-7)

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

epochs = 25

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
)

# 获取实际训练轮数
actual_epochs = len(history.history['accuracy'])

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(actual_epochs)

plt.figure(figsize=(12, 4))

# 绘制准确率
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

# 绘制损失
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.show()

总结:DenseNet与SE-Net(Squeeze-and-Excitation Networks)结合后,能够进一步增强模型的表现力和效率。DenseNet通过密集连接每一层,促进了特征的复用和梯度的流动,而SE-Net通过引入通道注意力机制,能够自动学习每个特征通道的重要性,调整通道的权重。将这两者结合起来,DenseNet负责加强特征之间的关联性和信息流动,而SE-Net则提升了特征通道的自适应能力,使得网络能够在不同任务中更加精准地利用最有用的特征。这样的结合使得模型在保持高效的同时,能够更加聚焦于有价值的特征,从而提升了性能,尤其在处理复杂的视觉任务时,表现尤为出色。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2257578.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【java学习笔记】Set接口实现类-LinkedHashSet

一、LinkedHashSet的全面说明 (就是把数组不同位置的链表当成一个节点然后相连)

【大模型系列篇】LLaMA-Factory大模型微调实践 - 从零开始

前一次我们使用了NVIDIA TensorRT-LLM 大模型推理框架对智谱chatglm3-6b模型格式进行了转换和量化压缩,并成功部署了推理服务,有兴趣的同学可以翻阅《NVIDIA TensorRT-LLM 大模型推理框架实践》,今天我们来实践如何通过LLaMA-Factory对大模型…

【C++】LeetCode:LCR 078. 合并 K 个升序链表

题干: 给定一个链表数组,每个链表都已经按升序排列。 请将所有链表合并到一个升序链表中,返回合并后的链表。 解法:优先队列 /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *ne…

数据结构和算法-04二叉树-04

广度优先的实现力扣中常见的二叉树相关问题及基本解决方案 tips: 在解决问题时,先确保问题能解决,再去考虑效率,这是解题的关键,切不可为追求效率而变成了技巧性解答。 广度优先 广度优先(层序遍历)遍历的方式是按层次…

DMA代码部分

第一个程序的接线图 OLED ShowHexNum(2,1,(uint32_t)&ADC1->DR,8); 这样可以看AD的DR寄存器的的地址(固定的)了 可以跑一下然后和手册对比 先查ADC1的地址 再在外设的总表里面, 查一下DR相对于上面地址的偏移量 所以其地址为4001 244C 研究一下外设寄存器的地址是怎么…

spdlog高性能日志系统

spdlog高性能日志系统 spdlog 是一个快速、简单、功能丰富的 C 日志库,专为现代 C 开发设计。它支持多种日志后端(如控制台、文件、syslog 等),并提供灵活的格式化和线程安全的日志输出。 1. 特点 极高的性能:大量的编…

FPGA在线升级 -- Multiboot

简介 本章节主要描述关于如何从Golden Image转换到Multiboot Image程序。 升级方案 Golden Image转换到Multiboot Image的方法主要又两种 1、使用ICAPE2 原语; 2、在XDC文件中加入升级约束命令; 以上两种方案都可以实现在线升级,第一种升级…

守护进程化

目录 一、进程组 二、会话 (1)什么是会话 (2)如何创建一个会话 三、守护进程 一、进程组 之前我们学习过进程,其实每一个进程除了有一个进程 ID(PID)之外 还属于一个进程组。进程组是一个或者多个进程的集合&…

QML插件扩展

https://note.youdao.com/ynoteshare/index.html?id294f86c78fb006f1b1b78cc430a20d74&typenote&_time1706510764806

RabbitMQ七种工作模式之 RPC通信模式, 发布确认模式

文章目录 六. RPC(RPC通信模式)客户端服务端 七. Publisher Confirms(发布确认模式)1. Publishing Messages Individually(单独确认)2. Publishing Messages in Batches(批量确认)3. Handling Publisher Confirms Asynchronously(异步确认) 六. RPC(RPC通信模式) 客⼾端发送消息…

ArcGIS字符串补零与去零

我们有时候需要 对属性表中字符串的补零与去零操作 我们下面直接视频教学 下面看视频教学 ArcGIS字符串去零与补零 推荐学习 ArcGIS全系列实战视频教程——9个单一课程组合 ArcGIS10.X入门实战视频教程(GIS思维) ArcGIS之模型构建器(Mod…

前端面试如何出彩

1、原型链和作用域链说不太清,主要表现在寄生组合继承和extends继承的区别和new做了什么。2、推荐我的两篇文章:若川:面试官问:能否模拟实现JS的new操作符、若川:面试官问:JS的继承 3、数组构造函数上有哪些…

大模型应用编排工具Dify之构建专属FQA应用

1.前言 ​ 通过 dify可以基于开源大模型的能力,并结合业务知识库、工具API和自定义代码等构建特定场景、行业的专属大模型应用。本文通过 dify工作室的聊天助手-工作流编排构建了一个基于历史工作日志回答问题的助手,相比原始的大模型答复,通…

前端node环境安装:nvm安装详细教程(安装nvm、node、npm、cnpm、yarn及环境变量配置)

需求:在做前端开发的时候,有的时候 这个项目需要 node 14 那个项目需要 node 16,我们也不能卸载 安装 。这岂不是很麻烦。这个时候 就需要 一个工具 来管理我们的 node 版本和 npm 版本。 下面就分享一个 nvm 工具 用来管理 node 版本。 这个…

c基础加堆练习题

1】思维导图: 2】在堆区空间连续申请5个int类型大小空间,用来存放从终端输入的5个学生成绩,然后显示5个学生成绩,再将学生成绩升序排序,排序后,再次显示学生成绩。显示和排序分别用函数完成 要求&#xff…

嵌入式Linux 设备树 GPIO详解 示例分析 三星 NXP RK

GPIO设备树用于在Linux内核中定义与GPIO相关的硬件资源,它使操作系统可以识别、配置和使用GPIO引脚。设备树中通常会指定GPIO控制器的基地址、GPIO引脚的中断配置、时钟和其他相关信息。 目录 RK相关案例代码 NXP相关案例代码 三星相关案例代码 在设备树中&…

【日记】不想随礼欸(926 字)

正文 今天忙了一天。感觉从早上就开始在救火。客户经理迎接检查,要补资料,找我们问这样要那样,我自己的事情几乎完全开展不了。虽说也没什么大事就是了。 晚上行长还让我重装系统…… 难绷。看来这个爹味新行长懂得还挺多。 中午趁着不多的休…

Spring 源码学习(七)——注解后处理器

通过之前对注解式配置的解析(Spring 源码学习(三)—— 注解式配置解析_spring源码学习-CSDN博客)可以发现其使用 AnnotationConfigUtils 类的 registerAnnotationConfigProcessors 静态方法对象注解后处理器对象进行注册&#xff…

如何避免缓存击穿?超融合常驻缓存和多存储池方案对比

作者:SmartX 解决方案专家 钟锦锌 很多运维人员都知道,混合存储介质配置可能会带来“缓存击穿”的问题,尤其是大数据分析、数据仓库等需要频繁访问“冷数据”的应用场景,缓存击穿可能会更频繁地出现,影响业务运行。除…

Scala的正则表达式二

验证用户名是否合法 规则 1.长度在6-12之间 2.不能数字开头 3.只能包含数字,大小写字母,下划线def main(args: Array[String]): Unit {val name1 "1admin"//不合法,是数字开头val name2 "admin123"//合法val name3 &quo…