深度学习训练营_第J5周_DenseNet+ SE-Net实战

news2024/11/17 3:32:55
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊|接辅导、项目定制

本周进行SE模块在DenseNet上的改进实验,之后将改进思路迁移到YOLOv5模型上测试
首先是学习SE模块
SE模块:Squeeze-and-Excitation Module
其中:

  1. Squeeze操作即将一个feature map的w,h使用平均池化压缩到1x1,而channel不变
  2. Excitation即激活层操作,用于将输入数据映射为对应channel的权重(使用sigmoid,输出值在0~1)
  3. 图中的Scale即将每个权重值乘上原本的feature map的对应层,得到新的feature map
    在这里插入图片描述
    代码借鉴了网上的写法,使用add_module()可以便于动态建立模型
class SELayer(nn.Module):
    def __init__(self, ch_in, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(ch_in, ch_in // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(ch_in // reduction, ch_in, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class _DenseLayer(nn.Module):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False):
        super(_DenseLayer, self).__init__()
        self.norm1 = nn.BatchNorm2d(num_input_features)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate,
                               kernel_size=1, stride=1, bias=False)
        self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.se1 = SELayer(growth_rate, reduction=16)
        self.drop_rate = drop_rate
        self.efficient = efficient

    def forward(self, *prev_features):
        concated_features = torch.cat(prev_features, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))
        new_features = self.se1(self.conv2(self.relu2(self.norm2(bottleneck_output))))
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return new_features


class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.norm = nn.BatchNorm2d(num_input_features)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(num_input_features, num_output_features,
                              kernel_size=1, stride=1, bias=False)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)


class _DenseBlock(nn.Module):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, efficient=False):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                efficient=efficient,
            )
            self.add_module('denselayer%d' % (i + 1), layer)

    def forward(self, init_features):
        features = [init_features]
        for name, layer in self.named_children():  # 遍历上面add_module生成的模型
            new_features = layer(*features)
            features.append(new_features)
        return torch.cat(features, 1)


class DenseNet(nn.Module):
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=24, compression=0.5, bn_size=4,
                 drop_rate=0,
                 num_classes=10, small_inputs=True, efficient=False):

        super(DenseNet, self).__init__()

        # First convolution
        if small_inputs:
            self.features = nn.Sequential(
                nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False))
        else:
            self.features = nn.Sequential(
                nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False),
                nn.BatchNorm2d(num_init_features),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
            )

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                efficient=efficient,
            )
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=int(num_features * compression))
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = int(num_features * compression)
            # self.features.add_module('SE_Block%d' % (i + 1),SE_Block(num_features, reduction=16))

        # Final batch norm
        self.features.add_module('norm_final', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out

在DenseNet上测试:
数据集刚好存了之前天气识别的图片,所以直接用了这个,训练结果:
在这里插入图片描述
下面是在YOLOv5上的实验,SE模块加在哪里就不写了,最后相较于原版yolov5(mAP=0.73),mAP有0.05的提升,其中还含有概率成分,不过条件不允许,就没办法多次训练取均值了,
虽然效果一般,但也是改了这么几次模型,第一次有一点点提升。
在这里插入图片描述
将SE模块根据自己的想法修改了一下,有了1%的提升,看来注意力机制在视觉上的表现是不错的,后续会进行其他注意力模块的引入测试
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/410379.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nuxt - 超详细 Element 组件库主题颜色进行 “统一全局“ 替换,将默认的蓝色主题色更换为其他自定义颜色(保姆级教程,简易且标准全局替换主题色)

前言 如果需要纯 Vue 版本,请访问:这篇文章。 网上的文章可以用乱七八糟来形容了,各种奇葩的引入、安装各种东西,本文提供简洁且符合官方标准的解决方案。 Element UI 默认主题色是蓝色,很可能与我们设计稿不一致(比如设计稿是绿色主题), 这时候问题就出现了,难不成每…

Concept HDL学习资料汇总

本篇博文目录:一.concept HDL相关概念知识1. Concept HDL2.Concept HDL与Cadence的区别3.Concept HDL与Cadence CIS之间的转换问题二.Cadence软件安装1.Cadence 16.6安装2.Cadence 17.x安装三.concept HDL视频学习资料四.concept HDL博文学习资料五.concept HDL书籍/文档资料一…

【报错】Aanaconda环境下配置pytorch时报错 (Solved)

Aanaconda环境下配置pytorch时报错: 在命令行输入conda install pytorch torchvision torchaudio cpuonly -c pytorch安装pytorch时产生报错,报错信息如下: EnvironmentLocationNotFound: Not a conda environment: C:\Users\绀句細浜篭.conda\envs\py…

vite+vue3基础

vite创建项目指令 npm create vitelatest . 项目创建选择vue。 创建完成之后安装依赖 npm i || cnpm ivite项目下 引入文件需要配置路径 //如果文件路径 .///如果文件层 较深 配置 “/***”“../../”配置vite中的 路径(js版本支持node全局变量 ts版本不支持 types…

4款【新概念APP】对比+免费下载

4款【新概念APP】对比免费下载4款【新概念APP】对比免费下载新概念英语咖(体积小、无广告、全免费、不能倍速播放)新概念英语全册(免费,但强制广告,否则不能播放音频。可以倍速)新概念英语全四册&#xff0…

SpringBoot入门简介

SpringBoot简介 1.什么是SpringBoot SpringBoot基于Spring4.0设计,不仅继承了Spring框架原有的优秀特性,而且还通过简化配置来进一步简化了Spring应用的整个搭建和开发过程。另外SpringBoot通过集成大量的框架使得依赖包的版本冲突,以及引用…

Android中的全量更新、增量更新以及热更新

在客户端开发过程中,我们可能会遇到这样一种需求:点击某个按钮弹出一个弹窗,提示我们可以更新到apk的某个版本,或者我们可以通过服务端接口进行强制更新。在这种需求中,我们是不需要通过应用商店来更新我们的apk的&…

【Hadoop/Java】基于HDFS的Java Web网络云盘

【Hadoop/Java】基于HDFS的Java Web网络云盘 本人BNUZ大学生萌新,水平不足,还请各位多多指教! 实验目的 熟悉HDFS Java API的使用;能使用HDFS Java API编写应用程序 实验要求 云盘系统通过互联网为企业和个人提供信息的储存、…

前后端部署+nginx配置

文章目录概要1、脚手架安装2、项目打包部署3、配置nginxEND概要 内容主要包括部署前端项目,nginx安装配置,以及后端项目的打包 1、脚手架安装 vue init webpack 项目运行(默认端口8080) npm run dev 如果前后端分离项目&…

使用Vue+el-form+form-validate实现管理端登录接口联调前准备工作实战

前言 这是《Vue + SpringBoot前后端分离项目实战》专栏的第7篇博客,感谢你能从成千上万篇博客中打开这一篇,和我一起学习前端开发实战知识,让我们一起开始吧。 目录 前言 一、上节回顾和本节介绍 1. 上节回顾

【可视化开发】echarts配置项——修改tooltip默认样式

修改tooltip默认样式 在可视化开发中我们通常会遇到修改tooltip样式问题,下面分享给大家代码片段和最终呈现效果 tooltip: {//鼠标悬浮框的提示文字trigger: "axis",axisPointer: {// 坐标轴指示器配置项。type: "none", // line 直线指示器 …

Vue3.0ElementPlus<input输入框自动获取焦点>

文章目录前言一、input-focus事件?二、使用步骤1.给input 设置ref 属性2.引入ref和nextTick3.在dialog打开事件中触发前言 记录一下自己最近开发vue3.0的小小问题~~ 最近在做项目时,dialog弹框事件需定位input焦点,方便用户可直接输入。原理…

【vue3】组合式API之setup()介绍与reactive()函数的使用

>😉博主:初映CY的前说(前端领域) ,📒本文核心:setup()概念、 reactive()的使用 【前言】vue3作为vue2的升级版,有着很多的新特性,其中就包括了组合式API,也就是是 Composition API。学习组合…

vue + gojs 实现拖拽 流程图

一、流程图效果 最近一段时间在研究go.js,它是一款前端开发画流程图的一个插件,也是一个难点,要说为什么是难点,首先,它是依赖画布canvas知识开发。其次,要依赖于内部API开发需求,开发项目需求的时候就要花…

js逆向爬取某音乐网站某歌手的歌曲

js逆向爬取某音乐网站某歌手的歌曲一、分析网站1、案例介绍2、寻找列表页Ajax入口(1)页面展示图。(2)寻找部分歌曲信息Ajax的token。(3)寻找歌曲链接(4)获取歌曲名称和id信息3、寻找…

vue-plugin-hiprint vue hiprint vue使用hiprint打印控件VUE HiPrint HiPrint简单使用

vue-plugin-hiprint vue hiprint vue使用hiprint打印控件VUE HiPrint HiPrint简单使用安装相关依赖安装 vue-plugin-hiprintJQuery安装 打印客户端引入依赖打印 html 内容 - 简单使用根据模版打印 - 简单使用以下内容 和上面demo 没关系 !!!&…

使用videjs+vue2+elementui自定义播放器控件

一、安装项目所需依赖 videojs依赖: npm install --save-dev video.js elementui依赖(这个图方便就不按需引入了): npm i element-ui -S 二、main.js修改 增加以下几行: import videojs from video.js import e…

成功解决:下载的谷歌浏览器,打开却是“2345浏览器”,方法亲测有效

今天打开谷歌浏览器使用,浏览器界面显示的2345浏览器,难道谷歌把2345收购了?应该不能,上网查找问题原因才发现,原来的谷歌首页是被劫持了。(如果迫切解决问题,直接拉到底找方法) 试了…

前端必备的谷歌浏览器JSON可视化插件:JSON-Handle

功能简介: 日常开发过程中,对接后台返回的数据接口时,数据格式常常是各种json格式字符串,在netWork里面查看十分不便,需要在网上找一个json格式化的工具再查看,然后再取数据字段,然后绑定到页面上,十分不便,推荐这么一款前端开发的浏览器插件工具给大家使用。 返回数…

【React-Hooks进阶】useState回调函数的参数 / useEffect发送网络请求/ useRef / useContext

前言 博主主页👉🏻蜡笔雏田学代码 专栏链接👉🏻React专栏 上篇文章初步学习了Hooks的基础知识 今天来深入学习Hooks的一些扩展知识 感兴趣的小伙伴一起来看看吧~🤞 文章目录useState -回调函数的参数使用场景语法语法规…