深度学习-----------------------注意力分数

news2024/12/29 19:02:54

目录

  • 注意力分数
    • 注意力打分函数代码
  • 掩蔽softmax操作
  • 拓展到高纬度
    • Additive Attention(加性注意力)
      • 加性注意力代码
      • 演示一下AdditiveAttention类
      • 该部分总代码
      • 注意力权重
    • Scaled Dot-Product Attention(缩放点积注意力)
      • 缩放点积注意力代码
      • 演示一下DotProductAttention类
      • 该部分总代码
      • 注意力权重
  • 总结

注意力分数

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述




注意力打分函数代码

import math
import torch
from torch import nn
from d2l import torch as d2l



掩蔽softmax操作

import torch
from torch import nn
from d2l import torch as d2l


def masked_softmax(X, valid_lens):
    """通过在最后一个轴上遮蔽元素来执行softmax操作"""
    if valid_lens is None:
        # 如果valid_lens为空,则对X执行softmax操作
        return nn.functional.softmax(X, dim=-1)
    else:
        # shape的形状为(2,2,4)
        shape = X.shape
        # 判断有效长度是否是一维的
        if valid_lens.dim() == 1:
            # valid_lens重复两次[2,3]→[2,2,3,3],和x的列数一样
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # 将valid_lens重塑为一维向量
            valid_lens = valid_lens.reshape(-1)
        # 在X的最后一个维度(即:列)上进行遮蔽操作
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        # 对遮蔽后的X执行softmax操作,并将形状还原为原始形状
        return nn.functional.softmax(X.reshape(shape), dim=-1)


print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])))

在这里插入图片描述

print(masked_softmax(torch.rand(2,2,4), torch.tensor([[1,3],[2,4]])))

在这里插入图片描述





拓展到高纬度

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述




Additive Attention(加性注意力)

(拓展到多维)

在这里插入图片描述


可学参数

在这里插入图片描述
等价于将key和query合并起来后放入到一个隐藏大小为h输出大小为1的单隐藏层MLP。

它的好处是:key、value、query可以是任意的长度。




加性注意力代码

需要学习三个参数:key_size, query_size, num_hiddens

class AdditiveAttention(nn.Module):
    """加性注意力"""

    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        # 用于生成注意力分数的线性变换
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        # queries的形状:(batch_size,查询的个数,num_hiddens),key是(batch_size,键的数目,num_hiddens)
        # 两者不能直接相加
        queries, keys = self.W_q(queries), self.W_k(keys)
        # 执行加性操作,将查询和键相加
        # queries加一维进去,变成了(batch_size,查询的个数,1,num_hiddens),key在第一维加一个维度,变成了(batch_size,1,键的数目,num_hiddens)
        # 最后features变成了(batch_size,number_querys,number_keys,num_hiddens)
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # features的形状:(batch_size,number_querys,number_keys,1)
        # 使用线性变换生成注意力分数,并将最后一维的维度压缩掉
        scores = self.w_v(features).squeeze(-1)
        # 使用遮蔽softmax计算注意力权重
        self.attention_weights = masked_softmax(scores, valid_lens)
        # 根据注意力权重对values进行加权求和
        return torch.bmm(self.dropout(self.attention_weights), values)



演示一下AdditiveAttention类

# queries是一个批量大小为2,1个query,query长度为20
# keys是一个批量大小为2,10个key,key的长度为2
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# repeat(2, 1, 1)沿着第一个维度重复两次(共两个)
# values是一个批量大小为2,10个value,value的长度为4
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
# 第一个样本看前两个,第二个样本看前6个
valid_lens = torch.tensor([2, 6])
# 创建加性注意力对象
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
# 调用加性注意力对象的forward方法
print(attention(queries, keys, values, valid_lens))



该部分总代码

import math
import torch
from torch import nn
from d2l import torch as d2l


def masked_softmax(X, valid_lens):
    """通过在最后一个轴上遮蔽元素来执行softmax操作"""
    if valid_lens is None:
        # 如果valid_lens为空,则对X执行softmax操作
        return nn.functional.softmax(X, dim=-1)
    else:
        # shape的形状为(2,2,4)
        shape = X.shape
        # 判断有效长度是否是一维的
        if valid_lens.dim() == 1:
            # valid_lens重复两次[2,3]→[2,2,3,3],和x的列数一样
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # 将valid_lens重塑为一维向量
            valid_lens = valid_lens.reshape(-1)
        # 在X的最后一个维度(即:列)上进行遮蔽操作
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        # 对遮蔽后的X执行softmax操作,并将形状还原为原始形状
        return nn.functional.softmax(X.reshape(shape), dim=-1)


# 加性注意力
class AdditiveAttention(nn.Module):
    """加性注意力"""

    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        # 用于生成注意力分数的线性变换
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)


queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
print(attention(queries, keys, values, valid_lens))

在这里插入图片描述


注意力权重

# 调用d2l.show_heatmaps函数,显示注意力权重的热图
d2l.show_heatmaps(attention.attention_weights.reshape((1,1,2,10)),
                 xlabel='Keys', ylabel='Queries')

在这里插入图片描述




Scaled Dot-Product Attention(缩放点积注意力)

如果query和key都是同样的长度,q、k∈ R d R^d Rd,那么可以:

在这里插入图片描述

向量化版本(拓展到多维)

在这里插入图片描述




缩放点积注意力代码

好处是不需要学习参数

class DotProductAttention(nn.Module):
    """缩放点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        # Dropout层,用于随机丢弃一部分注意力权重
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        # 获取查询向量的维度d
        d = queries.shape[-1]
        # 计算点积注意力得分,并进行缩放
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        # 使用遮蔽softmax计算注意力权重
        self.attention_weights = masked_softmax(scores, valid_lens)
        # 根据注意力权重对values进行加权求和
        return torch.bmm(self.dropout(self.attention_weights), values)



演示一下DotProductAttention类

queries = torch.normal(0,1,(2,1,2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
# 调用缩放点积注意力对象的forward方法
attention(queries, keys, values, valid_lens)



该部分总代码

import math
import torch
from torch import nn
from d2l import torch as d2l


def masked_softmax(X, valid_lens):
    """通过在最后一个轴上遮蔽元素来执行softmax操作"""
    if valid_lens is None:
        # 如果valid_lens为空,则对X执行softmax操作
        return nn.functional.softmax(X, dim=-1)
    else:
        # shape的形状为(2,2,4)
        shape = X.shape
        # 判断有效长度是否是一维的
        if valid_lens.dim() == 1:
            # valid_lens重复两次[2,3]→[2,2,3,3],和x的列数一样
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # 将valid_lens重塑为一维向量
            valid_lens = valid_lens.reshape(-1)
        # 在X的最后一个维度(即:列)上进行遮蔽操作
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        # 对遮蔽后的X执行softmax操作,并将形状还原为原始形状
        return nn.functional.softmax(X.reshape(shape), dim=-1)


class DotProductAttention(nn.Module):
    """缩放点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        # Dropout层,用于随机丢弃一部分注意力权重
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        # 获取查询向量的维度d
        d = queries.shape[-1]
        # 计算点积注意力得分,并进行缩放
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        # 使用遮蔽softmax计算注意力权重
        self.attention_weights = masked_softmax(scores, valid_lens)
        # 根据注意力权重对values进行加权求和
        return torch.bmm(self.dropout(self.attention_weights), values)


# keys是一个批量大小为2,10个key,key的长度为2
queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.ones((2, 10, 2))
# repeat(2, 1, 1)沿着第一个维度重复两次(共两个)
# values是一个批量大小为2,10个value,value的长度为4
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])
# 创建缩放点积注意力对象
attention = DotProductAttention(dropout=0.5)
# 设置为评估模式,不使用dropout
attention.eval()
# 调用缩放点积注意力对象的forward方法
print(attention(queries, keys, values, valid_lens))

在这里插入图片描述



注意力权重

# 调用d2l.show_heatmaps函数,显示注意力权重的热图
d2l.show_heatmaps(attention.attention_weights.reshape((1,1,2,10)),
                 xlabel='Keys', ylabel='Queries')

在这里插入图片描述




总结

注意力分数是query和key的相似度,注意力权重分数的softmax结果。

两种常见的分数计算
    将query和key合并起来进入一个单输出单隐藏的MLP。(加性注意力)
    直接将query和key做内积。(缩放点积注意力)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2194751.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

35-搜索插入位置

题目:35. 搜索插入位置 - 力扣&#xff08;LeetCode&#xff09; 思想:常规的二分&#xff0c;很简单的理解是当大于所有数时要在right的边界1插入 class Solution { public:int searchInsert(vector<int>& nums, int target) {int left 0;int right nums.size()-…

“高效解决PL/SQL Developer软件过期问题的方法“

目录 背景&#xff1a; 解决方法&#xff1a; 方法1&#xff1a;删除注册表信息 ​方法2&#xff1a;使用注册码 那种方式更好&#xff1a; 背景&#xff1a; 前段时间&#xff0c;由于项目需要&#xff0c;我下载了PL/SQL Developer这款强大集成开发环境&#xff0c;(ID…

白嫖EarMaster Pro 7简体中文破解版下载永久激活

EarMaster Pro 7 简体中文破解版功能介绍 俗话说得好&#xff0c;想要成为音乐家&#xff0c;就必须先拥有音乐家的耳朵&#xff0c;相信很多小伙伴都已经具备了一定的音乐素养&#xff0c;或者是说想要进一步得到提升。那我们就必须练好听耳的能力&#xff0c;并且把这种能力…

【新人系列】Python 入门(一):介绍及环境搭建

✍ 个人博客&#xff1a;https://blog.csdn.net/Newin2020?typeblog &#x1f4dd; 专栏地址&#xff1a;https://blog.csdn.net/newin2020/category_12801353.html &#x1f4e3; 专栏定位&#xff1a;为 0 基础刚入门 Python 的小伙伴提供详细的讲解&#xff0c;也欢迎大佬们…

设计模式之适配器模式(通俗易懂--代码辅助理解【Java版】)

文章目录 设计模式概述1、适配器模式2、适配器模式的使用场景3、优点4、缺点5、主要角色6、代码示例1&#xff09;UML图2&#xff09;源代码&#xff08;1&#xff09;定义一部手机&#xff0c;它有个typec口。&#xff08;2&#xff09;定义一个vga接口。&#xff08;3&#x…

传奇服务端快捷助手

定位传奇各目录&#xff0c;一键打开各配置文件。<br>收纳引擎、端口配置检查&#xff08;批量&#xff09;、路径配置、文本搜索、文件同步、一键重载&#xff08;跨桌面&#xff09;、命令管理 参考资料 传奇服务端快捷助手2024-06-20 - 工具软件程序 - 51开发者联盟 -…

JVM 内存区域 堆

堆是JVM中相当核心的内容&#xff0c;因为堆是JVM中管理的最大一块内存区域&#xff0c;大部分的GC也发生在堆区&#xff0c;那接下来就让深入地探究一下JVM中的堆结构。 需要明确&#xff0c;一个JVM实例只存在一个堆内存&#xff0c;堆区在JVM启动的时候就被创建&#xff0c…

在线教育系统开发:SpringBoot技术实战

3系统分析 3.1可行性分析 通过对本微服务在线教育系统实行的目的初步调查和分析&#xff0c;提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本微服务在线教育系统采用SSM框架&#xff0c;JAVA作为开…

Vue2电商平台(五)、加入购物车,购物车页面

文章目录 一、加入购物车1. 添加到购物车的接口2. 点击按钮的回调函数3. 请求成功后进行路由跳转(1)、创建路由并配置路由规则(2)、路由跳转并传参(本地存储) 二、购物车页面的业务1. uuid生成用户id2. 获取购物车数据3. 计算打勾商品总价4. 全选与商品打勾(1)、商品全部打勾&a…

Nature 正刊!树木多样性促进天然林土壤碳氮的固存

本文首发于“生态学者”微信公众号&#xff01; 2023年4月26日&#xff0c;《Nature》杂志在线发表了浙江农林大学陈信力教授、Scott X. Chang教授及湖首大学Han Y. H. Chen教授等合作的最新研究成果 “Tree diversity increases decadal forest soil carbon and nitrogen acc…

深入掌握 Golang 单元测试与性能测试:从零开始打造高质量代码!

在软件开发中&#xff0c;测试是保证代码质量、减少错误的重要环节。Golang 自带了强大的测试工具&#xff0c;不仅支持单元测试&#xff0c;还支持性能测试&#xff0c;让开发者可以轻松进行代码的测试和优化。本文将详细介绍如何在 Go 中进行单元测试和性能测试&#xff0c;帮…

Codeforces Round 969 (Div. 1) B. Iris and the Tree

题目 题解&#xff1a; #include <bits/stdc.h> using namespace std; #define int long long #define pb push_back #define fi first #define se second #define lson p << 1 #define rson p << 1 | 1 #define ll long long #define pii pair<int, int&…

(JAVA)开始熟悉 “二叉树” 的数据结构

1. 二叉树入门 ​ 符号表的增删查改操作&#xff0c;随着元素个数N的增多&#xff0c;其耗时也是线性增多的。时间复杂度都是O(n)&#xff0c;为了提高运算效率&#xff0c;下面将学习 树 这种数据结构 1.1 树的基本定义 ​ 树是我们计算机中非常重要的一种数据结构&#xf…

C语言刷题--数数一个数的二进制里有几个‘1’

先来看一下左移、右移 左移 右移 题目解答 1是一个特殊的数&#xff0c;二进制是000000000000000000000001(32位机器) 假如要判断的数是0&#xff08;二进制里面没有1&#xff09; 000000000000000000000000 & 000000000000000000000001 结果为0&#xff1b; 假如要…

基于FPGA的多路视频缓存

对于多路视频传输的场合&#xff0c;需要正确设置同步。 uifdma_dbuf0 的写通道输出帧同步计数器直接接入 uifdma_dbuf0&#xff0c;uifdma_dbuf1, uifdma_dbuf2, uifdma_dbuf3 的写通道同步计数输入。uifdma_dbuf0 的读通道&#xff0c;延迟 1 帧于 uifdma_dbuf0 的写通道帧计…

初入网络学习第一篇

引言 不磨磨唧唧&#xff0c;跟着学就好了&#xff0c;这个是我个人整理的学习内容梳理&#xff0c;学完百分百有收获。 1、使用的网络平台:eNSP 下载方法以及内容参考这篇文章 华为 eNSP 模拟器安装教程&#xff08;内含下载地址&#xff09;_ensp下载-CSDN博客https://b…

javaScript操作元素(9个案例+代码+效果)

目录 1.innerHTML 案例:使用innerHTML修改文本内容 1.代码 2.效果 2.innerText 案例:使用innerText修改文本 1.代码 2.效果 3.textContent 案例:使用textContent修改文本 1.代码 2.效果 4.通过style属性操作样式 案例:改变小球颜色 1.代码 2.效果 5.通过className属性操作样式 …

【Iceberg分析】Spark集成Iceberg采集输出

Spark集成Iceberg采集输出 文章目录 Spark集成Iceberg采集输出Iceberg提供了两类指标和提供了两类指标输出器ScanReportCommitReport LoggingMetricsReporterRESTMetricsReporter验证示例相关环境配置结果说明 Iceberg提供了两类指标和提供了两类指标输出器 ScanReport 包含在…

基于SpringBoot+Uniapp的家庭记账本微信小程序系统设计与实现

项目运行截图 展示效果图 展示效果图 展示效果图 展示效果图 展示效果图 5. 技术框架 5.1 后端采用SpringBoot框架 Spring Boot 是一个用于快速开发基于 Spring 框架的应用程序的开源框架。它采用约定大于配置的理念&#xff0c;提供了一套默认的配置&#xff0c;让开发者可以更…

Three.js基础内容(二)

目录 一、模型 1.1、组对象Group和层级模型(树结构) 1.2、递归遍历模型树结构、查询具体模型节点(楼房案例) 1.3、本地(局部)坐标和世界坐标 1.4、改变模型相对局部坐标原点位置 1.5、移除对象.remove() 1.6、模型隐藏与显示 二、纹理 2.1、创建纹理贴图(TextureLoade…