深度学习笔记--基于C++手撕self attention机制

news2024/11/16 19:38:26

目录

1--self attention原理

2--C++代码

3--拓展

3-1--mask self attention

3-2--cross attention


1--self attention原理

        直观来讲,就是每个 token 的 Query 去和其它 token(包括自身)的 Key 进行 dot product(点积)来计算权重 weight,weight 一般需要进行 softmax 归一化操作,然后将 weight 与每个 token 的 Value 进行加权。每个 token 通过这种方式来获取和学习其它 token 的特征,并作为下一层的输入。

2--C++代码

基于 egien 库实现:

#include <iostream>
#include <Eigen/Dense>

// 定义 Self-Attention 函数
Eigen::MatrixXd selfAttention(const Eigen::MatrixXd& input) {
    int seq_length = input.rows();
    int hidden_size = input.cols();

    // 初始化权重矩阵 WQ, WK, WV
    Eigen::MatrixXd WQ = Eigen::MatrixXd::Random(hidden_size, hidden_size);
    Eigen::MatrixXd WK = Eigen::MatrixXd::Random(hidden_size, hidden_size);
    Eigen::MatrixXd WV = Eigen::MatrixXd::Random(hidden_size, hidden_size);

    // 计算 Q, K, V 矩阵
    Eigen::MatrixXd Q = input * WQ;
    Eigen::MatrixXd K = input * WK;
    Eigen::MatrixXd V = input * WV;

    // 计算注意力分数
    Eigen::MatrixXd scores = (Q * K.transpose()).eval();

    // 缩放注意力分数
    scores /= sqrt(hidden_size);

    // 计算注意力权重(替代softmax)
    Eigen::MatrixXd attention_weights = scores.unaryExpr([](double x) { return exp(x); });
    Eigen::MatrixXd weights_sum = attention_weights.rowwise().sum(); // 按行求和
    for(int i = 0; i < attention_weights.rows(); i++){ // 归一化
        attention_weights.row(i) /= weights_sum(i);
    }

    // 计算输出
    Eigen::MatrixXd output = attention_weights * V;
    return output;
}

int main(){
    int seq_length = 5;
    int hidden_size = 4;
    Eigen::MatrixXd input = Eigen::MatrixXd::Random(seq_length, hidden_size); // 随机初始化输入矩阵
    Eigen::MatrixXd output = selfAttention(input); // 计算 Self-Attention 输出

    // 打印输入输出
    std::cout << "Input:\n" << input << "\n";
    std::cout << "Output:\n" << output << "\n";
    return 0;
}

3--拓展

        拓展记录一下 transformer 中的 mask self attention 和  cross attention 机制。

3-1--mask self attention

        mask self attention 与 self attention 最大的区别在于:self attention 中每一个 token 可以看到和获取所有 token 的特征,而 mask self attention 的 token 只能看到其前面(左边)的 token 特征,并不能聚合其后面的 token。

self attention:(全局视野)

mask self attention:

        对于上图的 mask self attention,a1 只能聚合本身的特征,a2 可以聚合 a1 和本身的特征,a4 则可以聚合前面全部 token 的特征;

        具体实现,就是下图中的 a2 的 Query 只去和 a1 和本身的 Key 来计算权重 weight,并且 weight 也只和 a1 和 a2 的 Value 进行加权;

3-2--cross attention

        上图中,transformer的 decoder 部分引入了 cross attention,其输入包括三部分,其中两部分来自于 encoder,即 Key 和 Value;第三部分来自于 decoder,即 Query;Query 通过查询来聚合来自 encoder 的特征;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1200921.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Tektronix(泰克)示波器TBS1102B测试电压

对于 Tektronix TBS1102B 示波器来说&#xff0c;测试电压的步骤基本如下&#xff1a; 连接测量点&#xff1a; 将被测电路的测量点连接到示波器的输入通道。使用正确的探头并确保连接的极性正确。 选择通道&#xff1a; 选择示波器上的通道&#xff0c;你想要测量的电压可能连…

Python BeautifulSoup 库使用教程

文章目录 简介安装 BeautifulSoup 库BeautifulSoup 库的导入BeautifulSoup 库依赖的解析库创建 BeautifulSoup 对象CSS选择器1、通过标签名查找2、通过 CSS 的类名查找3、通过 Tag(标签) 的 id 查找4、通过 是否存在某个属性来查找5、通过 某个标签是否存在某个属性来查找 获取…

【python后端】- 初识Django框架

Django入门 &#x1f604;生命不息&#xff0c;写作不止 &#x1f525; 继续踏上学习之路&#xff0c;学之分享笔记 &#x1f44a; 总有一天我也能像各位大佬一样 &#x1f31d;分享学习心得&#xff0c;欢迎指正&#xff0c;大家一起学习成长&#xff01; 文章目录 Django入门…

Vue3:自定义图标选择器(包含 SVG 图标封装)

文章目录 一、准备工作&#xff08;在 Vue3 中使用 SVG&#xff09;二、封装 SVG三、封装图标选择器四、Demo 效果预览&#xff1a; 一、准备工作&#xff08;在 Vue3 中使用 SVG&#xff09; 本文参考&#xff1a;https://blog.csdn.net/houtengyang/article/details/1290431…

Carla之语义分割及BoundingBox验证模型

参考&#xff1a; Carla系列——4.Cara模拟器添加语义分割相机&#xff08;Semantic segmentation camera&#xff09; Carla自动驾驶仿真五&#xff1a;opencv绘制运动车辆的boudingbox&#xff08;代码详解&#xff09; Carla官网Bounding Boxes Carla官网创建自定义语义标签…

数据结构-堆和二叉树

目录 1.树的概念及结构 1.1 树的相关概念 1.2 树的概念 1.3 树的表示 1.4 树在实际中的应用&#xff08;表示文件系统的目录树结构&#xff09; 2.二叉树的概念及结构 2.1 概念 2.2 特殊的二叉树 2.3 二叉树的存储 3.堆的概念及结构 4.堆的实现 初始化堆 堆的插入…

从0开始python学习-32.pytest.mark()

目录 1. 用户自定义标记 1.1 注册标记​编辑 1.2 给测试用例打标记​编辑 1.3 运行标记的测试用例 1.4 运行多个标记的测试用例 1.5 运行指定标记以外的所有测试用例 2. 内置标签 2.1 skip &#xff1a;无条件跳过&#xff08;可使用在方法&#xff0c;类&#xff0c;模…

6可靠的局域网组建

前面聊的拓扑结构都比较简单&#xff0c;所以能用&#xff0c;但是未必可靠。为了可靠&#xff0c;我们需要做冗余&#xff0c;同时需要做一些其他的配置。 生成树协议STP 假设交换机按照上面的方案连&#xff0c;虽然可以提高网络可靠性&#xff0c;但是因为形成了环路&#…

基于粒子群算法优化概率神经网络PNN的分类预测 - 附代码

基于粒子群算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于粒子群算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于粒子群优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要&#xff1a;针对PNN神经网络…

layui 表格(table)合计 取整数

第一步 开启合计行 是否开启合计行区域 table.render({elem: #myTable, url: ../baidui/, page: true, cellMinWidth: 100,totalRow:true,cols: [[ //表头//{ type: checkbox },{ type: checkbox,totalRowText: "合计" },//合计行区域{ field: id, align: center,…

c语言:解决数组中数组缺少单个的元素的问题

题目&#xff1a;数组nums包含从0到n的所以整数&#xff0c;但其中缺了一个。请编写代码找出那个缺失的整数。O(n)时间内完成。 如&#xff0c;输入&#xff1a;【3&#xff0c;0&#xff0c;1】。 输出&#xff1a; 2 三种方法 &#xff1a; 方法1&#xff1a;排序&#xf…

递归和master公式 系统栈 + 计算时间复杂度

前置知识&#xff1a;无 1&#xff09;从思想上理解递归&#xff1a;对于新手来说&#xff0c;递归去画调用图是非常重要的&#xff0c;有利于分析递归 2&#xff09;从实际上理解递归&#xff1a;递归不是玄学&#xff0c;底层是利用系统栈来实现的 3&#xff09;任何递归函…

Autosar UDS开发01(UDS诊断入门概念(UDSOnCan))

目录 回顾接触UDS的过程 UDS基本概念 UDS的作用 UDS的宏观认识 UDS的CAN通讯链路 UDS的报文种类 回顾接触UDS的过程 自21年毕业后&#xff0c;我一直干了2年的Autosar CAN通讯开发。 开发的主要内容简单概括就是&#xff1a;应用报文开发、网管报文开发、休眠唤醒开发&am…

职业迷茫,我该如何做好职业规划

案例25岁男&#xff0c;入职2月&#xff0c;感觉自己在混日子&#xff0c;怕能力没有提升&#xff0c;怕以后薪资也提不起来。完全不知道应该往哪个方向进修&#xff0c;感觉也没有自己特别喜欢的。感觉自己特别容易多想&#xff0c;想多年的以后一事无成的样子。 我觉得这个案…

腾讯云5年服务器CVM和3年轻量应用服务器配置价格表

腾讯云3年轻量和5年云服务器CVM活动入口&#xff0c;3年轻量应用服务器配置可选2核2G4M和2核4G5M带宽&#xff0c;5年CVM云服务器可以选择2核4G和4核8G配置可选&#xff0c;阿腾云atengyun.com分享腾讯云3年轻量应用服务器和5年云服务器CVM活动入口和配置报价&#xff1a; 目录…

3.如何实现 API 全局异常处理?-web组件篇

文章目录 1. 统一异常处理 1. 统一异常处理 在 Spring MVC 中&#xff0c;通过 ControllerAdvice ExceptionHandler 注解&#xff0c;声明将指定类型的异常&#xff0c;转换成对应的 CommonResult 响应。实现的代码&#xff0c;可见 GlobalExceptionHandler类。

【微软技术栈】C#.NET 如何使用本地化的异常消息创建用户定义的异常

本文内容 创建自定义异常创建本地化异常消息 在本文中&#xff0c;你将了解如何通过使用附属程序集的本地化异常消息创建从 Exception 基类继承的用户定义异常。 一、创建自定义异常 .NET 包含许多你可以使用的不同异常。 但是&#xff0c;在某些情况下&#xff0c;如果它们…

springboot苍穹外卖实战:七、店铺营业状态设置与查询+接口文档多端分组优化

店铺营业状态设置与查询 注意&#xff0c;先把测试类最上面的SpringBootTest注解注释掉&#xff0c;否则每次启动项目都会自动帮你测试一遍&#xff0c;导致项目启动变慢。 其次&#xff0c;该项目没有设置相应拦截器对付以下情况&#xff1a;用户使用过程中商家突然暂停营业&…

Django(复习篇)

项目创建 1. 虚拟环境 python -m venv my_env ​ cd my_env activate/deactivate ​ pip install django ​2. 项目和app创建 cd mypros django-admin startproject Pro1 django-admin startapp app1 ​3. settings配置INSTALLED_APPS【app1"】TEMPLATES【 DIRS: [os.pat…

低价寄快递寄件微信小程序 实际商用版 寄快递 低价寄快递小程序(源代码+截图)前后台源码

盈利模式 快递代下CPS就是用户通过线上的渠道&#xff08;快递小程序&#xff09;&#xff0c;线上下单寄快递来赚取差价&#xff0c;例如你的成本价是5元&#xff0c;你在后台比例设置里面设置 首重利润是1元&#xff0c;续重0.5元&#xff0c;用户下1kg的单页面显示的就是6元…