SelfAttention|自注意力机制ms简单实现

news2024/12/27 19:48:17

自注意力机制学习有感

  • 观看b站博主的讲解视频以及跟着他的pytorch代码实现mindspore的自注意力机制:
  • up主讲的很好,推荐入门自注意力机制。
import mindspore as ms
import mindspore.nn as nn
from mindspore import Parameter
from mindspore import context
context.set_context(device_target='Ascend',max_device_memory='1GB') 

class SelfAttention(nn.Cell):
    def __init__(self, dim):
        super(SelfAttention, self).__init__()
        wq_data = [[1.0, 0], [1., 1.]] # wq权重初始化 超参数
        wk_data = [[0., 1.], [1., 1.]] # wk权重初始化 超参数
        wv_data = [[0., 1., 1.], [1., 0., 0.]] # wv权重初始化 超参数
        
        self.q = nn.Dense(in_channels=dim, out_channels=2, has_bias=False)
        self.q.weight.set_data(ms.Tensor(wq_data).T)
        print("wq value:", self.q.weight.value())
        
        self.k = nn.Dense(in_channels = dim, out_channels=2, has_bias=False)
        self.k.weight.set_data(ms.Tensor(wk_data).T)
        print('wk value:', self.k.weight.value())
        
        self.v = nn.Dense(in_channels=dim, out_channels=3, has_bias=False)
        # print(self.v.weight.shape)
        self.v.weight.set_data(ms.Tensor(wv_data).T)
        print('wv value:',self.v.weight.value())
        print("*********************" * 2)
        
    def construct(self, x):
        q = self.q(x)
        print('q value:', q)
        k = self.k(x)
        print('k value:', k)
        v = self.v(x)
        # xx = x.matmul(ms.Tensor([[0., 1., 1.], [1., 0., 0.]]))
        print('v value:', v, '\n')
        print('#################################')
        x = (q @ k.T)/ms.ops.sqrt(ms.tensor(2.))
        x = ms.ops.softmax(x) @ v
        print("result:", x)
        

x = [[1., 1.],[1,0],[2,1],[0, 2.]]
x = ms.Tensor(x)
attn = SelfAttention(2)
attn(x)

结果如下:

wq value: [[1. 1.]
 [0. 1.]]
wk value: [[0. 1.]
 [1. 1.]]
wv value: [[0. 1.]
 [1. 0.]
 [1. 0.]]
******************************************
q value: [[2. 1.]
 [1. 0.]
 [3. 1.]
 [2. 2.]]
k value: [[1. 2.]
 [0. 1.]
 [1. 3.]
 [2. 2.]]
v value: [[1. 1. 1.]
 [0. 1. 1.]
 [1. 2. 2.]
 [2. 0. 0.]] 

#################################
result: [[1.5499581  0.71284014 0.71284014]
 [1.3395231  0.7726004  0.7726004 ]
 [1.7247156  0.4475609  0.4475609 ]
 [1.4366053  1.         1.        ]]

** 吐槽mindspore说明文档,对ms.nn.Dense的说明太过简单了,有对新手真不友好(对我) **

  • pytorch的文档:
    在这里插入图片描述
  • mindspore的文档:
    在这里插入图片描述
    pytorch有公式,至少提示A的转置有提示。mindspore没有,导致我这步实现的时候输出的结果不对,还是希望mindspore说明问昂也把公式写清楚点。其实mindspore的Dense和pytorch的Linear的公式实现是一样的。
    附上pytorch的实现:
#@title Default title text 
import torch
import torch_npu
import torch.nn as nn
class Self_Attention(torch.nn.Module):
    def __init__(self, dim):
        super(Self_Attention, self).__init__() #  其中qkv代表构建好训练好的wq,wk,wv的权重参数;
        self.scale = 2 ** -0.5
        self.q = torch.nn.Linear(dim, 2, bias=False) 
        q_list = [[1., 0.],[1., 1.]]
        self.q.weight.data = torch.Tensor(q_list).T
        print('q value:', self.q.weight.data)
        
        self.k = nn.Linear(dim, 2, bias=False)
        
        k_list = [[0., 1.], [1., 1.]]
        self.k.weight.data = torch.Tensor(k_list).T
        print('k value:', self.k.weight.data)
        
        self.v = nn.Linear(dim,3,bias=False)
        v_list = [[0., 1., 1.],[1., 0., 0.]]
        
        # print("origin shape:", self.v.weight.data.shape)
        
        self.v.weight.data = torch.Tensor(v_list).T
        print('init shape:',self.v.weight.data)
        
    def forward(self, x):
        q = self.q(x)  # 通过训练好的参数生成q参数
        print("q:", q)
        
        k = self.k(x)
        print("k:", k)
        
        v = self.v(x)
        print("v shape:", v.shape)
        
        # Att公式
        attn = (q.matmul(k.T)) / torch.sqrt(torch.tensor(2.0))
        print("attn1:", attn)
        
        # attn = (q @ k.transpose(-2, -1)) / torch.sqrt(torch.tensor(2.0))
        # print("attn11:", attn)
        # attn = (q @ k.transpose(-2, -1)) * self.scale
        # print("attn2:", attn)
        attn = attn.softmax(dim=-1)
        print("softmax attn:", attn)
        # print(attn.shape) # shape[4,4]
        x = attn @ v
        print(x.shape)  #shape[4,3]
        return x 
x = [[1., 1.],[1,0],[2,1],[0, 2.]]
x = torch.Tensor(x)
att = Self_Attention(2)  
att(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1449640.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

每日五道java面试题之java基础篇(九)

目录: 第一题 你们项⽬如何排查JVM问题第二题 ⼀个对象从加载到JVM,再到被GC清除,都经历了什么过程?第三题 怎么确定⼀个对象到底是不是垃圾?第四题 JVM有哪些垃圾回收算法?第五题 什么是STW? 第…

Spring Boot 笔记 017 创建接口_新增文章

1.1实体类增加校验注释 1.1.1 自定义校验 1.1.1.1 自定义注解 package com.geji.anno;import com.geji.validation.StateValidation; import jakarta.validation.Constraint; import jakarta.validation.Payload; import jakarta.validation.constraints.NotEmpty;import jav…

(03)Hive的相关概念——分区表、分桶表

目录 一、Hive分区表 1.1 分区表的概念 1.2 分区表的创建 1.3 分区表数据加载及查询 1.3.1 静态分区 1.3.2 动态分区 1.4 分区表的本质及使用 1.5 分区表的注意事项 1.6 多重分区表 二、Hive分桶表 2.1 分桶表的概念 2.2 分桶表的创建 2.3 分桶表的数据加载 2.4 …

kali无线渗透之用wps加密模式破解出wpa模式的密码12

WPS(Wi-Fi Protected Setup,Wi-Fi保护设置)是由Wi-Fi联盟推出的全新Wi-Fi安全防护设定标准。该标准推出的主要原因是为了解决长久以来无线网络加密认证设定的步骤过于繁杂之弊病,使用者往往会因为步骤太过麻烦,以致干脆不做任何加密安全设定&…

飞天使-k8s知识点17-kubernetes实操2-pod探针的使用

文章目录 探针的使用容器探针启动实验1-启动探针的使用-startupprobeLiveness Probes 和 Readiness Probes演示若存在started.html 则进行 探针的使用 kubectl edit deploy -n kube-system corednslivenessprobe 的使用 livenessProbe:failureThreshold: 5httpGet:path: /heal…

谷歌搜索技巧与 ChatGPT 实用指南:提升你的在线生产力

探索谷歌搜索技巧,提升搜索效率 前言 在搜索三巨头百度、必应、谷歌中,谷歌在搜索精确度以及多语言兼容性方面有明显的优势。其次在国内想要使用谷歌搜索你需要会科学上网(这里不说)。 一.排除干扰内容(广告&#xff…

RAG (Retrieval Augmented Generation)简介

1. 背景 目前大模型很多,绝大部分大模型都是通用型大模型,也就是说使用的是标准的数据,比如wikipedia,百度百科,。。。。 中小型企业一般都有自己的知识库,而这些知识库的数据没有在通用型的大模型中被用到…

mysql数据库 mvcc

在看MVCC之前我们先补充些基础内容,首先来看下事务的ACID和数据的总体运行流程 数据库整体的使用流程: ACID流程图 mysql核心日志: 在MySQL数据库中有三个非常重要的日志binlog,undolog,redolog. mvcc概念介绍: MVCC(Multi-Version Concurr…

【MySQL】外键约束的删除和更新总结

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​💫个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-7niJLSFaPo0wso60 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-siz…

MySQL 基础知识(一)之数据库和 SQL 概述

目录 1 数据库相关概念 2 数据库的结构 ​3 SQL 概要 4 SQL 的基本书写规则 1 数据库相关概念 数据库是将大量的数据保存起来,通过计算机加工而成的可以进行高效访问的数据集合数据库管理系统(DBMS)是用来管理数据库的计算机系统&#xf…

HCIA-HarmonyOS设备开发认证V2.0-轻量系统内核基础-互斥锁mux

目录 一、互斥锁基本概念二、互斥锁运行机制三、互斥锁开发流程四、互斥锁使用说明五、互斥锁接口六、代码分析(待续...) 一、互斥锁基本概念 互斥锁又称互斥型信号量,是一种特殊的二值性信号量,用于实现对共享资源的独占式处理。…

Web 目录爆破神器:Dirsearch 保姆级教程

一、介绍 dirsearch 是一款用于目录扫描的开源工具,旨在帮助渗透测试人员和安全研究人员发现目标网站上的隐藏目录和文件。与 dirb 类似,它使用字典文件中的单词构建 URL 路径,然后发送 HTTP 请求来检查这些路径是否存在。 以下是 dirsearc…

Python算法深度探索:从基础到进阶

引言 本文将引导您从Python的基础算法出发,逐步深入到更复杂的算法领域。我们将探讨数组操作、图算法以及机器学习中的常用算法,并通过实例和代码展示它们在实际应用中的价值。 1. 基础算法:数组操作 数组操作是算法实现中非常基础且重要的一…

预算紧缩下创新创业者应采取哪3个策略来保持创新?

在今天越来越饱和的消费市场中,品牌零售通过复杂、过度的的促销、折扣、优惠券和忠诚度奖励来吸引消费者,但这种做法可能削弱消费者的忠诚度,损害品牌声誉,并抑制新的收入机会。相反,零售商应采取更简化、以客户为中心…

【Android】使用Apktool反编译Apk文件

文章目录 1. 下载Apktool1.1 Apktool官网下载1.2 百度网盘下载 2. 安装Apktool3. 使用Apktool3.1 配置Java环境3.2 准备Apk文件3.3 反编译Apk文件3.3.1 解包Apk文件3.3.2 修改Apk文件3.3.3 打包Apk文件3.3.4 签名Apk文件 1. 下载Apktool 要使用Apktool,需要准备好 …

学习笔记20:牛客周赛32

D 统计子节点中1的个数即可&#xff08;类似树形dp&#xff1f;&#xff09; #include<iostream> #include<cstring> #include<cmath> #include<algorithm> #include<queue> #include<vector> #include<set> #include<map>u…

C#利用接口实现选择不同的语种

目录 一、涉及到的知识点 1.接口定义 2.接口具有的特征 3.接口通过类继承来实现 4.有效使用接口进行组件编程 5.Encoding.GetBytes(String)方法 &#xff08;1&#xff09;检查给定字符串中是否包含中文字符 &#xff08;2&#xff09;编码和还原前后 6.Encoding.GetS…

属性/成员变量

一、属性/成员变量 二、注意事项 三、创建对象

OpenCV-30 腐蚀操作

一、引入 腐蚀操作也是用卷积核扫描图像&#xff0c;只不过腐蚀操作的卷积核一般都是1&#xff08;卷积核内的每个数字都为1&#xff09;&#xff0c;如果卷积核内所有像素点都是白色&#xff0c;那么锚点&#xff08;中心点&#xff09;即为白色。 大部分时候腐蚀操作使用的都…

石子合并+环形石子合并+能量项链+凸多边形的划分——区间DP

一、石子合并 (经典例题) 设有 N 堆石子排成一排&#xff0c;其编号为 1,2,3,…,N。 每堆石子有一定的质量&#xff0c;可以用一个整数来描述&#xff0c;现在要将这 N 堆石子合并成为一堆。 每次只能合并相邻的两堆&#xff0c;合并的代价为这两堆石子的质量之和&#xff0c;…