pytorch注意力机制

news2025/1/22 19:10:26

pytorch注意力机制

最近看了一篇大佬的注意力机制的文章然后自己花了一上午的时间把按照大佬的图把大佬提到的注意力机制都复现了一遍,大佬有一些写的复杂的网络我按照自己的理解写了几个简单的版本接下来就放出我写的代码。顺便从大佬手里盗走一些图片,等我有时间一起进行替换,在此特别鸣谢这位大佬。
链接: 大佬博客论文地址

SENet

在这里插入图片描述
SE是一类最简单的通道注意力机制,主要是使用自适应池化层将[b,c,w,h]的数据变为[b,c,1,1],然后对数据进行维度变换
使数据变为[b,c]然后通过两个全连接层使数据变为[b,c//ratio]->再变回[b,c],然后使用维度变换重新变为[b,c,1,1],然后与输入数据相乘。

import torch

class SE_block(torch.nn.Module):
    def __init__(self,in_channel,ratio):
        super(SE_block, self).__init__()
        self.avepool = torch.nn.AdaptiveAvgPool2d(1)
        self.linear1 = torch.nn.Linear(in_channel,in_channel//ratio)
        self.linear2 = torch.nn.Linear(in_channel//ratio,in_channel)
        self.sigmoid = torch.nn.Sigmoid()
        self.Relu = torch.nn.ReLU()

    def forward(self,input):
        b,c,w,h = input.shape
        x = self.avepool(input)
        x = x.view([b,c])
        x = self.linear1(x)
        x = self.Relu(x)
        x = self.linear2(x)
        x = self.sigmoid(x)
        x = x.view([b,c,1,1])

        return input*x

if __name__ == "__main__":
    input = torch.randn((1,512,224,224))
    model = SE_block(in_channel=512,ratio=8)
    output = model(input)
    print(output.shape)

ECAnet

在这里插入图片描述
ECANet是SENet的改进版本中间使用卷积层来代替全连接层来实现ECA的通道注意力机制

import torch
import math


class ECA_block(torch.nn.Module):
    def __init__(self,in_channel,gama=2, b=1):
        super(ECA_block, self).__init__()
        # 自适应核宽
        kernel_size = int(abs(math.log(in_channel,2)+b)/gama)
        kernel_size = kernel_size if kernel_size%2 else kernel_size + 1
        self.ave_pool = torch.nn.AdaptiveAvgPool2d(1)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv = torch.nn.Conv1d(in_channels=1,out_channels=1,kernel_size=kernel_size,padding=kernel_size//2)

    def forward(self,input):
        b,c,w,h = input.shape
        x = self.ave_pool(input)
        x = x.view([b,1,c])
        x = self.conv(x)
        x = self.sigmoid(x)
        x = x.view([b,c,1,1])
        return input*x

if __name__ == "__main__":
    input = torch.randn((1,512,224,224))
    model = ECA_block(in_channel=512)
    output = model(input)
    print(output.shape)

CMBA

在这里插入图片描述
CMBA注意力机制模块将数据依次通过通道注意力机制和空间注意力机制

import torch

class channel_attention(torch.nn.Module):
    def __init__(self,in_channel,ratio):
        super(channel_attention, self).__init__()
        self.ave_pool = torch.nn.AdaptiveAvgPool2d(1)
        self.max_pool = torch.nn.AdaptiveMaxPool2d(1)
        self.linear1 = torch.nn.Linear(in_channel,in_channel//ratio)
        self.linear2 = torch.nn.Linear(in_channel//ratio,in_channel)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self,input):
        b,c,w,h = input.shape
        ave = self.ave_pool(input)
        max = self.max_pool(input)

        ave = ave.view([b,c])
        max = ave.view([b,c])

        ave = self.relu(self.linear1(ave))
        max = self.relu(self.linear1(max))

        ave = self.sigmoid(self.linear2(ave))
        max = self.sigmoid(self.linear2(max))

        x = self.sigmoid(ave+max).view([b,c,1,1])

        return x*input



class spatial_attention(torch.nn.Module):
    def __init__(self,kernel_size = 7):
        super(spatial_attention, self).__init__()
        self.conv = torch.nn.Conv2d(2,1,kernel_size=kernel_size,padding=kernel_size//2)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self,input):
        b,c,w,h = input.shape
        max,_ = torch.max(input,dim=1,keepdim=True)
        ave = torch.mean(input,dim=1,keepdim=True)
        x = torch.cat([ave,max],dim=1)
        x = self.conv(x)
        x = self.sigmoid(x)

        return x*input



class CMBA(torch.nn.Module):
    def __init__(self,in_channel,ratio,kernel_size):
        super(CMBA, self).__init__()
        self.channel_attention = channel_attention(in_channel=in_channel,ratio=ratio)
        self.spatial_attention = spatial_attention(kernel_size=kernel_size)

    def forward(self,x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)

        return x

if __name__ == "__main__":
    input = torch.randn((1,512,224,224))
    # model = channel_attention(in_channel=512,ratio=8)
    # model = spatial_attention(kernel_size=7)
    model = CMBA(in_channel=512,ratio=8,kernel_size=7)
    output = model(input)
    print(output.shape)

SKnet

在这里插入图片描述
这是一个给予多个感受野的卷积核的通道注意力机制
目前这个代码是CPU的代码如想使用GUP的SKNet请联系作者

import torch


# 获得layer_num=3个卷积层
class convlayer(torch.nn.Sequential):
    def __init__(self,in_channel,layer_num=3):
        super(convlayer, self).__init__()
        for i in range(layer_num):
            layer = torch.nn.Conv2d(in_channel,in_channel,kernel_size=i*2+3,padding=i+1)
            self.add_module('convlayer%d'%(i),layer)


# 获得layer_num=3个用于反向压缩卷积的线性层
class linearlayer(torch.nn.Sequential):
    def __init__(self,in_channel,out_channel,layer_num=3):
        super(linearlayer, self).__init__()
        for i in range(layer_num):
            layer = torch.nn.Linear(in_channel,out_channel)
            self.add_module('linearlayer%d'%(i),layer)


class SK(torch.nn.Module):
    def __init__(self,in_channel,ratio,layer_num):
        super(SK, self).__init__()
        self.conv = convlayer(in_channel,layer_num)
        self.linear1 = torch.nn.Linear(in_channel,in_channel//ratio)
        self.linear2 = linearlayer(in_channel//ratio,in_channel,layer_num)
        self.softmax = torch.nn.Softmax()
        self.ave = torch.nn.AdaptiveAvgPool2d(1)

    def forward(self,input):
        b,c,w,h = input.shape
        # 用来保存不同感受野的加和
        x = torch.zeros([b,c,w,h])
        # 存储每个感受野的输出
        x_list = []

        # 使用感受野不同的卷积层输出不同的值
        for i in self.conv:
            # 得到对应卷积层的结果
            res = i(input)
            # 保存每个卷积层的输出
            x_list.append(res)
            # 对输出求和
            x += res

        # 进行全局平均池化进行压缩
        x = self.ave(x)
        # 对数据进行维度变化方便进入线性层
        x = x.view([b,c])
        # 将维度变化之后的数据通道第一个线形层
        x = self.linear1(x)

        # 新建一个变量保存输出
        output = torch.zeros([b,c,w,h])

        for j,k in enumerate(self.linear2):
            # 使用第j个全连接层进行数据升维
            s = k(x)
            # 改变数据结构
            s = s.view([b,c,1,1])
            # 进行softmax
            s = self.softmax(s)
            # 将softmax的值与卷积分支的结果相乘然后相加
            output += s*x_list[j]

        return output


if __name__ == "__main__":
    input = torch.randn((1,512,224,224))
    model = SK(512,8,3)
    print(model(input).shape)

SCSE

在这里插入图片描述
本注意力机制是将数据分别通过空间注意力机制和通道注意力机制然后再相加的一种注意力机制

import torch



class sSE(torch.nn.Module):
    def __init__(self,in_channel):
        super(sSE, self).__init__()
        self.conv = torch.nn.Conv2d(in_channel,1,kernel_size=1,bias=False)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self,input):
        x = self.conv(input)
        x = self.sigmoid(x)
        return input*x

class cSE(torch.nn.Module):
    def __init__(self,in_channel):
        super(cSE, self).__init__()
        self.ave = torch.nn.AdaptiveAvgPool2d(1)
        self.conv1 = torch.nn.Conv2d(in_channel,in_channel//2,1,bias=False)
        self.conv2 = torch.nn.Conv2d(in_channel//2, in_channel,1,bias=False)
        self.sigmoid = torch.nn.Sigmoid()
        self.relu = torch.nn.ReLU()

    def forward(self,input):
        x = self.ave(input)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.sigmoid(x)

        return x*input



class SCSE(torch.nn.Module):
    def __init__(self,in_channel):
        super(SCSE, self).__init__()
        self.cse = cSE(in_channel)
        self.sse = sSE(in_channel)
    def forward(self,x):
        out_cse = self.cse(x)
        out_sse = self.sse(x)
        return out_cse+out_sse


if __name__ == "__main__":
    input = torch.randn((1,512,224,224))
    # model = sSE(in_channel=512)
    # model = cSE(in_channel=512)
    model = SCSE(in_channel=512)
    print(model(input).shape)

NoLocalNet

在这里插入图片描述
本注意力机制是使用三个卷积核然后互相进行矩阵相乘的注意力机制最终将相乘的成功与输入相加

import torch

class NonLocalNet(torch.nn.Module):
    def __init__(self,in_channel):
        super(NonLocalNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channel,in_channel//2,1)
        self.conv2 = torch.nn.Conv2d(in_channel,in_channel//2,1)
        self.conv3 = torch.nn.Conv2d(in_channel,in_channel//2,1)
        self.conv4 = torch.nn.Conv2d(in_channel//2,in_channel,1)
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self,input):
        b,c,w,h = input.shape
        c1 = self.conv1(input).view([b,c//2,w*h])
        c2 = self.conv2(input).view([b,c//2,w*h]).permute(0,2,1)
        c3 = self.conv3(input).view([b,c//2,w*h]).permute(0,2,1)
        f = torch.bmm(c2,c1)
        f = self.softmax(f)
        y = torch.bmm(f,c3).permute(0,2,1).view([b,c//2,w,h])
        y = self.conv4(y)

        return y+input


if __name__ == "__main__":
    input = torch.randn((1,24,100,100))
    model = NonLocalNet(in_channel=24)
    print(model(input).shape)

GCnet

在这里插入图片描述
本注意力机制使用了类似SENet的分支结构

import torch

class GC(torch.nn.Module):
    def __init__(self,in_channel,ratio):
        super(GC, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channel,1,kernel_size=1)
        self.conv2 = torch.nn.Conv2d(in_channel,in_channel//ratio,kernel_size=1)
        self.conv3 = torch.nn.Conv2d(in_channel//ratio,in_channel,kernel_size=1)
        self.softmax = torch.nn.Softmax(dim=1)
        self.ln = torch.nn.LayerNorm([in_channel//ratio,1,1])
        self.relu = torch.nn.ReLU()

    def forward(self,input):
        b,c,w,h = input.shape
        x = self.conv1(input).view([b,1,w*h]).permute(0,2,1)
        x = self.softmax(x)
        i = input.view([b,c,w*h])
        x = torch.bmm(i,x).view([b,c,1,1])
        x = self.conv2(x)
        x = self.ln(x)
        x = self.relu(x)
        x = self.conv3(x)

        return x+input



if __name__ == "__main__":
    input = torch.randn((1,24,100,100))
    model = GC(in_channel=24,ratio=8)
    print(model(input).shape)



一维代码请到公众号购买

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/425690.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue3使用触摸滑动插件(Swiper)

Vue2使用触摸滑动插件(Swiper) 参考文档: Swiper官方 Swiper Vue Swiper Demos 本文使用的是最新版本:Swiper9.2.2 安装插件:yarn add swiper 本文基于Swiper插件进行封装,主要实现两种形式的轮播图展示…

电脑自带远程桌面和远程控制软件哪个好?

随着科技的不断发展和普及,越来越多的公司和个人开始关注远程控制软件的使用。我们常常需要在不同的地方工作,但工作需要的文件和软件却只能在一个地方使用,这时候远程控制电脑软件就变得尤为重要咯。但是,许多人分不清楚&#xf…

Windows远程连接工具有哪些

Windows远程连接工具,一般称为远程桌面软件,更准确的叫远程访问软件或远程控制软件,可以让你从一台电脑远程控制另一台电脑。远程桌面软件允许您控制连接的计算机,就好像它就在您面前一样。 远程桌面工具可用于技术支持、远程工作…

单TYPE-C口 可支持快充又可传输USB2.0数据方案

虽然现在有不少厂商也采用了Type-C接口,但是只作为一个充电接口,对于跨时代的type-c接口来说,多少有点大材小用, 那么有没有办法,让一个type-c接口既可以充电,又可以接OTG?比如不充电的时候可以…

Python一行命令搭建HTTP服务器并外网访问 - 内网穿透

文章目录1.前言2.本地http服务器搭建2.1.Python的安装和设置2.2.Python服务器设置和测试3.cpolar的安装和注册3.1 Cpolar云端设置3.2 Cpolar本地设置4.公网访问测试5.结语转载自远程内网穿透的文章:【Python】快速简单搭建HTTP服务器并公网访问「cpolar内网穿透」 1…

第二章:HTML CSS 网页开发基础(二)

CSS概述 CSS全称:Cascading Style Sheet,可以对文字进行重叠,定位。主要实现页面美化。 一、CSS规则 CSS样式表中包括了3部分:选择符、属性、属性值 选择符{属性:属性值;} 选择符:也可以称为选择器,所有…

Java实现输入行数打印取缔字符,打印金字塔三角形的两个代码程序

目录 前言 一、实现输入行数,打印取缔字符 1.1运行流程(思想) 1.2代码段 1.3运行截图 二、打印金字塔三角形 1.1运行流程(思想) 1.2代码段 1.3运行截图​​​​​​​ 前言 1.因多重原因,本博文有…

【BlazePose】《BlazePose: On-device Real-time Body Pose tracking》

arXiv-2020 文章目录1 Background and Motivation2 Advantages / Contributions3 Method4 Experiments5 Conclusion(own)1 Background and Motivation 人体关键点存在的难点:a wide variety of poses, numerous degrees of freedom, and occ…

JavaWeb—Maven

目录 1.什么是Maven 2.Maven的作用 3.Maven概述 3.1Maven介绍 3.2 Maven模型 3.3 Maven仓库 1.什么是Maven Maven是Apache旗下的一个开源项目,是一款用于管理和构建java项目的工具。 官网:Maven – Welcome to Apache Mavenhttps://maven.apache.o…

vscode 终端集成bash

windows 版本的 vs code 终端默认是没有集成bash的,虽然也能在vscode 终端可以提交git,但是没有高亮,没有提示,很不方便,这时候就需要我们将bash集成到vs code的终端,就可以愉快的使用git的分支高亮&#x…

阿里云蔡英华:云智一体,让产业全面迈向智能

4月11日,在2023阿里云峰会上,阿里云智能首席商业官蔡英华表示,算力的飞速发展使数字化成为确定,使智能化成为可能。阿里云将以云计算为基石,以AI为引擎,参与到从数字化迈向智能化的划时代变革中。 基于服务…

资深PM赞不绝口的【9种项目管理图】

好用的项目管理工具可以帮助项目经理掌握项目进度,更好的拆分任务,节约时间。 今天给大家安排上,助力大家在项目交付路上更顺畅,早日以高质量交付结果,找到百万年薪工作。 ​项目管理甘特图扫描Q群二维码下载Q群5330…

MySQL--表的使用--0409

目录 1.表的基本操作 1.1 创建表 2. 查看表结构 3.修改表 3.1 新增一列 3.2 修改列属性 3.3 修改名字 3.3.1 修改表名字 3.3.2 修改表内列名字 3.4删除 3.4.1 删除列 3.4.2 删除表 1.表的基本操作 查看自己目前在哪个数据库里面 mysql> select database(); 1.1 创…

SpringBoot整合 EasyES (八)

一直在坑自己家人,对,说的就是你,大A. 上一章简单介绍了SpringBoot整合ES 实现简单项目(七), 如果没有看过,请观看上一章 Mybatis 有增强性的 MybatisPlus, ES 有增强性的吗? 有的, easy-es ​ Easy-Es(简称EE&…

java捕获编译时异常exception和运行时错误error的方法

背景 最近使用jacob的时候,由于编译没问题,运行时报如下,我 查看代码发现是调用jacob文件时,是下面的方法报错, ComThread.Release(); 这个方法编译不报错,是因为doCoUninitialize使用native修饰的&#…

java 通过 spring 官网创建springboot项目

文章java简单一写一个springboot入门案例带大家用idea工具工具创建了一个springboot简单的小案例 但有时 我们idea如果连不上网 就会有点问题 我们可以采用另一种创建方式 但这里的前提肯定就是 你的计算机是要有网的 然后访问 https://spring.io/ 打开spring的官网 在 Project…

去了字节跳动,才知道年薪40W的测试有这么多?

今年大环境不好,内卷的厉害,薪资待遇好的工作机会更是难得。最近脉脉职言区有一条讨论火了: 哪家互联网公司薪资最‘厉害’? 下面的评论多为字节跳动,还炸出了很多年薪40W的测试工程师 我只想问一句,现在的…

数据结构进阶:前缀和与差分

数据结构进阶:前缀和与差分基础前缀和基础差分区间乘积前缀置换经典差分性质题目前缀和变种高次前缀和高维前缀和 (SOSDP)蓝桥杯已经结束,打下来很难受。自己对于算法的掌握还是不够,遂继续开启博客书写,激励自己学习。本系列文章…

FinClip 云开发实践(附小程序demo)

在开发一个小程序时,除了考虑界面功能逻辑外,还需要后端的数据支持,开发者需要提前考虑服务器、存储和数据库等相关需求的支持能力,此外还可能需要花费时间精力在部署应用、和依赖服务的建设上。 ​ 因此,腾讯小程序为…

【Java】类和对象详解

1. 类和对象 1.1 类和对象的理解 客观存在的事物皆为对象 ,所以我们也常常说万物皆对象。 类 类的理解 类是对现实生活中一类具有共同属性和行为的事物的抽象类是对象的数据类型,类是具有相同属性和行为的一组对象的集合简单理解:类就是对…