pytorch代码实现注意力机制之MHSA

news2025/1/19 17:00:05

MHSA注意力机制

MHSA是多头自注意力机制(Multi-Head Self-Altention),是自然语言处理领域中用于语言模型中的一种特殊机制。它能够让模型在预测下一个词的时候,更好地关注句子中不同位置的词,以适应不同的语言场景。MHSA的核心思想是将一个线性变换分成多个头,每个头执行自注意力操作,并将所有头的输出拼接在一起作为最终的表示。在自注意力操作中,每个头都计算出一个注意力矩阵,该矩阵在整个序列中对不同位置的词进行加权求和,以得到每个位置的表示。MHSA的应用已被证明在许多自然语言处理任务中具有很好的效果。

论文地址:MHSA注意力机制原论文

MHSA结构图

代码实现:

import torch
import torch.nn as nn

class MHSA(nn.Module):
    def __init__(self, n_dims, width=14, height=14, heads=4, pos_emb=False):
        super(MHSA, self).__init__()

        self.heads = heads
        self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.pos = pos_emb
        if self.pos:
            self.rel_h_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, 1, int(height)]),
                                             requires_grad=True)
            self.rel_w_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, int(width), 1]),
                                             requires_grad=True)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        n_batch, C, width, height = x.size()
        q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
        k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
        v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)
        content_content = torch.matmul(q.permute(0, 1, 3, 2), k)  # 1,C,h*w,h*w
        c1, c2, c3, c4 = content_content.size()
        if self.pos:
            content_position = (self.rel_h_weight + self.rel_w_weight).view(1, self.heads, C // self.heads, -1).permute(
                0, 1, 3, 2)  # 1,4,1024,64

            content_position = torch.matmul(content_position, q)  # ([1, 4, 1024, 256])
            content_position = content_position if (
                    content_content.shape == content_position.shape) else content_position[:, :, :c3, ]
            assert (content_content.shape == content_position.shape)
            energy = content_content + content_position
        else:
            energy = content_content
        attention = self.softmax(energy)
        out = torch.matmul(v, attention.permute(0, 1, 3, 2))  # 1,4,256,64
        out = out.view(n_batch, C, width, height)
        return out

if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)
    mhsa = MHSA(n_dims=512)
    output = mhsa(input)
    print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/982754.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

学生台灯选什么光源好?2023热门护眼台灯推荐

现在的台灯可以说是孩子必不可少的一个学习灯具了,几乎每家每户都会备着一台。不过台的好坏也有区别,相对而言,以前所用的白炽灯、日光灯、节能灯等等传统台灯已经是不适合孩子使用的了,目前而言最好的是LED灯。下面小编为大家推荐…

ROS速成2——机器人运动控制

1. 2.实现思路 创建软件包 定义发布者对象,名字叫vel_pub, 让advertise发布一个类型为geometry_msgs的Twist,话题名称是cmd_vel 声明一个 Twist类型的消息包,名字叫vel_msg,用来承载要发送的速度值 开启while循环,不停使用vel_pub对象发送…

亚马逊美国站直接插入式夜间照明灯具认证标准要求UL1786检测报告办理周期

亚马逊为什么要求电子产品UL检测报告? 美国是一个对安全要求非常严格的国家,美国本土的所有电子产品生产企业早在很多年前就要求有相关安规检测。 而随着亚马逊在全球商业的战略地位不断提高,境外的电子设备通过亚马逊不断涌入美国市场&…

【Python】【Fintech】用Python和蒙特卡洛法预测投资组合未来收益

【背景】 想利用蒙特卡洛方法和yahoo,stooq等财经网站上的数据快速预测特定portfolio的收益。 【分析】 整个程序的功能包括 读取json中的portfolio组合创建蒙特卡洛模拟预测收益的算法创建从财经网站获得特定投资组合数据,并根据2的算法获得该Index或Portfolio收益预测结…

一套成熟的实验室信息管理系统(云LIS源码)ASP.NET CORE

一套成熟的实验室信息管理系统,集前处理、检验、报告、质控、统计分析、两癌等模块为一体的网络管理系统。它的开发和应用将加快检验科管理的统一化、网络化、标准化的进程。 LIS把检验、检疫、放免、细菌微生物及科研使用的各类分析仪器,通过计算机联…

正中优配:政策预期叠加资金面压制 债市回调至“降息”前

地产方针利好和资金面边沿收紧的压制之下,债券商场出现了回调。 到9月6日收盘,10年期国债收益率上行2.4个基点报2.665%,已回到降息之前的点位。 资金面也在收敛,到6日收盘,DR001加权均匀利率报1.51%,较前…

数学建模竞赛常用代码总结-PythonMatlab

数学建模过程中有许多可复用的基础代码,在此对 python 以及 MATLAB 中常用代码进行简单总结,该总结会进行实时更新。 一、文件读取 python (pandas) 文件后缀名(扩展名)并不是必须的,其作用主要一方面是提示系统是用…

ROS地图/像素坐标描点调试【Python源码实现】

文章目录 ROS python 地图描点调试工具1. Rviz描点1.1 需求描述1.2 visualization Marker1.3 工程实践 2. 静态地图图片描点2.1 需求描述2.2 工程实践 ROS python 地图描点调试工具 1. Rviz描点 1.1 需求描述 在ROS开发中,有时会加载图片文件转为地图载入move_ba…

算法——组合程序算法解析

组合就是从m个元素的数组中求n个元素的所有组合&#xff0c;代码如下&#xff1a; #include <iostream> #include <vector> using namespace std; // 递归求解组合 void combinations(vector<int>& nums, vector<int>& combination, int star…

RK3568开发笔记(七):在宿主机ubuntu上搭建Qt交叉编译开发环境,编译一个Demo,目标板运行Demo测试

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/132733901 红胖子网络科技博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV、OpenGL、ffmpeg、OSG、单片机、软硬…

【C++】—— 单例模式详解

前言&#xff1a; 本期&#xff0c;我将要讲解的是有关C中常见的设计模式之单例模式的相关知识&#xff01;&#xff01; 目录 &#xff08;一&#xff09;设计模式的六⼤原则 &#xff08;二&#xff09;设计模式的分类 &#xff08;三&#xff09;单例模式 1、定义 2、…

MySQL的故事——创建高性能的索引

创建高性能的索引 文章目录 创建高性能的索引一、索引基础二、索引的优点三、高性能的索引策略 一、索引基础 要理解MySQL中索引是如何工作的&#xff0c;最简单的方法就是去看看一本书的“索引 ”部分&#xff1a;如果在一本书中找到某个特定主题&#xff0c;一般会先看书的“…

Linux修复损坏的文件系统

如何判断文件系统是否损坏 当文件系统受损时&#xff0c;将会出现一些明显的迹象。例如&#xff0c;文件或文件夹无法访问、文件大小异常、系统启动慢或无法启动等。此外&#xff0c;系统也可能发出一些错误信息&#xff0c;如"Input/output error"、"Filesyst…

怎么观察敌人的具体情况

怎么观察敌人的具体情况&#xff1f; 【安志强趣讲《孙子兵法》第32讲】 【原文】 杖而立者&#xff0c;饥也&#xff1b;汲而先饮者&#xff0c;渴也&#xff1b;见利而不进者&#xff0c;劳也&#xff1b;鸟集者&#xff0c;虚也&#xff1b;夜呼者&#xff0c;恐也&#xff…

Nginx参数配置详细说明【全局、http块、server块、events块】【已亲测】

Nginx重点参数配置说明 本文包含Nginx参数配置说明全局块、http块、server块、events块共计30多个参数配置与解释&#xff0c;其中常见参数包含配置错误出现的错误日志&#xff0c;能让你更快的解决问题。 该文的所有参数大部分经过单独测试&#xff0c;错误都是自己收集出来的…

【opencv】多版本安装

安装opencv3.2.0以及对应的付费模块 一、安装多版本OpenCV如何切换 按照如下步骤安装的OpenCV&#xff0c;在CMakeLists.txt文件中&#xff0c;直接指定opencv的版本就可以找到相应版本的OpenCV&#xff0c;为了验证可以在CMakeLists.txt文件中使用如下指令输出版本验证&…

26.篮球练习

题目 Description 小徐酷爱打篮球&#xff0c;在小学期的前两周半都在练习篮球。 今天&#xff0c;小徐想要练习如何突破。练习场地可由如下所示的网格图表示&#xff0c;图中的位置可用坐标表示。 其中A点(0,0)为小徐的起始位置&#xff0c;B点(n,m)为小徐想要到达的位置。…

漏洞分析|Adobe ColdFusion WDDX 序列化漏洞利用

0x01 概述 在上一篇有关 Adobe ColdFusion 序列化漏洞&#xff08;CVE-2023-29300&#xff09;的文章中&#xff0c;我们对已公开的 JNDI 利用链&#xff08;CVE-2023-38204&#xff09;进行了复现。JNDI 利用链受目标出网的限制&#xff0c;在不出网的情况下无法很好地利用。…

二叉树的递归遍历和非递归遍历

目录 一.二叉树的递归遍历 1.先序遍历二叉树 2.中序遍历二叉树 3.后序遍历二叉树 二.非递归遍历(栈) 1.先序遍历 2.中序遍历 3.后序遍历 一.二叉树的递归遍历 定义二叉树 #其中TElemType可以是int或者是char,根据要求自定 typedef struct BiNode{TElemType data;stru…

核心实验11合集_hybrid接口特殊用法_ENSP

项目场景一&#xff1a; 核心实验11合集_hybrid接口特殊用法_ENSP 前期用户少&#xff0c;只有一个vlan段&#xff0c;如今需要划分不同vlan&#xff0c;使用hybrid接口实现。&#xff08;不可更改ip地址&#xff09; 实搭拓扑图&#xff1a; 具体操作&#xff1a; sw1: [sw1…