传知代码-基于多尺度动态卷积的图像分类

news2024/9/23 5:29:40

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

概述

在计算机视觉领域,图像分类是非常重要的任务之一。近年来,深度学习的兴起极大提升了图像分类的精度和效率。本文将介绍一种基于动态卷积网络(Dynamic Convolutional Networks)、多尺度特征融合网络(Multi-scale Feature Fusion Networks)和自适应损失函数(Adaptive Loss Functions)的智能图像分类模型,采用了PyTorch框架进行实现,并通过PyQt构建了简洁的用户图像分类界面。该模型能够处理多分类任务,并且提供了良好的可扩展性和轻量化设计,使其适用于多种不同的图像分类场景。

效果可视化

在这里插入图片描述

模型原理解读

动态卷积

传统卷积网络通常使用固定的卷积核,而动态卷积则是通过引入多个可学习的卷积核动态选择不同的卷积核进行操作。这样可以在不同的输入图像上实现不同的卷积操作,从而提高模型的表达能力。通过加入Attention模块,能对输入图像的不同特征进行加权处理,进一步增强了网络对特征的自适应能力。

常规的卷积层使用单个静态卷积核,应用于所有输入样本。而动态卷积层则通过注意力机制动态加权n个卷积核的线性组合,使得卷积操作依赖于输入样本。动态卷积操作可以定义为:
在这里插入图片描述

在这里插入图片描述

其中动态卷积的线性组合可以用这个图表示:
在这里插入图片描述

在 ODConv 中,对于卷积核 W i W_i Wi:

  1. α s i \alpha_{s_i} αsi 为 k × \times × k 空间位置的每个卷积参数(每个滤波器)分配不同的注意力标量;下图a
  2. α c i \alpha_{c_i} αci 为每个卷积滤波器 W i m W_i^m Wim c in c_{\text{in}} cin 个通道分配不同的注意力标量;下图b
  3. α f i \alpha_{f_i} αfi c out c_{\text{out}} cout 个卷积滤波器分配不同的注意力标量;下图c
  4. α w i \alpha_{w_i} αwi 为整个卷积核分配一个注意力标量。下图d

在下图中,展示了将这四种类型的注意力逐步乘以 n n n 个卷积核的过程。原则上,这四种类型的注意力是相互补充的,通过按位置、通道、滤波器和卷积核的顺序逐步将它们乘以卷积核 W i W_i Wi,使卷积操作在所有空间位置、所有输入通道、所有卷积核中都不同,针对输入 x x x 捕获丰富的上下文信息,从而提供性能保证。
在这里插入图片描述

原则上来讲,这四种类型的注意力是互补的,通过渐进式对卷积沿位置、通道、滤波器以及核等维度乘以不同的注意力将使得卷积操作对于输入存在各个维度的差异性,提供更好的性能以捕获丰富上下文信息。因此,ODCOnv可以大幅提升卷积的特征提取能力;更重要的是,采用更少卷积核的ODConv可以取得更优的性能。
代码实现

class ODConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1,
                 reduction=0.0625, kernel_num=4):
        super(ODConv2d, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.kernel_num = kernel_num
        self.attention = Attention(in_planes, out_planes, kernel_size, groups=groups,
                                   reduction=reduction, kernel_num=kernel_num)
        self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size),
                                   requires_grad=True)
        self._initialize_weights()
 
        if self.kernel_size == 1 and self.kernel_num == 1:
            self._forward_impl = self._forward_impl_pw1x
        else:
            self._forward_impl = self._forward_impl_common
 
    def _initialize_weights(self):
        for i in range(self.kernel_num):
            nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')
 
    def update_temperature(self, temperature):
        self.attention.update_temperature(temperature)
 
    def _forward_impl_common(self, x):
        # Multiplying channel attention (or filter attention) to weights and feature maps are equivalent,
        # while we observe that when using the latter method the models will run faster with less gpu memory cost.
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        batch_size, in_planes, height, width = x.size()
        x = x * channel_attention
        x = x.reshape(1, -1, height, width)
        aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)
        aggregate_weight = torch.sum(aggregate_weight, dim=1).view(
            [-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])
        output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups * batch_size)
        output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
        output = output * filter_attention
        return output
 
    def _forward_impl_pw1x(self, x):
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        x = x * channel_attention
        output = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups)
        output = output * filter_attention
        return output
 
    def forward(self, x):
        return self._forward_impl(x)

多尺度特征融合网络

多尺度特征是指从图像中提取不同尺度、不同分辨率下的特征。这些特征可以捕捉图像中的局部细节信息(如纹理、边缘等)和全局结构信息(如物体形状和轮廓)。传统的卷积神经网络(CNN)一般通过逐层下采样提取深层特征,但在这个过程中,高层的语义信息虽然丰富,却丢失了低层的细节信息。多尺度特征融合通过结合不同层次的特征,弥补了这一不足。

在这里插入图片描述

如上图所示,在本文的网络设计中,多尺度特征融合通过以下几个步骤实现:

特征提取模块:模型通过不同的卷积核(例如3x3、5x5、7x7)对输入图像进行多层次的卷积操作,提取出不同尺度的特征。

特征拼接与加权融合:在融合阶段,来自不同卷积层的特征图会进行拼接或加权求和,确保网络能够根据不同的任务需求自适应地调整特征权重。例如,在分类任务中,局部信息可能对小物体的识别更有帮助,而全局信息则适用于大物体的分类。
代码实现

class MultiScaleFeatureFusion(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MultiScaleFeatureFusion, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.conv3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv5x5 = nn.Conv2d(in_channels, out_channels, kernel_size=5, padding=2)
        self.conv7x7 = nn.Conv2d(in_channels, out_channels, kernel_size=7, padding=3)

    def forward(self, x):
        out1 = self.conv1x1(x)
        out2 = self.conv3x3(x)
        out3 = self.conv5x5(x)
        out4 = self.conv7x7(x)
        return out1 + out2 + out3 + out4  # 多尺度特征融合

自适应损失函数

在深度学习的图像分类任务中,损失函数的选择直接影响模型的训练效果。本文所设计的网络引入了自适应损失函数(Adaptive Loss Functions),这是提升分类性能的重要创新之一。传统的损失函数通常具有固定的形式和权重,不能根据数据分布和训练阶段的不同自动调整。而自适应损失函数通过动态调整损失权重和形式,能够更有效地优化模型,提升其对复杂问题的学习能力。
代码实现

class AdaptiveLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, balance_factor=0.999):
        super(AdaptiveLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.balance_factor = balance_factor
    
    def forward(self, logits, targets):
        # 计算交叉熵损失
        ce_loss = F.cross_entropy(logits, targets, reduction='none')

模型整体结构

本文使用的模型整体结构如下图所示:
在这里插入图片描述

数据集简介

德国交通标志识别基准(GTSRB)包含43类交通标志,分为39,209个训练图像12,630个测试图像。图像具有不同的光线条件和丰富的背景。如下图所示:
在这里插入图片描述
在这里插入图片描述

实验结果

在经过动态卷积和多尺度特征提取以及自适应损失函数后在验证集上能够取得0.944的准确率。
在这里插入图片描述

其loss曲线和准确率曲线如下图所示:
在这里插入图片描述

并且本文与其他文章结果进行了比较:

模型准确率差异
ASSC[1]82.8%+11.6%
DAN[2]91.1%+3.3%
SRDA[3]93.6%+0.8%
OURS94.4%-

混淆矩阵结果

在这里插入图片描述

实现过程

版本

PyQt5                     5.15.11
seaborn                   0.13.2
torch                     2.4.0
PyQt5-Qt5                 5.15.2
numpy                     1.26.4
pandas                    1.5.0
  1. 首先对模型进行训练,保存最佳模型
python main.py
  1. 加载最佳模型进行可视化预测
python predict_gui.py

参考文献

[1] Haeusser, Philip, et al. “Associative domain adaptation.” Proceedings of the IEEE international conference on computer vision. 2017.
[2] Long, Mingsheng, et al. “Learning transferable features with deep adaptation networks.” International conference on machine learning. PMLR, 2015.
[3] Cai, Guanyu, et al. “Learning smooth representation for unsupervised domain adaptation.” IEEE Transactions on Neural Networks and Learning Systems 34.8 (2021): 4181-4195.
[4] Li, Chao, Aojun Zhou, and Anbang Yao. “Omni-dimensional dynamic convolution.” arXiv preprint arXiv:2209.07947 (2022).

源码下载

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2156691.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器人机构、制造

简单整理一下,在学习了一些运动学和动力学之类的东西,简单的整合了一些常用的机械结构和图片。 1.电机: 市面上的电机有:直流电机,交流电机,舵机,步进电机,电缸,无刷电…

【无人机设计与控制】 基于matlab的蚁群算法优化无人机uav巡检

摘要 本文使用蚁群算法(ACO)优化无人机(UAV)巡检路径。无人机巡检任务要求高效覆盖特定区域,以最小化能源消耗和时间。本研究提出的算法通过仿生蚁群算法优化巡检路径,在全局搜索和局部搜索中平衡探索与开…

【软件工程】成本效益分析

一、成本分析目的 二、成本估算方法 三、成本效益分析方法 课堂小结 例题 选择题

深度之眼(三十)——pytorch(一)--深入浅出pytorch(附安装流程)

文章目录 一、前言一、pytoch二、六个部分三、如何学习四、学习路径(重要)五、安装pytorch5.1 坑15.2 坑2 一、前言 我看了下目录 第一章和第二章都是本科学的数字图像处理。 也就是这一专栏:数字图像实验。 所以就不准备学习前两章了,直接…

一文详解大语言模型Transformer结构

目录 1. 什么是Transformer 2. Transformer结构 2.1 总体结构 2.2 Encoder层结构 2.3 Decoder层结构 2.4 动态流程图 3. Transformer为什么需要进行Multi-head Attention 4. Transformer相比于RNN/LSTM,有什么优势?为什么? 5. 为什么说Transf…

MySQL --数据类型

文章目录 1.数据类型分类2.数值类型2.1 tinyint类型2.2 bit类型2.3小数类型2.31float2.32decimal 3.字符串类型3.1 char3.2varchar3.3 char和varchar比较 4.日期和时间类型5.enum和set 1.数据类型分类 2.数值类型 2.1 tinyint类型 数值越界测试: create table tt1…

C++ Qt 之 QPushButton 好看的样式效果实践

文章目录 1.前序2.效果演示3.代码如下 1.前序 启发于 edge 更新 web 页面,觉得人家做的体验挺好 决定在Qt实现,方便以后使用 2.效果演示 特性介绍: 默认蓝色鼠标移入 渐变色,鼠标变为小手鼠标移出 恢复蓝色,鼠标恢…

计算机毕业设计之:基于uni-app的校园活动信息共享系统设计与实现(三端开发,安卓前端+网站前端+网站后端)

博主介绍: ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台…

006——队列

队列: 一种受限的线性表(线性逻辑结构),只允许在一段进行添加操作,在另一端只允许进行删除操作,中间位置不可操作,入队的一端被称为队尾,出队的一端被称为队头,在而我们…

作业报告┭┮﹏┭┮(Android反调试)

一:Android反调试 主要是用来防止IDA进行附加的,主要的方法思路就是,判断自身是否有父进程,判断是否端口被监听,然后通过调用so文件中的线程进行监视,这个线程开启一般JNI_OnLoad中进行开启的。但是这个是…

Java语言程序设计基础篇_编程练习题**18.31 (替换单词)

目录 题目:**18.31 (替换单词) 习题思路 代码示例 运行结果 替换前 替换后 题目:**18.31 (替换单词) 编写一个程序,递归地用一个新单词替换某个目录下的所有文件中出现的某个单词。从命令行如下传递参数: java Exercise18…

C++标准库双向链表 list 中的insert函数实现。

CPrimer中文版(第五版): //运行时错误:迭代器表示要拷贝的范围,不能指向与目的位置相同的容器 slist.insert(slist.begin(),slist.begin(),slist.end()); 如果我们传递给insert一对迭代器,它们不能…

【有啥问啥】深度剖析:大模型AI时代下的推理路径创新应用方法论

深度剖析:大模型AI时代下的推理路径创新应用方法论 随着大规模预训练模型(Large Pretrained Models, LPMs)和生成式人工智能的迅速发展,AI 在多领域的推理能力大幅提升,尤其是在自然语言处理、计算机视觉和自动决策领…

【C++11】异常处理

目录 一、异常的引入 二、C异常的关键字 三、异常的抛出与处理规则 四、异常缺陷的处理 五、自定义异常体系 六、异常规范 七、异常安全 八、异常的优缺点 1.优点 2.缺点 一、异常的引入 传统的C语言处理异常的方式有两种: 1.终止程序:使用as…

[WMCTF2020]Make PHP Great Again 2.01

又是php代码审计,开始吧. 这不用审吧,啊喂. 意思就是我们要利用require_once()函数和传入的file的value去读取flag的内容.,貌似呢require_once()已经被用过一次了,直接读取还不行,看一下下面的知识点. require_once() require…

Qt 注册表操作

一.操作环境 二.注册表查看 1. 搜索注册表打开 2. 注册表查看 例如我想操作 计算机\HKEY_CURRENT_USER\SOFTWARE\winzq\qwert下的内容 三.代码 1. H文件 #ifndef __REGISTER_H__ #define __REGISTER_H__#include <QString> #include <QSettings> #include <Q…

Web 安全(Web Security)

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:Linux运维老纪的首页…

信息安全工程师(11)网络信息安全科技信息获取

一、信息获取的重要性 在网络安全领域&#xff0c;及时、准确地获取科技信息对于防范和应对网络威胁至关重要。这些信息可以帮助安全团队了解最新的攻击手段、漏洞信息、防护技术等&#xff0c;从而制定有效的安全策略和应对措施。 二、信息获取的来源 网络信息安全科技信息的获…

s3c2440各部分应用

一、按位运算 按位与&&#xff1a;清零&#xff0c;清零位&0&#xff1b; 如&#xff1a;0xFFFF &&#xff08; ~&#xff08;1 << 7&#xff09;&#xff09;, 将第7位清零。 按位或 | &#xff1a;置1&#xff0c;置1位 | 1&#xff1b; 如&…

MySQL(七)——事务

文章目录 事务事务的概念事务的ACID特性事务的语法查看存储引擎查看自动提交参数和设置手动事务操作保存点 隔离级别与并发事务问题隔离级别并发事务问题 事务 事务的概念 事务&#xff08;Transaction&#xff09;是数据库管理系统中执行过程中的一个逻辑单位&#xff0c;由…