基于多尺度动态卷积的图像分类

news2025/1/24 17:50:16


✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:传知代码论文复现

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

目录

概述

效果可视化

模型原理解读

动态卷积

多尺度特征融合网络

自适应损失函数

模型整体结构

数据集简介

实验结果

实现过程

参考文献


 本文所有资源均可在该地址处获取。

概述

在计算机视觉领域,图像分类是非常重要的任务之一。近年来,深度学习的兴起极大提升了图像分类的精度和效率。本文将介绍一种基于动态卷积网络(Dynamic Convolutional Networks)、多尺度特征融合网络(Multi-scale Feature Fusion Networks)和自适应损失函数(Adaptive Loss Functions)的智能图像分类模型,采用了PyTorch框架进行实现,并通过PyQt构建了简洁的用户图像分类界面。该模型能够处理多分类任务,并且提供了良好的可扩展性和轻量化设计,使其适用于多种不同的图像分类场景。

效果可视化

模型原理解读

动态卷积

传统卷积网络通常使用固定的卷积核,而动态卷积则是通过引入多个可学习的卷积核动态选择不同的卷积核进行操作。这样可以在不同的输入图像上实现不同的卷积操作,从而提高模型的表达能力。通过加入Attention模块,能对输入图像的不同特征进行加权处理,进一步增强了网络对特征的自适应能力。

常规的卷积层使用单个静态卷积核,应用于所有输入样本。而动态卷积层则通过注意力机制动态加权n个卷积核的线性组合,使得卷积操作依赖于输入样本。动态卷积操作可以定义为:
 


其中动态卷积的线性组合可以用这个图表示:


在 ODConv 中,对于卷积核 WiWi​:

  1. αsiαsi​​ 为 k ×× k 空间位置的每个卷积参数(每个滤波器)分配不同的注意力标量;下图a
  2. αciαci​​ 为每个卷积滤波器 WimWim​ 的 cincin​ 个通道分配不同的注意力标量;下图b
  3. αfiαfi​​ 为 coutcout​ 个卷积滤波器分配不同的注意力标量;下图c
  4. αwiαwi​​ 为整个卷积核分配一个注意力标量。下图d

在下图中,展示了将这四种类型的注意力逐步乘以 nn 个卷积核的过程。原则上,这四种类型的注意力是相互补充的,通过按位置、通道、滤波器和卷积核的顺序逐步将它们乘以卷积核 WiWi​,使卷积操作在所有空间位置、所有输入通道、所有卷积核中都不同,针对输入 xx 捕获丰富的上下文信息,从而提供性能保证。


原则上来讲,这四种类型的注意力是互补的,通过渐进式对卷积沿位置、通道、滤波器以及核等维度乘以不同的注意力将使得卷积操作对于输入存在各个维度的差异性,提供更好的性能以捕获丰富上下文信息。因此,ODCOnv可以大幅提升卷积的特征提取能力;更重要的是,采用更少卷积核的ODConv可以取得更优的性能。
代码实现

class ODConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1,
                 reduction=0.0625, kernel_num=4):
        super(ODConv2d, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.kernel_num = kernel_num
        self.attention = Attention(in_planes, out_planes, kernel_size, groups=groups,
                                   reduction=reduction, kernel_num=kernel_num)
        self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size),
                                   requires_grad=True)
        self._initialize_weights()
 
        if self.kernel_size == 1 and self.kernel_num == 1:
            self._forward_impl = self._forward_impl_pw1x
        else:
            self._forward_impl = self._forward_impl_common
 
    def _initialize_weights(self):
        for i in range(self.kernel_num):
            nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')
 
    def update_temperature(self, temperature):
        self.attention.update_temperature(temperature)
 
    def _forward_impl_common(self, x):
        # Multiplying channel attention (or filter attention) to weights and feature maps are equivalent,
        # while we observe that when using the latter method the models will run faster with less gpu memory cost.
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        batch_size, in_planes, height, width = x.size()
        x = x * channel_attention
        x = x.reshape(1, -1, height, width)
        aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)
        aggregate_weight = torch.sum(aggregate_weight, dim=1).view(
            [-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])
        output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups * batch_size)
        output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
        output = output * filter_attention
        return output
 
    def _forward_impl_pw1x(self, x):
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        x = x * channel_attention
        output = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups)
        output = output * filter_attention
        return output
 
    def forward(self, x):
        return self._forward_impl(x)

多尺度特征融合网络

多尺度特征是指从图像中提取不同尺度、不同分辨率下的特征。这些特征可以捕捉图像中的局部细节信息(如纹理、边缘等)和全局结构信息(如物体形状和轮廓)。传统的卷积神经网络(CNN)一般通过逐层下采样提取深层特征,但在这个过程中,高层的语义信息虽然丰富,却丢失了低层的细节信息。多尺度特征融合通过结合不同层次的特征,弥补了这一不足。


如上图所示,在本文的网络设计中,多尺度特征融合通过以下几个步骤实现:

特征提取模块:模型通过不同的卷积核(例如3x3、5x5、7x7)对输入图像进行多层次的卷积操作,提取出不同尺度的特征。

特征拼接与加权融合:在融合阶段,来自不同卷积层的特征图会进行拼接或加权求和,确保网络能够根据不同的任务需求自适应地调整特征权重。例如,在分类任务中,局部信息可能对小物体的识别更有帮助,而全局信息则适用于大物体的分类。
代码实现

class MultiScaleFeatureFusion(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MultiScaleFeatureFusion, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.conv3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv5x5 = nn.Conv2d(in_channels, out_channels, kernel_size=5, padding=2)
        self.conv7x7 = nn.Conv2d(in_channels, out_channels, kernel_size=7, padding=3)

    def forward(self, x):
        out1 = self.conv1x1(x)
        out2 = self.conv3x3(x)
        out3 = self.conv5x5(x)
        out4 = self.conv7x7(x)
        return out1 + out2 + out3 + out4  # 多尺度特征融合

自适应损失函数

在深度学习的图像分类任务中,损失函数的选择直接影响模型的训练效果。本文所设计的网络引入了自适应损失函数(Adaptive Loss Functions),这是提升分类性能的重要创新之一。传统的损失函数通常具有固定的形式和权重,不能根据数据分布和训练阶段的不同自动调整。而自适应损失函数通过动态调整损失权重和形式,能够更有效地优化模型,提升其对复杂问题的学习能力。
代码实现

class AdaptiveLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, balance_factor=0.999):
        super(AdaptiveLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.balance_factor = balance_factor
    
    def forward(self, logits, targets):
        # 计算交叉熵损失
        ce_loss = F.cross_entropy(logits, targets, reduction='none')

模型整体结构

本文使用的模型整体结构如下图所示:

数据集简介

德国交通标志识别基准(GTSRB)包含43类交通标志,分为39,209个训练图像12,630个测试图像。图像具有不同的光线条件和丰富的背景。如下图所示:

实验结果

在经过动态卷积和多尺度特征提取以及自适应损失函数后在验证集上能够取得0.944的准确率。


其loss曲线和准确率曲线如下图所示:


并且本文与其他文章结果进行了比较:

模型准确率差异
ASSC[1]82.8%+11.6%
DAN[2]91.1%+3.3%
SRDA[3]93.6%+0.8%
OURS94.4%-

混淆矩阵结果

实现过程

版本

PyQt5                     5.15.11
seaborn                   0.13.2
torch                     2.4.0
PyQt5-Qt5                 5.15.2
numpy                     1.26.4
pandas                    1.5.0

  1. 首先对模型进行训练,保存最佳模型
python main.py

  1. 加载最佳模型进行可视化预测
python predict_gui.py

参考文献

[1] Haeusser, Philip, et al. “Associative domain adaptation.” Proceedings of the IEEE international conference on computer vision. 2017.
[2] Long, Mingsheng, et al. “Learning transferable features with deep adaptation networks.” International conference on machine learning. PMLR, 2015.
[3] Cai, Guanyu, et al. “Learning smooth representation for unsupervised domain adaptation.” IEEE Transactions on Neural Networks and Learning Systems 34.8 (2021): 4181-4195.
[4] Li, Chao, Aojun Zhou, and Anbang Yao. “Omni-dimensional dynamic convolution.” arXiv preprint arXiv:2209.07947 (2022).

​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2263174.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[Linux] 信号保存与处理

🪐🪐🪐欢迎来到程序员餐厅💫💫💫 主厨:邪王真眼 主厨的主页:Chef‘s blog 所属专栏:青果大战linux 总有光环在陨落,总有新星在闪烁 信号的保存 下面的概…

计算机网络-GRE Over IPSec实验

一、概述 前情回顾:上次基于IPsec VPN的主模式进行了基础实验,但是很多高级特性没有涉及,如ike v2、不同传输模式、DPD检测、路由方式引入路由、野蛮模式等等,以后继续学习吧。 前面我们已经学习了GRE可以基于隧道口实现分支互联&…

进网许可认证、交换路由设备检测项目更新25年1月起

实施时间 2025年1月1日起实施 涉及设备范围 核心路由器、边缘路由器、以太网交换机、三层交换机、宽带网络接入服务器(BNAS) 新增检测依据 GBT41266-2022网络关键设备安全检测方法交换机设备 GBT41267-2022网络关键设备安全技术要求交换机设备 GB/…

用C#(.NET8)开发一个NTP(SNTP)服务

完整源码,附工程下载,工程其实也就下面两个代码。 想在不能上网的服务器局域网中部署一个时间服务NTP,当然系统自带该服务,可以开启,本文只是分享一下该协议报文和能跑的源码。网上作为服务的源码不太常见,…

Connection lease request time out 问题分析

Connection lease request time out 问题分析 问题背景 使用apache的HttpClient,我们知道可以通过setConnectionRequestTimeout()配置从连接池获取链接的超时时间,而Connection lease request time out正是从连接池获取链接超时的报错,这通常…

【课程论文系列实战】:随机对照实验驱动的电商落地页优化

数据与代码见文末 摘要 随机对照试验(Randomized Controlled Trial,RCT)被认为是因果推断的“金标准”方法。通过随机分配实验参与者至不同组别,确保了组间可比性,RCT能够有效地消除样本选择偏差和混杂变量问题。本文…

UML 建模实验

文章目录 实验一 用例图一、安装并熟悉软件EnterpriseArchitect16二、用例图建模 实验二 类图、包图、对象图类图第一题第二题 包图对象图第一题第二题 实验三 顺序图、通信图顺序图银行系统学生指纹考勤系统饮料自动销售系统“买到饮料”“饮料已售完”“无法找零”完整版 通信…

高质量翻译如何影响软件用户体验 (UX)

在软件开发领域,用户体验 (UX) 是决定产品成败的关键因素之一。一个流畅、吸引人且直观的用户体验可以决定一款软件的成功与否。在影响优秀用户体验的众多因素中,高质量翻译尤为重要,尤其是在当今全球化的市场环境中。确保软件为不同语言和文…

ArcGIS Pro 3.4新功能2:Spatial Analyst新特性,密度、距离、水文、太阳能、表面、区域分析

Spatial Analyst 扩展模块在 ArcGIS Pro 3.4 中引入了新功能和增强功能。此版本为您提供了用于表面和区域分析的新工具以及改进的密度和距离分析功能,多种用于水文分析的工具性能的提高,一些新的太阳能分析功能。 目录 1.密度分析 2.距离分析 3.水文…

Linux C 程序 【05】异步写文件

1.开发背景 Linux 系统提供了各种外设的控制方式,其中包括文件的读写,存储文件的介质可以是 SSD 固态硬盘或者是 EMMC 等。 其中常用的写文件方式是同步写操作,但是如果是写大文件会对 CPU 造成比较大的负荷,采用异步写的方式比较…

凯酷全科技抖音电商服务的卓越践行者

在数字经济蓬勃发展的今天,电子商务已成为企业增长的新引擎。随着短视频平台的崛起,抖音作为全球领先的短视频社交平台,不仅改变了人们的娱乐方式,也为品牌和商家提供了全新的营销渠道。厦门凯酷全科技有限公司(以下简…

精准提升:从94.5%到99.4%——目标检测调优全纪录

🚀 目标检测模型调优过程记录 在进行目标检测模型的训练过程中,我们面对了许多挑战与迭代。从初始模型的训练结果到最终的调优优化,每一步的实验和调整都有其独特的思路和收获。本文记录了我在优化目标检测模型的过程中进行的几次尝试&#…

STM8单片机学习笔记·GPIO的片上外设寄存器

目录 前言 IC基本定义 三极管基础知识 单片机引脚电路作用 STM8GPIO工作模式 GPIO外设寄存器 寄存器含义用法 CR1:Control Register 1 CR2:Control Register 2 ODR:Output Data Register IDR:Input Data Register 赋值…

国标GB28181平台EasyGBS在安防视频监控中的信号传输(电源/视频/音频)特性及差异

在现代安防视频监控系统中,国标GB28181协议作为公共安全视频监控联网系统的国家标准,该协议不仅规范了视频监控系统的信息传输、交换和控制技术要求,还为不同厂商设备之间的互联互通提供了统一的框架。EasyGBS平台基于GB28181协议&#xff0c…

如何使用checkBox组件实现复选框

文章目录 概念介绍使用方法示例代码我们在上一章回中介绍了DatePickerDialog Widget相关的内容,本章回中将介绍Checkbox Widget.闲话休提,让我们一起Talk Flutter吧。 概念介绍 我们在这里说的Checkbox也是叫复选框,没有选中时是一个正方形边框,边框内容是空白的,选中时会…

基于“2+1 链动模式商城小程序”的微商服务营销策略探究

摘要:本文探讨在竞争激烈的市场经济与移动互联网时代背景下,微商面临的机遇与挑战。着重分析“21 链动模式商城小程序”如何助力微商改变思路,通过重视服务、提升服务质量,以服务营销放大利润,实现从传统微商模式向更具…

1-1 STM32-0.96寸OLED显示与控制

1.0 模块原理图 2.0 0.96OLED简介 资料下载:https://jiangxiekeji.com/download.html 程序介绍:https://jiangxiekeji.com/tutorial/oled.html SSD1306是一款OLED/PLED点阵显示屏的控制器,可以嵌入在屏幕中,用于执行接收数据、显…

在Visual Studio 2022中配置C++计算机视觉库Opencv

本文主要介绍下载OpenCV库以及在Visual Studio 2022中配置、编译C计算机视觉库OpenCv的方法 1.Opencv库安装 ​ 首先,我们需要安装OpenCV库,作为一个开源库,我们可以直接在其官网下载Releases - OpenCV,如果官网下载过慢&#x…

QT:QDEBUG输出重定向和命令行参数QCommandLineParser

qInstallMessageHandler函数简介 QtMessageHandler qInstallMessageHandler(QtMessageHandler handler) qInstallMessageHandler 是 Qt 框架中的一个函数,用于安装一个全局的消息处理函数,以替代默认的消息输出机制。这个函数允许开发者自定义 Qt 应用…

网站灰度发布?Tomcat的8005、8009、8080三个端口的作用什么是CDNLVS、Nginx和Haproxy的优缺点服务器无法开机时

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默, 忍不住分享一下给大家。点击跳转到网站 学习总结 1、掌握 JAVA入门到进阶知识(持续写作中……) 2、学会Oracle数据库入门到入土用法(创作中……) 3、手把…