PyTorch 实现图像版多头注意力(Multi-Head Attention)和自注意力(Self-Attention)

news2025/4/8 19:23:43

本文提供一个适用于图像输入的多头注意力机制(Multi-Head Attention)PyTorch 实现,适用于 ViT、MAE 等视觉 Transformer 中的注意力计算。


模块说明

  • 输入支持图像格式 (B, C, H, W)
  • 内部转换为序列 (B, N, C),其中 N = H * W
  • 多头注意力计算:查询(Q)、键(K)、值(V)使用线性层投影
  • 结果 reshape 回原图维度 (B, C, H, W)

多头注意力机制代码(适用于图像输入)

import torch
import torch.nn as nn

class ImageMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(ImageMultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Q, K, V 的线性映射
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # 输出映射层
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.scale = self.head_dim ** 0.5

    def forward(self, x):
        # 输入 x: (B, C, H, W),需要 reshape 为 (B, N, C)
        B, C, H, W = x.shape
        x = x.view(B, C, H * W).permute(0, 2, 1)  # (B, N, C)

        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # 拆成多头 (B, num_heads, N, head_dim)
        Q = Q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # 注意力分数计算
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_out = torch.matmul(attn_probs, V)

        # 合并多头
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, H * W, self.embed_dim)

        # 输出映射
        out = self.out_proj(attn_out)

        # 恢复回原图维度 (B, C, H, W)
        out = out.permute(0, 2, 1).view(B, C, H, W)
        return out

# 测试示例
# 假设输入是一张 14x14 的特征图(类似 patch embedding 后)
img = torch.randn(4, 64, 14, 14)  # (B, C, H, W)

mha = ImageMultiHeadAttention(embed_dim=64, num_heads=8)
out = mha(img)

print(out.shape)  # 输出应为 (4, 64, 14, 14)


PyTorch 实现自注意力机制(Self-Attention)

本节补充自注意力机制(Self-Attention)的核心代码实现,适用于 ViT 等模型中 patch token 的注意力操作。

自注意力机制代码(Self-Attention)

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.scale = embed_dim ** 0.5

    def forward(self, x):
        # 输入 x: (B, N, C)
        B, N, C = x.shape

        # 一次性生成 Q, K, V
        qkv = self.qkv_proj(x)  # (B, N, 3C)
        Q, K, V = torch.chunk(qkv, chunks=3, dim=-1)  # 各自为 (B, N, C)

        # 计算注意力分数
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # (B, N, N)
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # 得到注意力加权输出
        attn_out = torch.matmul(attn_probs, V)  # (B, N, C)

        # 映射回原维度
        out = self.out_proj(attn_out)  # (B, N, C)
        return out
        
#  测试示例
# 假设输入为 196 个 patch,每个 patch 的嵌入维度为 64
x = torch.randn(2, 196, 64)  # (B, N, C)

attn = SelfAttention(embed_dim=64)
out = attn(x)

print(out.shape)  # 输出应为 (2, 196, 64)

📎 拓展说明
• 本实现为单头自注意力机制
• 可用于 NLP 中的序列特征或 ViT 图像 patch 序列
• 若需改为多头注意力,只需将 embed_dim 拆成 num_heads × head_dim 并分别计算后合并


PyTorch 实现图像输入的自注意力机制(Self-Attention)

本节介绍一种适用于图像输入 (B, C, H, W) 的自注意力机制实现,适合卷积神经网络与 Transformer 的融合模块,如 Self-Attention ConvNet、BAM、CBAM、ViT 前层等。

自注意力机制(图像维度)代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class ImageSelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(ImageSelfAttention, self).__init__()
        self.in_channels = in_channels
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv   = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))  # 可学习缩放因子

    def forward(self, x):
        # 输入 x: (B, C, H, W)
        B, C, H, W = x.size()

        # 生成 Q, K, V
        proj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1)  # (B, N, C//8)
        proj_key   = self.key_conv(x).view(B, -1, H * W)                      # (B, C//8, N)
        proj_value = self.value_conv(x).view(B, -1, H * W)                    # (B, C, N)

        # 注意力矩阵:Q * K^T
        energy = torch.bmm(proj_query, proj_key)         # (B, N, N)
        attention = F.softmax(energy, dim=-1)             # (B, N, N)

        # 加权求和 V
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # (B, C, N)
        out = out.view(B, C, H, W)

        # 残差连接 + 缩放因子
        out = self.gamma * out + x
        return out
        
#测试用例
x = torch.randn(2, 64, 32, 32)  # 输入一张图像:B=2, C=64, H=W=32
self_attn = ImageSelfAttention(in_channels=64)
out = self_attn(x)

print(out.shape)  # 输出形状应为 (2, 64, 32, 32)

• 本模块基于图像 (B, C, H, W) 进行自注意力计算
• 使用卷积进行 Q/K/V 提取,保持局部感知力
• gamma 是可学习缩放因子,用于残差连接控制注意力贡献度


自注意力中**缩放因子(scale factor)的处理,在序列维度(如 ViT)和图片维度(如 Self-Attention Conv)**中有点不一样。下面我们来详细解释一下原因,并对两种写法做一个统一和对比分析

两种缩放因子的区别
  1. 序列维度的缩放因子
scale = head_dim ** 0.5  # 或者 embed_dim ** 0.5
attn = (Q @ K.T) / scale

• 来源:Transformer 原始论文(Attention is All You Need)
• 原因:在高维向量内积中,为了避免 dot product 的结果数值过大导致梯度不稳定,需要除以 sqrt(d_k)
• 使用场景:多头注意力机制,输入是 (B, N, C),应用在 NLP、ViT 等序列结构

  1. 图片维度(C, H, W)的注意力机制中没有缩放,或者使用 softmax 平衡
attn = softmax(Q @ K.T)   # 无 scale,或者手动调节

• 来源:Non-local Net、Self-Attention Conv、BAM 等 CNN + Attention 融合方法
• 原因:Q 和 K 都通过 1x1 conv 压缩成 C//8 或更小的维度,内积的值本身不会太大;同时图像 attention 主要用 softmax 控制权重范围
• 缩放因子的控制通常用 γ(gamma)作为残差通道缩放,不是 QK 内部的数值缩放


💬 如果你觉得这篇整理有帮助,欢迎点赞收藏!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2329705.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Docker快速安装MongoDB并配置主从同步

目录 一、创建相关目录及授权 二、下载并运行MongoDB容器 三、配置主从复制 四、客户端远程连接 五、验证主从同步 六、停止和恢复复制集 七、常用命令 一、创建相关目录及授权 创建主节点mongodb数据及日志目录并授权 mkdir -p /usr/local/mongodb/mongodb1/data mkdir …

Golang系列 - 内存对齐

Golang系列-内存对齐 常见类型header的size大小内存对齐空结构体类型参考 摘要: 本文将围绕内存对齐展开, 包括字符串、数组、切片等类型header的size大小、内存对齐、空结构体类型的对齐等等内容. 关键词: Golang, 内存对齐, 字符串, 数组, 切片 常见类型header的size大小 首…

网络原理 - HTTP/HTTPS

1. HTTP 1.1 HTTP是什么? HTTP (全称为 “超文本传输协议”) 是⼀种应用非常广泛的应用层协议. HTTP发展史: HTTP 诞生于1991年. 目前已经发展为最主流使用的⼀种应用层协议 最新的 HTTP 3 版本也正在完善中, 目前 Google / Facebook 等公司的产品已经…

OCC Shape 操作

#pragma once #include <iostream> #include <string> #include <filesystem> #include <TopoDS_Shape.hxx> #include <string>class GeometryIO { public:// 加载几何模型&#xff1a;支持 .brep, .step/.stp, .iges/.igsstatic TopoDS_Shape L…

深度学习入门(四):误差反向传播法

文章目录 前言链式法则什么是链式法则链式法则和计算图 反向传播加法节点的反向传播乘法节点的反向传播苹果的例子 简单层的实现乘法层的实现加法层的实现 激活函数层的实现ReLu层Sigmoid层 Affine层/SoftMax层的实现Affine层Softmax层 误差反向传播的实现参考资料 前言 上一篇…

Linux:页表详解(虚拟地址到物理地址转换过程)

文章目录 前言一、分页式存储管理1.1 虚拟地址和页表的由来1.2 物理内存管理与页表的数据结构 二、 多级页表2.1 页表项2.2 多级页表的组成 总结 前言 在我们之前的学习中&#xff0c;我们对于页表的认识仅限于虚拟地址到物理地址转换的桥梁&#xff0c;然而对于具体的转换实现…

PostgreSQL 一文从安装到入门掌握基本应用开发能力!

本篇文章主要讲解 PostgreSQL 的安装及入门的基础开发能力,包括增删改查,建库建表等操作的说明。navcat 的日常管理方法等相关知识。 日期:2025年4月6日 作者:任聪聪 一、 PostgreSQL的介绍 特点:开源、免费、高性能、关系数据库、可靠性、稳定性。 官网地址:https://w…

WEB安全--内网渗透--LMNTLM基础

一、前言 LM Hash和NTLM Hash是Windows系统中的两种加密算法&#xff0c;不过LM Hash加密算法存在缺陷&#xff0c;在Windows Vista 和 Windows Server 2008开始&#xff0c;默认情况下只存储NTLM Hash&#xff0c;LM Hash将不再存在。所以我们会着重分析NTLM Hash。 在我们内…

8.用户管理专栏主页面开发

用户管理专栏主页面开发 写在前面用户权限控制用户列表接口设计主页面开发前端account/Index.vuelangs/zh.jsstore.js 后端Paginator概述基本用法代码示例属性与方法 urls.pyviews.py 运行效果 总结 欢迎加入Gerapy二次开发教程专栏&#xff01; 本专栏专为新手开发者精心策划了…

室内指路机器人是否支持与第三方软件对接?

嘿&#xff0c;你知道吗&#xff1f;叁仟室内指路机器人可有个超厉害的技能&#xff0c;那就是能和第三方软件 “手牵手” 哦&#xff0c;接下来就带你一探究竟&#xff01; 从技术魔法角度看哈&#xff1a;好多室内指路机器人都像拥有超能力的小魔法师&#xff0c;采用开放式…

从代码上深入学习GraphRag

网上关于该算法的解析都停留在大概流程上&#xff0c;但是具体解析细节未知&#xff0c;由于代码是PipeLine形式因此阅读起来比较麻烦&#xff0c;本文希望通过阅读项目代码来解析其算法的具体实现细节&#xff0c;特别是如何利用大模型来完成图谱生成和检索增强的实现细节。 …

【Redis】通用命令

使用者通过redis-cli客户端和redis服务器交互&#xff0c;涉及到很多的redis命令&#xff0c;redis的命令非常多&#xff0c;我们需要多练习常用的命令&#xff0c;以及学会使用redis的文档。 一、get和set命令&#xff08;最核心的命令&#xff09; Redis中最核心的两个命令&…

微前端随笔

✨ single-spa&#xff1a; js-entry 通过es-module 或 umd 动态插入 js 脚本 &#xff0c;在主应用中发送请求&#xff0c;来获取子应用的包&#xff0c; 该子应用的包 singleSpa.registerApplication({name: app1,app: () > import(http://localhost:8080/app1.js),active…

C++中的浅拷贝和深拷贝

浅拷贝只是将变量的值赋予给另外一个变量&#xff0c;在遇到指针类型时&#xff0c;浅拷贝只会把当前指针的值&#xff0c;也就是该指针指向的地址赋予给另外一个指针&#xff0c;二者指向相同的地址&#xff1b; 深拷贝在遇到指针类型时&#xff0c;会先将当前指针指向地址包…

车载诊断架构 --- 整车重启先后顺序带来的思考

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 周末洗了一个澡,换了一身衣服,出了门却不知道去哪儿,不知道去找谁,漫无目的走着,大概这就是成年人最深的孤独吧! 旧人不知我近况,新人不知我过…

【C++11(下)】—— 我与C++的不解之缘(三十二)

前言 随着 C11 的引入&#xff0c;现代 C 语言在语法层面上变得更加灵活、简洁。其中最受欢迎的新特性之一就是 lambda 表达式&#xff08;Lambda Expression&#xff09;&#xff0c;它让我们可以在函数内部直接定义匿名函数。配合 std::function 包装器 使用&#xff0c;可以…

Windows 10/11系统优化工具

家庭或工作电脑使用时间久了&#xff0c;会出现各种各样问题&#xff0c;今天给大家推荐一款专为Windows 10/11系统设计的全能优化工具&#xff0c;该软件集成了超过40项专业级实用程序&#xff0c;可针对系统性能进行深度优化、精准调校、全面清理、加速响应及故障修复。通过系…

浅谈在HTTP中GET与POST的区别

从 HTTP 报文来看&#xff1a; GET请求方式将请求信息放在 URL 后面&#xff0c;请求信息和 URL 之间以 &#xff1f;隔开&#xff0c;请求信息的格式为键值对&#xff0c;这种请求方式将请求信息直接暴露在 URL 中&#xff0c;安全性比较低。另外从报文结构上来看&#xff0c…

LightRAG实战:轻松构建知识图谱,破解传统RAG多跳推理难题

作者&#xff1a;后端小肥肠 &#x1f34a; 有疑问可私信或评论区联系我。 &#x1f951; 创作不易未经允许严禁转载。 姊妹篇&#xff1a; 2025防失业预警&#xff1a;不会用DeepSeek-RAG建知识库的人正在被淘汰_deepseek-embedding-CSDN博客 从PDF到精准答案&#xff1a;Coze…

C++多线程编码二

1.lock和try_lock lock是一个函数模板&#xff0c;可以支持多个锁对象同时锁定同一个&#xff0c;如果其中一个锁对象没有锁住&#xff0c;lock函数会把已经锁定的对象解锁并进入阻塞&#xff0c;直到多个锁锁定一个对象。 try_lock也是一个函数模板&#xff0c;尝试对多个锁…