【动手学深度学习-Pytorch版】注意力汇聚:Nadaraya-Watson 核回归

news2024/12/26 0:04:40

在这里插入图片描述

注意力机制中的查询、键、值

在注意力机制的框架中包含了键、值与查询三个主要的部分,其中键与查询构成了注意力汇聚(有的也叫作注意力池化)。
键是指一些非意识的线索,例如在序列到序列的学习中,特别是机器翻译,键是指除了文本序列自身信息外的其他信息,例如人工翻译或者语言学习情况。
查询则是与键(非意识提示)相反的,它常被称为意识提示或者自主提示。这体现在文本序列翻译中,则是文本序列的context上下文信息,该上下文信息包含了词元与词元之间的自主线索。
值是通过设计注意力的汇聚方式,将给定的查询与键进行匹配,得出的最匹配的值的信息。

平均汇聚

平均汇聚是对输入进行加权取平均值,其中各输入的权重保持平衡。下面是d2l给出的一个实例:

导包

import torch
from d2l import torch as d2l

可视化注意力权重

#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);

示例

attention_weights = torch.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')

结果

在这里插入图片描述

Nadaraya-Watson 核回归

具体来说,1964年提出的Nadaraya-Watson核回归模型 是一个简单但完整的例子,可以用于演示具有注意力机制的机器学习。

导包

import matplotlib.pyplot as plt
import torch
from torch import nn
from d2l import torch as d2l

生成数据集

给定“输入-输出”数据集 ( x 1 , y 1 ) , . . . ( x n , y n ) {(x_1,y1),...(x_n,y_n)} (x1,y1),...(xn,yn),其Y通过以下函数生成,其中包含了噪声ε(服从均值为0和标准差为0.5的正态分布):
在这里插入图片描述
训练样本数 = 测试样本数 = 50;训练样本通过torch.sort()排序后输出,结果含有噪声;而测试样本采用的是不含噪声点。

# 生成数据集
n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本

with  open('D://pythonProject//f-write//Nadaraya-forward-x_train_sort.txt', 'w') as f:
    f.write(str(x_train))

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数

print('y_train_shape: ',y_train.size())
# y_train_shape:  torch.Size([50])

可视化注意力权重函数

#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap='Reds'):
    print('Shape of matrices is',matrices.shape)
    # Shape of matrices is torch.Size([1, 1, 50, 50])

    """其输入matrices的形状是(要显示的行数,要显示的列数,查询的数目,键的数目)"""
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize, sharex=True, sharey=True,
                                 squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.5)
    plt.show()

绘制训练样本、真实数据生成函数以及预测函数

def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

实现平均汇聚

平均汇聚是一种最简单的估计器,它是直接计算所有训练样本输出值的平均值:
在这里插入图片描述
但是通过绘制图像上的对比发现:平均汇聚的图像与真实值之间存在极大的偏差:

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

在这里插入图片描述

使用非参数注意力汇聚

这里采用的注意力函数为核函数,具体表达式为:
在这里插入图片描述
这里的K即为核函数,基于上述的核函数的启发,可以根据下图注意力机制框架的角度重写,成为一个更加通用的注意力汇聚公式。
在这里插入图片描述

通用的注意力公式:
在这里插入图片描述
其中 X X X代表查询, ( x i , y i ) (x_i,y_i) (xi,yi)是键值对。注意力汇聚是对值的加权平均。α是根据查询与键形成的注意力权重,下面将会利用高斯核作为注意力权重。
在这里插入图片描述
给定一个键 x i x_i xi,如果它接近于给定的查询 X X X,则分配给 Y i Y_i Yi的权重越大。下面是根据此非参数的注意力汇聚形成的预测结果:
在这里插入图片描述
绘制代码:

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
print('Size of X_repeat is:',X_repeat.shape)
# Size of X_repeat is: torch.Size([50, 50])

# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)

print('Size of attention_weights is:',attention_weights.shape)
# Size of attention_weights is: torch.Size([50, 50])

# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

测试数据的输入相当于查询,而训练数据的输入相当于键,下面可以使用热力图发现“查询-键”对越接近,注意力汇聚的注意力权重α越高。


print('attention_weights.unsqueeze(0).unsqueeze(0)',attention_weights.unsqueeze(0).unsqueeze(0).shape)
# attention_weights.unsqueeze(0).unsqueeze(0) torch.Size([1, 1, 50, 50])

# unsqueeze(0)两次相当于增加两次第一个维度
show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

在这里插入图片描述

带参数的注意力汇聚

与非参数的注意力汇聚不同,带参数的注意力汇聚在查询和键 x i x_i xi之间加入了可学习参数w:w控制的是高斯核的窗口大小,可以通过W控制曲线平滑一点或者不平滑。

小批量矩阵的乘法

# 实现小批量矩阵的乘法
X = torch.ones((2,1,4))
Y = torch.ones((2,4,6))
print('小批量矩阵乘法测试:',torch.bmm(X,Y).shape)
# 小批量矩阵乘法测试: torch.Size([2, 1, 6])

在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。

"""利用小批量矩阵乘法计算小批量数据中的加权平均值"""
weights = torch.ones((2,10)) * 0.1
values = torch.arange(20.0).reshape((2,10))
print('Init shape of weights: ',weights.shape,'Init shape of values: ',values.shape)
# Init shape of weights:  torch.Size([2, 10]) Init shape of values:  torch.Size([2, 10])

weights = weights.unsqueeze(1)
values = values.unsqueeze(-1)
print('New shape of weights: ',weights.shape,'New shape of values: ',values.shape)
# New shape of weights:  torch.Size([2, 1, 10]) New shape of values:  torch.Size([2, 10, 1])

res = torch.bmm(weights,values)
print('加权平均值示例:',res)
# 加权平均值示例: tensor([[[ 4.5000]],
#
#         [[14.5000]]])

定义模型

该模型中直接使用的是上述提到的核,同时在进入模型初始化函数中,查询的初始化大小为测试数据输入大小,而键与值是成对出现的,所以两者大小是相同的。后面将查询的大小更改成(查询个数,“键-值”对数),注意力权重也是该形状。

"""定义模型"""
# w控制的是高斯核的窗口大小,可以通过W控制曲线平滑一点或者不平滑
class NWKernelRegression(nn.Module):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,),requires_grad=True))

    def forward(self,queries,keys,values):
        with  open('D://pythonProject//f-write//Nadaraya-forward-init-queries.txt', 'w') as f:
            f.write(str(queries.shape))
        with  open('D://pythonProject//f-write//Nadaraya-forward-init-keys.txt', 'w') as f:
            f.write(str(keys.shape))
        with  open('D://pythonProject//f-write//Nadaraya-forward-init-values.txt', 'w') as f:
            f.write(str(values.shape))
        # init-queries ---> torch.Size([50])
        # init-keys ---> torch.Size([50, 49])
        # init-values ---> torch.Size([50, 49])

        # queries和attention_weights的形状为(查询个数,“键-值”对数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1,keys.shape[1]))
        with  open('D://pythonProject//f-write//Nadaraya-forward-changed-queries.txt', 'w') as f:
            f.write(str(queries.shape))
        # changed-queries ---> torch.Size([50, 49])

        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2/2,dim=1)
        # attention的形状也是(查询个数,“键-值”对个数)
        with  open('D://pythonProject//f-write//Nadaraya-forward--attention.txt', 'w') as f:
            f.write(str(self.attention_weights.shape))
        # forward-attention ---> torch.Size([50, 49])

        # values的形状为(查询个数,“键-值”对个数)
        with  open('D://pythonProject//f-write//Nadaraya-changed--values.txt', 'w') as f:
            f.write(str(values.unsqueeze(-1).shape))
        # torch.Size([50, 49, 1])
        return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)

训练

重点:keys的形状:(‘n-train’ , ‘n-train’ -1)与values的形状:(‘n-train’ , ‘n-train’ -1)如何理解?
为了说明该问题,下面代码将生成key与value的关键代码拿出,并根据其矩阵相乘的思路降维形成另外一个示例代码:

import numpy as np
import torch

x_train,_ = torch.sort(torch.rand(3) * 5)
X_title = x_train.repeat((3,1))
print('X_title',X_title)
print(X_title.shape)
print(type(X_title))
print(1-torch.eye(3))
keys = X_title[(1-torch.eye(3)).type(torch.bool)].reshape((3,-1))
print('keys:',keys)
print(keys.shape)

arr1 = np.array([[1, 2, 3], [4, 5, 6],[7,8,9]])

x = torch.from_numpy(arr1)
print('x.size',x.size())

arr2 = np.array([[0, 1, 1], [1, 0, 1],[1,1,0]])
y = torch.from_numpy(arr2)
print('y.size',y.size())

print(x[y.type(torch.bool)].reshape((3,-1))  )

在原来代码里面key是通过
X_title[(1-torch.eye(3)).type(torch.bool)].reshape((3,-1))形成的,但是高维度的数据分析起来有点繁杂,这里进行了降维,其中x代表的是原来的X_tiley代表的是原有的1-torch.eye(3)——这里通过打印后可以发现y1-torch.eye(3)只是维度上的区别,元素规律保持不变,即都是一个对角矩阵。通过打印可知:
X_title[(1-torch.eye(3)).type(torch.bool)].reshape((3,-1))的操作是将两个矩阵进行点乘,并且最后reshape点乘后的矩阵,因为
1-torch.eye(3)中恰含有一列的零元素,所以点乘后二维元素少了一列,这就是[50,49]的来源。

"""训练"""
# 将训练数据集变换为键和值用于训练注意力模型
# 任何一个训练样本的输入都会和除了自身以外的其他训练样本的键值对进行计算 从而得到其对应的预测输出
# X_tile的形状: (n_train,n_train) 每一个行都包含着相同的训练输入
X_title = x_train.repeat((n_train,1))
with  open('D://pythonProject//f-write//Nadaraya-X_title.txt', 'w') as f:
    f.write(str(X_title.shape))
# X_title ---> torch.Size([50,50]

# Y_tile的形状: (n_train,n_train) 每一个行都包含着相同的训练输出
Y_title = y_train.repeat((n_train,1))
with  open('D://pythonProject//f-write//Nadaraya-Y_title.txt', 'w') as f:
    f.write(str(Y_title.shape))
# Y_title ---> torch.Size([50, 50])

# keys的形状:('n-train' , 'n-train' -1)
# keys ---> torch.Size([50, 49])
with  open('D://pythonProject//f-write//Nadaraya-torch.eye(n_train).txt', 'w') as f:
    f.write(str(torch.eye(n_train).size()))

res_eye = (1-torch.eye(n_train)).type(torch.bool)
print(res_eye.size())

test_res = X_title[(1-torch.eye(n_train)).type(torch.bool)]

with  open('D://pythonProject//f-write//Nadaraya-1-torch.eye(n_train).txt', 'w') as f:
    f.write(str(test_res))

"""---将X_title与1-torch.eye(n_train)哈达玛积后元素个数少了一列"""
keys = X_title[(1-torch.eye(n_train)).type(torch.bool)].reshape((n_train,-1))
with  open('D://pythonProject//f-write//Nadaraya-keys.txt', 'w') as f:
    f.write(str(keys.shape))
# values ---> torch.Size([50, 49])
# values的形状:('n-train' , 'n-train' -1)
values = Y_title[(1-torch.eye(n_train)).type(torch.bool)].reshape((n_train,-1))
with  open('D://pythonProject//f-write//Nadaraya-values.txt', 'w') as f:
    f.write(str(values.shape))
"""训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降"""
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

for epoch in range(5):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))

绘制散点图与热力图

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)
# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1038197.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Python+Django的热门旅游景点数据分析系统的设计与实现(源码+lw+部署文档+讲解等)

前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 👇🏻…

JAVA:实现Excel和PDF上下标

1、简介 最近项目需要实现26个小写字母的上下标功能,自己去网上找了所有Unicode的上下标形式,缺少一些关键字母,顾后面考虑自己创建上下标字体样式,以此来记录。 2、Excel Excel本身是支持上下标,我们可以通过Excel单元格的样式来设置当前字体上下标,因使用的是POI的m…

YOLOv5:修改backbone为ACMIX

YOLOv5:修改backbone为ACMIX 前言前提条件相关介绍ACMIXYOLOv5修改backbone为ACMIX修改common.py修改yolo.py修改yolov5.yaml配置 参考 前言 记录在YOLOv5修改backbone操作,方便自己查阅。由于本人水平有限,难免出现错漏,敬请批评…

【软件设计师-从小白到大牛】上午题基础篇:第四章 法律法规与知识产权

文章目录 前言章节提要一、保护期限真题链接 二、知识产权人确定真题链接 三、侵权判定真题链接 四、标准化基础知识 前言 ​ 本系列文章为观看b站视频以及b站up主zst_2001系列视频所做的笔记,感谢相关博主的分享。如有侵权,立即删除。 视频链接&#xf…

Qt5开发及实例V2.0-第二十三章-Qt-多功能文档查看器实例

Qt5开发及实例V2.0-第二十三章-Qt-多功能文档查看器实例 第23章 多功能文档查看器实例23.1. 简介23.2. 界面与程序框架设计23.2.1. 图片资源23.2.2. 网页资源23.2.3. 测试用文件 23.3 主程序代码框架23.4 浏览网页功能实现23.4.1 实现HtmIHandler处理器 23.5. 部分代码实现23.5…

git 本地分支基础操作

(1)建立分支 a:基于某个commit建立分支 然后切换 git branch test_branch 6435675ad32c035ed4d9cb6c351de5cbaecddd99 git checkout test_branchb: git checkout 建立分支然后切换 git checkout -b checkout_branchc:分支建立 然后切换 git branch …

【Amazon】AI 代码生成器—Amazon CodeWhisperer初体验 | 开启开挂编程之旅

使用 AI 编码配套应用程序更快、更安全地构建应用程序 文章目录 1.1 Amazon CodeWhisperper简介1.2 Amazon CodeWhisperer 定价2.1 打开VS Code2.2 安装AWS ToolKit插件 一、前言 1.1 Amazon CodeWhisperper简介 1️⃣更快地完成更多工作 CodeWhisperer 经过数十亿行代码的训…

网络分层模型和常见协议介绍

文章目录 网络分层模型和常见协议介绍网络分层模型介绍常见各层协议介绍 网络分层模型和常见协议介绍 理解性记忆:这是我自己创造的一个理解性记忆口诀,大家别笑我😄 七层:因为七层协议并没有得到应用,所以物&#xff…

【算法】相向双指针

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kuan 的首页,持续学…

【GO】LGTM_Grafana_gozero_配置trace(4)_代码实操及追踪

最近在尝试用 LGTM 来实现 Go 微服务的可观测性,就顺便整理一下文档。 Tempo 会分为 4 篇文章: Tempo 的架构官网测试实操跑通gin 框架发送 trace 数据到 tempogo-zero 微服务框架发送数据到 tempo 本文就是写一下如何在 go-zero 微服务框架里面配置 t…

套接字socket编程的基础知识点

目录 前言(必读) 网络字节序 网络中的大小端问题 为什么网络字节序采用的是大端而不是小端? 网络字节序与主机字节序之间的转换 字符串IP和整数IP 整数IP存在的意义 字符串IP和整数IP相互转换的方式 inet_addr函数(会自…

83、SpringBoot --- 下载和安装 MSYS2、 Redis

★ 下载和安装MSYS2(作用:可在Windows模拟一个Linux的编译环境) 得到Redis的编译环境——在Linux平台上,这一步可以省略。(1)登录MSYS2官网(http://repo.msys2.org/distrib/ )下载M…

前端新轮子Nue,号称替代Vue、React和Svelte

新的简约前端开发工具集Nue.js 于周三发布。在 Hacker News 上介绍它时,前端开发者和Nue.js 的创作者Tero Piirainen表示,它是 React、Vue、Next.js、Vite、Svelte 和 Astro 的替代品。他在 Nue.js的 FAQ 中进一步解释说,它是为网站和响应式用…

【Vue.js】使用Element搭建登入注册界面axios中GET请求与POST请求跨域问题

一,ElementUI是什么? Element UI 是一个基于 Vue.js 的桌面端组件库,它提供了一套丰富的 UI 组件,用于构建用户界面。Element UI 的目标是提供简洁、易用、美观的组件,同时保持灵活性和可定制性 二,Element…

Spring学习笔记6 Bean的实例化方式

Spring学习笔记5 GoF之工厂模式_biubiubiu0706的博客-CSDN博客 Spring为Bean提供了多种实例化方式,通常包括4中(目的:更加灵活) 1.通过构造方法实例化 2.通过简单工厂模式实例化 3.通过factory-bean实例化 4.通过FactoryBean接口实例化 新建模块 spring-005 依赖 <!--S…

自动化测试、压力测试、持续集成

因为项目的原因&#xff0c;前段时间研究并使用了 SoapUI 测试工具进行自测开发的 api。下面将研究的成果展示给大家&#xff0c;希望对需要的人有所帮助。 SoapUI 是什么&#xff1f; SoapUI 是一个开源测试工具&#xff0c;通过 soap/http 来检查、调用、实现 Web Service 的…

github想传至远程仓库显示fatal: remote origin already exists. (远程来源已经存在 解决办法)

参考:https://blog.csdn.net/qq_40428678/article/details/84074207 在当我们输入git remote add origin https://gitee.com/(github/码云账号)/(github/码云项目名).git 就会报如下的错 fatal: remote origin already exists. 翻译过来就是&#xff1a;致命&#xff1a;远程…

zabbix自定义监控、钉钉、邮箱报警 (五十六)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 一、实验准备 二、安装 三、添加监控对象 四、添加自定义监控项 五、监控mariadb 1、添加模版查看要求 2、安装mariadb、创建用户 3、创建用户文件 4、修改监控模版 5、…

新版Chromedriver在哪下载(Chromedriver 116.0.5845.188的寻找之旅)

不知道什么时候Chrome自动升级到116.0.5845.188了&#xff0c;害得我原来的Chromedriver 114无法使用了&#xff0c;无奈之下只好重新去下载。 可寻遍网络&#xff0c;都没找到Chromedriver116的版本。网上大多网友给的下载网址是chromedriver.storage.googleapis.com/index.ht…

数据结构与算法之时间复杂度和空间复杂度(C语言版)

1. 时间复杂度 1.1 概念 简而言之&#xff0c;算法中的基本操作的执行次数&#xff0c;叫做算法的时间复杂度。也就是说&#xff0c;我这个程序执行了多少次&#xff0c;时间复杂度就是多少。 比如下面这段代码的执行次数&#xff1a; void Func1(int N) {int count 0;for…