【学习代码】PyTorch代码 练习

news2025/1/20 2:17:27

文章目录

  • Attention
    • SEAttention【显式地建模特征通道之间的相互依赖关系】【easy】
    • CBAM Attention【SEAttention通道注意力 + 空间注意力】【middle】
    • SelfAttention【重点:理解意义 and 实践!】【hard】

Attention

SEAttention【显式地建模特征通道之间的相互依赖关系】【easy】

【论文解读】SENet网络: https://zhuanlan.zhihu.com/p/80123284

在这里插入图片描述

注意点:最后是sigmoid()

代码来源:https://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/model/attention/SEAttention.py

import numpy as np
import torch
from torch import nn
from torch.nn import init

class SEAttention(nn.Module):

    def __init__(self, channel=512,reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid() # 这里是sigmoid
        )


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    se = SEAttention(channel=512,reduction=8)
    output=se(input)
    print(output.shape)

CBAM Attention【SEAttention通道注意力 + 空间注意力】【middle】

图画得不错

在这里插入图片描述

注意点:
1. 最后 也是接sigmoid激活 ,限制值 0-1范围
2. 是 残差连接
3. 逐元素相乘 *,自动 会用 广播机制

代码来源:https://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/model/attention/CBAM.py

import numpy as np
import torch
from torch import nn
from torch.nn import init



class ChannelAttention(nn.Module):
    def __init__(self,channel,reduction=16):
        super().__init__()
        self.maxpool=nn.AdaptiveMaxPool2d(1)
        self.avgpool=nn.AdaptiveAvgPool2d(1)
        self.se=nn.Sequential(
            nn.Conv2d(channel,channel//reduction,1,bias=False),
            nn.ReLU(),
            nn.Conv2d(channel//reduction,channel,1,bias=False)
        )
        self.sigmoid=nn.Sigmoid()
    
    def forward(self, x) :
        print(f'ChannelAttention ===============>')
        print(f'x.shape = {x.shape}')
        max_result=self.maxpool(x)
        avg_result=self.avgpool(x)
        max_out=self.se(max_result)
        avg_out=self.se(avg_result)
        output=self.sigmoid(max_out+avg_out)
        print(f'output.shape = {output.shape}') # output.shape = torch.Size([50, 512, 1, 1])
        return output

class SpatialAttention(nn.Module):
    def __init__(self,kernel_size=7):
        super().__init__()
        self.conv=nn.Conv2d(2,1,kernel_size=kernel_size,padding=kernel_size//2)
        self.sigmoid=nn.Sigmoid()
    
    def forward(self, x) :
        print(f'SpatialAttention ===============>')
        print(f'x.shape = {x.shape}')
        max_result,_=torch.max(x,dim=1,keepdim=True)
        avg_result=torch.mean(x,dim=1,keepdim=True)
        result=torch.cat([max_result,avg_result],1)
        output=self.conv(result)
        output=self.sigmoid(output)
        print(f'output.shape = {output.shape}') # output.shape = torch.Size([50, 1, 7, 7])
        return output



class CBAMBlock(nn.Module):

    def __init__(self, channel=512,reduction=16,kernel_size=49):
        super().__init__()
        self.ca=ChannelAttention(channel=channel,reduction=reduction)
        self.sa=SpatialAttention(kernel_size=kernel_size)


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, _, _ = x.size()
        residual=x
        out=x*self.ca(x) # 逐元素相乘 *,自动 会用 广播机制
        out=out*self.sa(out)
        return out+residual


if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    kernel_size=input.shape[2]
    cbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)
    output=cbam(input)
    print(output.shape)

    

SelfAttention【重点:理解意义 and 实践!】【hard】

在这里插入图片描述

注意点:

  1. 理解矩阵乘法的意义,torch.matmul 矩阵乘法
  2. 理解各个参数的含义

代码来源:https://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/model/attention/SelfAttention.py

import numpy as np
import torch
from torch import nn
from torch.nn import init



class ScaledDotProductAttention(nn.Module):
    '''
    Scaled dot-product attention
    '''

    def __init__(self, d_model, d_k, d_v, h,dropout=.1):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''
        super(ScaledDotProductAttention, self).__init__()
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)
        self.dropout=nn.Dropout(dropout)

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)
        att = torch.softmax(att, -1)
        att=self.dropout(att)

        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)
        return out


if __name__ == '__main__':
    input=torch.randn(50,49,512)
    sa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)
    output=sa(input,input,input)
    print(output.shape)

    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1330855.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

找到字符串中所有字母异位词--滑动窗口

个人主页:Lei宝啊 愿所有美好如期而遇 本体题目链接https://leetcode.cn/problems/VabMRr/description/ 算法原理 滑动窗口其实就是种双指针,只是这种双指针只向后移动,不会回退,具有单调性,也就是说,…

C语言的分支和循环语句

各位少年,今天和大家分享的是分支语句循环体语句,C语言是结构体的程序设计语言,这里的结构指的是(顺序结构)(选择结构)(循环结构)C语言是能够实现这三种结构的&#xff0…

了解树和学习二叉树

1.树 1.1 概念 树是一种 非线性 的数据结构,它是由 n ( n>0 )个有限结点组成一个具有层次关系的集合。 把它叫做树是因为它看 起来像一棵倒挂的树,也就是说它是根朝上,而叶朝下的 。 注意:树形结构中…

html网页编写语言

html是一门语言,所有的网页都是用html这门语言编写出来的。 HTML(HyperText Markup Language):超文本标记语言。 超文本:超越了文本的限制,比普通文本更强大。除了文字信息,还可以定义图片&…

产品原型设计软件 Axure RP 9 mac支持多人写作设计

axure rp 9 mac是一款产品原型设计软件,它可以让你在上面任意构建草图、框线图、流程图以及产品模型,还能够注释一些重要地方,axure rp汉化版可支持同时多人写作设计和版本管理控制,这款交互式原型设计工具可以帮助设计者制作出高…

【YOLOV8预测篇】使用Ultralytics YOLO进行检测、分割、姿态估计和分类实践

目录 一 安装Ultralytics 二 使用预训练的YOLOv8n检测模型 三 使用预训练的YOLOv8n-seg分割模型 四 使用预训练的YOLOv8n-pose姿态模型 五 使用预训练的YOLOv8n-cls分类模型 <

linux运维的面试题一

1.linux启动过程 1加电 2加载主板bios设置 3加载多重操作系统启动管理器grub 4加载内核系统到内存当中 5加载配置文件 6加载内核模块 7完成相应的初始化工作和启动相应的服务 8启动系统进程 9出现登录界面 10开机启动完成 2.安装过操作系统吗&#xff1f;怎么安装? 1.小批量设…

波奇学Linux:进程替换

单进程替换 excel使得能够在文件中运行系统指令 int excel (系统文件地址&#xff0c;系统指令&#xff0c;指令参数&#xff0c;NULL)&#xff1b; 成功时无返回值&#xff0c;失败时返回-1 如图进程成功运行第一个printf后再运行指令&#xff0c;但没有输出第二个printf的内容…

Jenkins 插件管理指南

目录 常用插件 插件安装 已安装插件 installed plugins 常用插件 Docker Plugin&#xff1a; 这个插件让Jenkins能够与Docker容器平台进行集成。它允许在Jenkins构建过程中创建、管理和销毁Docker容器&#xff0c;为需要Docker化的项目提供了极大的便利性。对于需要在容器中…

Leetcode—1491.去掉最低工资和最高工资后的工资平均值【简单】

2023每日刷题&#xff08;六十八&#xff09; Leetcode—1491.去掉最低工资和最高工资后的工资平均值 实现代码 class Solution { public:double average(vector<int>& salary) {double sum 0;int n salary.size();sort(salary.begin(), salary.end());for(int i…

全部没有问题(三)

C语言 吞空格 一共三种办法&#xff1a; fflush(stdin), while(getchar()!\n), scanf("%*c") p.s visual studio已经移除gets函数了&#xff0c;要换编译器/函数&#xff0c;visual C可以的 #include <stdio.h>int main() {char x[100], y[100], z[100];i…

贪吃蛇(十)贪吃蛇吃食物

上节讲到限制蛇身回头&#xff0c;本节要实现吃食物功能 实现思路 在存储上食物方面可以复用蛇的结构体。初始化食物的时候&#xff0c;我们设置食物的坐标&#xff0c;每次调用这个函数的时候&#xff0c;坐标发生一些规律的变化。另外我们需要扫描食物的函数&#xff0c;这…

数据大模型与低代码开发:赋能技术创新的黄金组合

在当今技术领域&#xff0c;数据大模型和低代码开发已经成为两个重要的趋势。数据大模型借助庞大的数据集和强大的计算能力&#xff0c;助力我们从海量数据中挖掘出有价值的洞见和预测能力。与此同时&#xff0c;低代码开发通过简化开发流程和降低编码需求&#xff0c;使得更多…

淘宝商品评论API:电商行业的战略资源与制胜之道

在电商行业&#xff0c;数据是金。而其中&#xff0c;用户评论数据更是无价之宝。它不仅仅反映了商品的质量和卖家的服务态度&#xff0c;更是消费者在决策时的关键参考。正因如此&#xff0c;获得淘宝商品评论API的重要性不言而喻。 一、数据背后的无尽宝藏 淘宝&#xff0c;…

java的XWPFDocument3.17版本学习

maven依赖 <dependency><groupId>org.apache.poi</groupId><artifactId>poi-ooxml</artifactId><version>3.17</version> </dependency> 测试类&#xff1a; import org.apache.poi.openxml4j.exceptions.InvalidFormatExcep…

人工智能_机器学习072_SVM支持向量机_人脸识别模型训练_训练时间过长解决办法_数据降维_LFW人脸数据建模与C参数选择---人工智能工作笔记0112

我们先来看一下之前的代码: import numpy as np 导入数学计算库 from sklearn. svm import SVC 导入支持向量机 线性分类器 import matplotlib.pyplot as plt 加载人脸图片以后,我们用pyplot把人脸图片数据展示一下 from sklearn.model_selection import train_test_split 人脸…

3. 结构型模式 - 组合模式

亦称&#xff1a; 对象树、Object Tree、Composite 意图 组合模式是一种结构型设计模式&#xff0c; 你可以使用它将对象组合成树状结构&#xff0c; 并且能像使用独立对象一样使用它们 问题 如果应用的核心模型能用树状结构表示&#xff0c; 在应用中使用组合模式才有价值。 …

优化模型:MATLAB整数规划

一、整数规划介绍 1.1 整数规划的定义 若规划模型的所有决策变量只能取整数时&#xff0c;称为整数规划。若在线性规划模型中&#xff0c;变量限制为整数&#xff0c;则称为整数线性规划。 1.2 整数规划的分类 整数规划模型大致可分为两类&#xff1a; &#xff08;1&…

机场数据治理系列介绍(1):机场行业数据中台中的数据管理顶层设计方案

目录 1、战略目标与价值 2、组织架构与角色分配 3、数据流程管理 4、数据质量管理 5、数据安全与隐私保护 6、技术平台选择与部署 7、数据服务与API管理 8、数据培训与文化推广 9、持续改进与优化 民航机场行业的数据中台中&#xff0c;有关数据管理顶层设计方案需要结…

C# WPF上位机开发(文件对话框和目录对话框)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 一个上位机软件在处理数据的时候&#xff0c;除了配置文件、数据文件之外&#xff0c;一般还需要使用选择对话框进行文件和目录的选取。如果不这样…