人工智能(pytorch)搭建模型24-SKAttention注意力机制模型的搭建与应用场景

news2024/11/24 2:11:23

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型24-SKAttention注意力机制模型的搭建与应用场景,本文将介绍关于SKAttention注意力机制模型的搭建,SKAttention机制具有灵活性和通用性,可应用于计算机视觉、视频分析、自然语言处理、医学影像分析和机器人视觉等领域。其主要优势在于能够根据任务和输入数据动态调整卷积核,从而提高模型的泛化能力和性能。在实际应用中,SKAttention机制可针对具体问题和数据集进行优化,以达到最佳效果。

一、SKAttention注意力机制的概述

SKAttention是一种用于深度学习中的注意力机制,特别是在卷积神经网络(CNN)中。它通过动态选择不同大小的卷积核来提高网络对多尺度特征的捕捉能力。SKAttention的设计灵感来源于视觉皮层神经元,这些神经元能够根据刺激自适应地调整其感受野大小。在CNN中实现这种机制可以帮助网络更好地捕捉复杂图像空间的多尺度特征,同时减少计算资源的浪费。
SKAttention的核心是“选择性核(Selective Kernel)”单元,它允许多个具有不同内核大小的分支在信息指导下使用SoftMax进行融合。这些分支中的每个卷积核都会对输入图像进行处理,产生不同尺寸的特征图。然后,通过融合操作将这些不同尺寸的特征图结合起来,生成用于选择权重的全局和综合表示。最后,根据这些权重对不同大小内核的特征图进行聚合,从而得到最终的输出特征。
SKAttention的主要优势在于它可以更有效地捕捉图像空间的多尺度特征,提高模型在处理不同尺度目标时的性能。此外,SKAttention还可以聚合深度特征,使模型更容易理解,同时也允许更好的可解释性。
SKAttention模块可以灵活地集成到各种深度学习模型中,特别是在目标检测领域,如YOLOv5和YOLOv7等模型,已经成功地集成了SKAttention来提升检测效果。在集成时,SKAttention可以作为即插即用的注意力模块添加到网络的任何合适位置。
SKAttention的实现涉及到多个卷积层、全连接层和softmax激活函数。在模型训练过程中,通过反向传播和梯度下降方法不断更新网络参数,优化模型性能。

二、SKAttention注意力机制的数学原理

SKAttention这个机制可以表示为两个主要步骤:特征融合(Feature Fusion)和注意力权重生成(Attention Weight Generation)。

  1. 特征融合
    假设我们有三种不同大小的卷积核(如3x3, 5x5, 7x7),分别应用于输入特征图 X X X 上,得到三个不同的特征图 F 1 , F 2 , F 3 F_1, F_2, F_3 F1,F2,F3。这些特征图通过一个融合层(例如,使用1x1卷积)来减少通道数并融合信息,得到融合后的特征图 F = [ f 1 , f 2 , f 3 ] F = [f_1, f_2, f_3] F=[f1,f2,f3],其中 f 1 , f 2 , f 3 f_1, f_2, f_3 f1,f2,f3 分别是 F 1 , F 2 , F 3 F_1, F_2, F_3 F1,F2,F3 经过融合层后的结果。
  2. 注意力权重生成
    在融合特征 F F F 上应用两个全连接层(或一个简单的神经网络)来生成注意力权重 α \alpha α。第一个全连接层将特征维度减少到 C r \frac{C}{r} rC,其中 C C C 是通道数, r r r 是减少比例。第二个全连接层将维度恢复到原始的通道数 C C C,并通过softmax函数生成归一化的注意力权重。
    数学上,这个过程可以表示为:
    F = Conv 1 x 1 ( F 1 , F 2 , F 3 ) F = \text{Conv}_{1x1}(F_1, F_2, F_3) F=Conv1x1(F1,F2,F3)
    α = softmax ( FC C r ( FC C ( F ) ) ) \alpha = \text{softmax}(\text{FC}_{\frac{C}{r}}(\text{FC}_{C}(F))) α=softmax(FCrC(FCC(F)))
    其中, Conv 1 x 1 \text{Conv}_{1x1} Conv1x1 表示1x1卷积, FC C r \text{FC}_{\frac{C}{r}} FCrC FC C \text{FC}_{C} FCC 分别表示减少维度和恢复维度的全连接层, softmax \text{softmax} softmax 是softmax激活函数。
    最终,通过注意力权重 α \alpha α 和原始特征 F 1 , F 2 , F 3 F_1, F_2, F_3 F1,F2,F3 的加权和来得到最终的输出特征图 O O O
    O = ∑ i = 1 3 α i ⋅ F i O = \sum_{i=1}^{3} \alpha_i \cdot F_i O=i=13αiFi
    这里, α i \alpha_i αi 是对应于 F i F_i Fi 的注意力权重。

在这里插入图片描述

三、SKAttention注意力机制应用场景

SKAttention(Selective Kernel Attention)机制是一种在深度学习领域中用于提高卷积神经网络性能的技术。它通过在网络中自动选择最合适的卷积核大小来增强模型对不同尺度特征的提取能力。SKAttention机制通常被应用于以下几个方面:
1.计算机视觉:SKAttention机制最初是为了改善图像分类任务而设计的。它可以被集成到各种卷积神经网络架构中,如ResNet、MobileNet等,以提高模型对图像特征的识别能力。在图像识别、目标检测和图像分割等任务中,SKAttention机制都有显著的效果。
2. 视频分析:在视频内容分析中,SKAttention机制可以帮助模型更好地理解视频中的时空特征,用于行为识别、运动检测等任务。
3. 自然语言处理:虽然SKAttention机制最初是为计算机视觉设计的,但其核心思想也可以被借鉴到自然语言处理领域,用于改进序列模型,比如在文本分类、机器翻译等任务中提取关键信息。
4. 医学影像分析:在医学影像领域,SKAttention机制可以帮助模型更加精确地识别和分析影像中的病变区域,用于疾病诊断、组织分割等。
5. 机器人视觉:在机器人视觉系统中,SKAttention机制可以提高机器人对周围环境的感知能力,用于路径规划、目标追踪等任务。
SKAttention机制的核心优势在于其灵活性,它能够根据不同的任务和输入数据动态地调整卷积核,从而提高模型的泛化能力和性能。在实际应用中,SKAttention机制可以根据具体问题和数据集的特点进行相应的调整和优化,以达到最佳的效果。

在这里插入图片描述

四、pytorch搭建SKAttention注意力机制

import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict

class SKAttention(nn.Module):

    def __init__(self, channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32):
        super().__init__()
        self.d=max(L,channel//reduction)
        self.convs=nn.ModuleList([])
        for k in kernels:
            self.convs.append(
                nn.Sequential(OrderedDict([
                    ('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)),
                    ('bn',nn.BatchNorm2d(channel)),
                    ('relu',nn.ReLU())
                ]))
            )
        self.fc=nn.Linear(channel,self.d)
        self.fcs=nn.ModuleList([])
        for i in range(len(kernels)):
            self.fcs.append(nn.Linear(self.d,channel))
        self.softmax=nn.Softmax(dim=0)



    def forward(self, x):
        bs, c, _, _ = x.size()
        conv_outs=[]
        ### split
        for conv in self.convs:
            conv_outs.append(conv(x))
        feats=torch.stack(conv_outs,0)#k,bs,channel,h,w

        ### fuse
        U=sum(conv_outs) #bs,c,h,w

        ### reduction channel
        S=U.mean(-1).mean(-1) #bs,c
        Z=self.fc(S) #bs,d

        ### calculate attention weight
        weights=[]
        for fc in self.fcs:
            weight=fc(Z)
            weights.append(weight.view(bs,c,1,1)) #bs,channel
        attention_weughts=torch.stack(weights,0)#k,bs,channel,1,1
        attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1

        ### fuse
        V=(attention_weughts*feats).sum(0)
        return V


if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    se = SKAttention(channel=512,reduction=8)
    output=se(input)
    print(output.shape)

五、总结

SKAttention是一种深度学习中的注意力机制,用于自动选择最合适的卷积核大小,以提升卷积神经网络对多尺度特征的提取能力。该机制通过动态调整卷积核,增强模型对不同尺度特征的识别,从而提高模型的泛化能力和性能。SKAttention机制具有灵活性和通用性,适用于计算机视觉、视频分析、自然语言处理、医学影像分析和机器人视觉等多个领域。想了解更多SKAttention注意力机制的应用,请持续关注微学AI。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1445044.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电视上如何下载软件

电视上如何下载软件,告诉大家一个简单的方法,可以用DT浏览器下载软件,然后会自动安装这个软件,如有技术问题,可以免费解答

4、解构三个重要的Pipeline(SD-Inpainting, ControlNet, AnimateDiff) [代码级手把手解析diffusers库]

上一篇我们解析了所有Pipeline的基类DiffusionPipeline。后续各种各样的pipeline都继承了DiffusionPipeline的模型加载保存等功能,然后再配合各个组件实现各种的结构即可。 事实上,一个Pipeline通常包含了如下模块(from_pretrained函数根据model_index.json文件new了一个Pipe…

disql备份还原

disql备份还原 前言 本文档根据官方文档,进行整理。 一、概述 在 disql 工具中使用 BACKUP 语句你可以备份整个数据库。通常情况下,在数据库实例配置归档后输入以下语句即可备份数据库: BACKUP DATABASE BACKUPSET db_bak_01;语句执行完…

spring boot(2.4.x 开始)和spring cloud项目中配置文件application和bootstrap加载顺序

在前面的文章基础上 https://blog.csdn.net/zlpzlpzyd/article/details/136060312 spring boot 2.4.x 版本之前通过 ConfigFileApplicationListener 加载配置 https://github.com/spring-projects/spring-boot/blob/v2.3.12.RELEASE/spring-boot-project/spring-boot/src/mai…

使用耳机壳UV树脂制作私模定制耳塞有哪些选择呢?

私模定制耳塞人士的选择可以从以下几个方面考虑: 专业经验:选择有丰富经验的私模定制耳塞人士,能够更好地理解用户需求,提供更专业的建议和服务。可以通过查看其作品和客户评价来了解其经验和口碑。材料质量:选择使用…

第五篇:MySQL常见数据类型

MySQL中的数据类型有很多,主要分为三类:数值类型、字符串类型、日期时间类型 三个表格都在此网盘中,需要者可移步自取,如果觉得有帮助希望点个赞~ MySQL常见数据类型表 数值类型 (注:decimal类型举例,如1…

【闲谈】开源软件的崛起与影响

随着信息技术的快速发展,开源软件已经成为软件开发的趋势,并产生了深远的影响。开源软件的低成本、可协作性和透明度等特点,使得越来越多的企业和个人选择使用开源软件,促进了软件行业的繁荣。然而,在使用开源软件的过…

STM32 STD/HAL库驱动W25Q64模块读写字库数据+OLED0.96显示例程

STM32 STD/HAL库驱动W25Q64 模块读写字库数据OLED0.96显示例程 🎬原创作者对W25Q64保存汉字字库演示: W25Q64保存汉字字库 🎞测试字体显示效果: 📑功能实现说明 利用W25Q64保存汉字字库,OLED显示汉字的时…

vue electron 应用在windows系统上以管理员权限打开应用

打开package.json文件,在build下的win增加配置 "requestedExecutionLevel": "requireAdministrator",

商汤科技「日日新4.0」正式发布,多维度升级大模型体系,能力比肩GPT-4!

文 | BFT机器人 近日,商汤科技正式发布「日日新SenseNova 4.0」,宣告大模型体系多维度全面升级。这款模型具备更全面的知识覆盖、更可靠的推理能力,以及更优越的长文本理解和数字推理能力。同时,它还支持跨模态交互,为…

给定具体日期 返回给定日期是星期几 calendar.weekday(year,month,day)

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 给定具体日期 返回给定日期是星期几 calendar.weekday(year,month,day) [太阳]选择题 如果2024年2月12日是星期一,请问最后一个print语句的运行结果是? import calenda…

【Spring学习】Spring Data Redis:RedisTemplate、Repository、Cache注解

1,spring-data-redis官网 1)特点 提供了对不同Redis客户端的整合(Lettuce和Jedis)提供了RedisTemplate统一API来操作Redis支持Redis的发布订阅模型支持Redis哨兵和Redis集群支持基于Lettuce的响应式编程支持基于JDK、JSON、字符…

【教3妹学编程-算法题】执行操作后的最大分割数量

2哥 : 3妹,今年过年收到压岁钱了没呢。 3妹:切,我都多大了啊,肯定没收了啊 2哥 : 俺也一样,不仅没收到,小侄子小外甥都得给,还倒贴好几千 3妹:哈哈哈哈,2叔叔&#xff0c…

JAVA学习笔记9

1.Java API 文档 1.java类的组织形式 2.字符类型(char) 1.基本介绍 ​ *字符类型可以表示单个字符,字符类型是char,char是两个字节(可以存放汉字),多个字符我们用字符串String ​ eg:char c1 ‘a’; ​ char c2…

AJAX——常用请求方法

1 请求方法 请求方法:对服务器资源,要执行的操作 2 数据提交 场景:当数据需要在服务器上保存 3 axios请求配置 url:请求的URL网址 method:请求的方法,GET可以省略(不区分大小写) …

牛客网SQL进阶114:更新记录

官网链接: 更新记录(二)_牛客题霸_牛客网现有一张试卷作答记录表exam_record,其中包含多年来的用户作答试卷记录,结构如下表。题目来自【牛客题霸】https://www.nowcoder.com/practice/0c2e81c6b62e4a0f848fa7693291d…

Excel下载接口

💗wei_shuo的个人主页 💫wei_shuo的学习社区 🌐Hello World ! Excel下载接口 需求分析 页面表格的数据下载,保存到Excel表格搜索后的数据点击下载,下载的数据需要是搜索后的数据 Controller HTTP 响应对象:…

【HTTP】localhost和127.0.0.1的区别是什么?

目录 localhost是什么呢? 从域名到程序 localhost和127.0.0.1的区别是什么? 域名的等级划分 多网站共用一个IP和端口 私有IP地址 IPv6 今天在网上逛的时候看到一个问题,没想到大家讨论的很热烈,就是标题中这个: …

python常用的深度学习框架

目录 一:介绍 二:使用 Python中有几个非常受欢迎的深度学习框架,它们提供了构建和训练神经网络所需的各种工具和库。以下是一些最常用的Python深度学习框架: 一:介绍 TensorFlow:由Google开发的TensorF…

LeetCode Python -8.字符串转整数

文章目录 题目答案运行结果 题目 请你来实现一个 myAtoi(string s) 函数,使其能将字符串转换成一个 32 位有符号整数(类似 C/C 中的 atoi 函数)。 函数 myAtoi(string s) 的算法如下: 读入字符串并丢弃无用的前导空格检查下一个…