《动手学深度学习 Pytorch版》 10.2 注意力汇聚:Nadaraya-Watson 核回归

news2025/1/12 4:46:22
import torch
from torch import nn
from d2l import torch as d2l

1964 年提出的 Nadaraya-Watson 核回归模型是一个简单但完整的例子,可以用于演示具有注意力机制的机器学习。

10.2.1 生成数据集

根据下面的非线性函数生成一个人工数据集,其中噪声项 ϵ \epsilon ϵ 服从均值为 0 ,标准差为 0.5 的正态分布:

y i = 2 sin ⁡ x i + x i 0.8 + ϵ \boldsymbol{y}_i=2\sin{\boldsymbol{x}_i}+\boldsymbol{x}_i^{0.8}+\epsilon yi=2sinxi+xi0.8+ϵ

n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
n_test
50
def plot_kernel_reg(y_hat):  # 绘制训练样本
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

10.2.2 平均汇聚

先使用最简单的估计器来解决回归问题。基于平均汇聚来计算所有训练样本输出值的平均值:

f ( x ) = 1 n ∑ i = 1 n y i f(x)=\frac{1}{n}\sum^n_{i=1}y_i f(x)=n1i=1nyi

y_hat = torch.repeat_interleave(y_train.mean(), n_test)  # 计算平均并进行扩展
plot_kernel_reg(y_hat)


在这里插入图片描述

10.2.3 非参数注意力汇聚

相对于平均汇聚的忽略输入。Nadaraya 和 Watson 提出了一个更好的想法,根据输入的位置对输出 y i y_i yi 进行加权,即 Nadaraya-Watson 核回归:

f ( x ) = ∑ i = 1 n K ( x − x i ) ∑ j = 1 n K ( x − x j ) y i f(x)=\sum^n_{i=1}\frac{K(x-x_i)}{\sum^n_{j=1}K(x-x_j)}y_i f(x)=i=1nj=1nK(xxj)K(xxi)yi

将其中的核(kernel) K K K 根据上节内容重写为更通用的注意力汇聚公式:

f ( x ) = ∑ i = 1 n α ( x , x i ) y i f(x)=\sum^n_{i=1}\alpha(x,x_i)y_i f(x)=i=1nα(x,xi)yi

参数字典:

  • x x x 为查询

  • ( x i , y i ) (x_i,y_i) (xi,yi) 为键值对

  • α ( x , x i ) \alpha(x,x_i) α(x,xi) 为注意力权重(attention weight),即查询 x x x 和键 x i x_i xi 之间的关系建模,此权重被分配给对应值的 y i y_i yi

对于任何查询,模型在所有键值对注意力权重都是一个有效的概率分布: 非负的且和为1。

考虑高斯核(Gaussian kernel)以更好地理解注意力汇聚:

K ( u ) = 1 2 π exp ⁡ ( − u 2 2 ) K(u)=\frac{1}{\sqrt{2\pi}}\exp{(-\frac{u^2}{2})} K(u)=2π 1exp(2u2)

将高斯核代入上式可得:

f ( x ) = ∑ i = 1 n α ( x , x i ) y i = ∑ i = 1 n exp ⁡ ( − 1 2 ( x − x i ) 2 ) ∑ j = 1 n exp ⁡ ( − 1 2 ( x − x j ) 2 ) y i = ∑ i = 1 n s o f t m a x ( − 1 2 ( x − x i ) 2 ) y i \begin{align} f(x)=&\sum^n_{i=1}\alpha(x,x_i)y_i\\ =&\sum^n_{i=1}\frac{\exp{(-\frac{1}{2}(x-x_i)^2)}}{\sum^n_{j=1}\exp{(-\frac{1}{2}(x-x_j)^2)}}y_i\\ =&\sum^n_{i=1}\mathrm{softmax}\left(-\frac{1}{2}(x-x_i)^2\right)y_i \end{align} f(x)===i=1nα(x,xi)yii=1nj=1nexp(21(xxj)2)exp(21(xxi)2)yii=1nsoftmax(21(xxi)2)yi

如果一个键 x i x_i xi 越是接近给定的查询 x x x,那么分配给这个键对应值 y i y_i yi 的注意力权重就会越大,也就“获得了更多的注意力”。

上式是一个非参数的注意力汇聚(nonparametric attention pooling)模型。 接下来基于这个非参数的注意力汇聚模型绘制的预测结果的模型预测线是平滑的,并且比平均汇聚的预测更接近真实。

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)


在这里插入图片描述

观察注意力的权重可以发现,“查询-键”对越接近,注意力汇聚的注意力权重就越高。

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')


在这里插入图片描述

10.2.4 带参数的注意力汇聚

可以轻松地将可学习的参数集成到注意力汇聚中,例如,在下面的查询 x x x 和键 x i x_i xi 之间的距离乘以可学习参数 w w w

f ( x ) = ∑ i = 1 n α ( x , x i ) y i = ∑ i = 1 n exp ⁡ ( − 1 2 ( ( x − x i ) w ) 2 ) ∑ j = 1 n exp ⁡ ( − 1 2 ( ( x − x j ) w ) 2 ) y i = ∑ i = 1 n s o f t m a x ( − 1 2 ( ( x − x i ) w ) 2 ) y i \begin{align} f(x)=&\sum^n_{i=1}\alpha(x,x_i)y_i\\ =&\sum^n_{i=1}\frac{\exp{(-\frac{1}{2}((x-x_i)w)^2)}}{\sum^n_{j=1}\exp{(-\frac{1}{2}((x-x_j)w)^2)}}y_i\\ =&\sum^n_{i=1}\mathrm{softmax}\left(-\frac{1}{2}((x-x_i)w)^2\right)y_i \end{align} f(x)===i=1nα(x,xi)yii=1nj=1nexp(21((xxj)w)2)exp(21((xxi)w)2)yii=1nsoftmax(21((xxi)w)2)yi

10.2.4.1 批量矩阵乘法

假定两个张量的形状分别是 ( n , a , b ) (n,a,b) (n,a,b) ( n , b , c ) (n,b,c) (n,b,c),它们的批量矩阵乘法输出的形状为 ( n , a , c ) (n,a,c) (n,a,c)

X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape
torch.Size([2, 1, 6])

可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。

weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
weights.shape, values.shape, weights.unsqueeze(1).shape, values.unsqueeze(-1).shape, torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))
(torch.Size([2, 10]),
 torch.Size([2, 10]),
 torch.Size([2, 1, 10]),
 torch.Size([2, 10, 1]),
 tensor([[[ 4.5000]],
 
         [[14.5000]]]))

10.2.4.2 定义模型

class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1)
        # values的形状为(查询个数,“键-值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)

10.2.4.3 训练

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')  # 使用平方损失函数
trainer = torch.optim.SGD(net.parameters(), lr=0.5)  # 使用随机梯度下降
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

for epoch in range(5):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))


在这里插入图片描述

训练完带参数的注意力汇聚模型后可以发现:在尝试拟合带噪声的训练数据时,预测结果绘制的线不如之前非参数模型的平滑。

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)


在这里插入图片描述

与非参数的注意力汇聚模型相比, 带参数的模型加入可学习的参数后, 曲线在注意力权重较大的区域变得更不平滑。

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')


在这里插入图片描述

练习

(1)增加训练数据的样本数量,能否得到更好的非参数的 Nadaraya-Watson 核回归模型?

不能。

n_train_more = 500
x_train_more, _ = torch.sort(torch.rand(n_train_more) * 5)

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train_more = f(x_train_more) + torch.normal(0.0, 0.5, (n_train_more,))
x_test_more = torch.arange(0, 5, 0.01)
y_truth_more = f(x_test_more)

def plot_kernel_regv_more(y_hat_more):
    d2l.plot(x_test_more, [y_truth_more, y_hat_more], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train_more, y_train_more, 'o', alpha=0.5);

X_repeat_more = x_test_more.repeat_interleave(n_train_more).reshape((-1, n_train_more))
attention_weights_more = nn.functional.softmax(-(X_repeat_more - x_train_more)**2 / 2, dim=1)
y_hat_more = torch.matmul(attention_weights_more, y_train_more)
plot_kernel_regv_more(y_hat_more)


在这里插入图片描述

d2l.show_heatmaps(attention_weights_more.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')


在这里插入图片描述


(2)在带参数的注意力汇聚的实验中学习得到的参数 w w w 的价值是什么?为什么在可视化注意力权重时,它会使加权区域更加尖锐?

w w w 的价值在于放大注意力,也就是利用 softmax 函数的特性使键 x i x_i xi 和查询 x x x 距离小的得以保存,学习到的 w w w 就是掌握这个放大的尺度。

距离大的被过滤,当然也就显得更尖锐了。


(3)如何将超参数添加到非参数的Nadaraya-Watson核回归中以实现更好地预测结果?

加进去就能行。

n_train_test = 50
x_train_test, _ = torch.sort(torch.rand(n_train_test) * 5)

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train_test = f(x_train_test) + torch.normal(0.0, 0.5, (n_train_test,))
x_test_test = torch.arange(0, 5, 0.1)
y_truth_test = f(x_test_test)

def plot_kernel_regv_more(y_hat_test):
    d2l.plot(x_test_test, [y_truth_test, y_hat_test], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train_test, y_train_test, 'o', alpha=0.5);

X_repeat_test = x_test_test.repeat_interleave(n_train_test).reshape((-1, n_train_test))
attention_weights_test = nn.functional.softmax(-((X_repeat_test - x_train_test)*net.w.detach().numpy())**2 / 2, dim=1)  # 加入训练好的权重
y_hat_test = torch.matmul(attention_weights_test, y_train_test)
plot_kernel_regv_more(y_hat_test)


在这里插入图片描述


(4)为本节的核回归设计一个新的带参数的注意力汇聚模型。训练这个新模型并可视化其注意力权重。

不会,略。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1128018.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GoLong的学习之路(七)语法之slice(切片)

书接上回,上回书中写道:指针,并说明了基本引用类型分配内存new和特定情况下slice(切片),map,channel等集合函数的内存分配make。这篇文章就开始说明,slice。 文章目录 slice&#xf…

人生道路选择,恳请前辈指点,半路出家学习java?

人生道路选择,恳请前辈指点,半路出家学习java? 首先答案肯定是可以的。Java作为一门高级语言,它很优秀地屏蔽了许多繁枝末节。很多科班出身的人上来可能会先学C、C,要学会怎么管理内存等很底层的事情,而在开…

联想拯救者Y7000笔记本WiFi频繁掉线的坑

2023年10月的某一天开始,跟了我近4年的联想拯救者Y7000本本,无线网总是频繁的掉线,连上没几分钟就断开了,同办公室的其他电脑没这种情况出现,一开始以为是运营商网络问题,或者路由器问题导致的,…

多通道图片的卷积过程

多通道(channels)图片的卷积 如果输入图片是三维的(三个channel),例如(8,8,3),那么每一个filter的维度就是(3,3,3&#x…

一文彻底理解C语言中的指针

假定给你一块非常小的内存,这块内存只有8字节,这里也没有高级语言,没有操作系统,你操作的数据单位是单个字节,你该怎样读写这块内存呢? 注意这里的限定,再读一遍,没有高级语言&#…

rabbitmq-3.8.15集群、集群镜像模式安装部署

目录 一、环境 1、映射、域名、三墙 2、Erlang和socat安装(三台服务器都实行) 二、部署三台rabbitmq-3.8.15实例 1、rabbitmq官网下载地址 : 2、解压rabbitmq 3、添加系统变量 4、启动web插件、启动rabbitmq 5、在rabbitmq1上添加用…

(PyTorch)PyTorch中的常见运算(*、@、Mul、Matmul)

1. 矩阵与标量 矩阵(张量)每一个元素与标量进行操作。 import torch a torch.tensor([1,2]) print(a1) >>> tensor([2, 3]) 2. 哈达玛积(Mul) 两个相同尺寸的张量相乘,然后对应元素的相乘就是这个哈达玛…

常见的芯片封装技术

两边出pin的封装 1、DIP封装 DIP封装(Dual In-line Package),也叫双列直插式封装技术,指采用双列直插形式封装的集成电路芯片,绝大多数中小规模集成电路均采用这种封装形式,其引脚数一般不超过100。DIP封装…

windows11录屏功能详解,记录你的精彩时刻

windows 11是微软最新推出的操作系统版本,拥有很多简单便捷的功能,包括内置的录屏工具,让用户可以轻松地录制屏幕内容。但是很多人不了解windows11录屏功能,本文将详细介绍windows 11录屏的三个方法,以及它们的优势和适…

HTTP图解基础知识

书:图解HTTP;分享书中学到的东西,内容很多,极具可玩性关键字:http,cookie,状态,头部字段,缓存,Etag 参考示例:https://zhuanlan.zhihu.com/p/…

ChatGPT在机器学习中的应用与实践

💂 个人网站:【工具大全】【游戏大全】【神级源码资源网】🤟 前端学习课程:👉【28个案例趣学前端】【400个JS面试题】💅 寻找学习交流、摸鱼划水的小伙伴,请点击【摸鱼学习交流群】 引言 随着人工智能技术…

【Gan教程 】 什么是变分自动编码器VAE?

名词解释:Variational Autoencoder(VAE) 一、说明 为什么深度学习研究人员和概率机器学习人员在讨论变分自动编码器时会感到困惑?什么是变分自动编码器?为什么围绕这个术语存在不合理的混淆?本文从两个角度…

docker搭建waline评论系统

我这里是给博客网站嵌入评论系统的 1.登录LeanCloud 国际版,没有账号可以注册个 链接:点击跳转 2.新建应用,选择开发版(免费),商用版每个月最低消费5美刀。 3.在设置-应用凭证里面将AppID、AppKey、Maste…

基于springboot+vue校园短期闲置资源置换平台051

大家好✌!我是CZ淡陌。一名专注以理论为基础实战为主的技术博主,将再这里为大家分享优质的实战项目,本人在Java毕业设计领域有多年的经验,陆续会更新更多优质的Java实战项目,希望你能有所收获,少走一些弯路…

YoloV8改进策略:独家原创,LSKA(大可分离核注意力)改进YoloV8,比Transformer更有效,包括论文翻译和实验结果

文章目录 摘要论文:《LSKA(大可分离核注意力):重新思考CNN大核注意力设计》1、简介2、相关工作3、方法4、实验5、消融研究6、与最先进方法的比较7、ViTs和CNNs的鲁棒性评估基准比较8、结论YoloV8官方结果改进一:测试结果摘要 本文给大家带来一种超大核注意力机制的改进方…

FFmpeg编译安装(windows环境)以及在vs2022中调用

文章目录 下载源码环境准备下载msys换源下载依赖源码位置 开始编译编译x264编译ffmpeg 在VS2022写cpp调用ffmpeg 下载源码 直接在官网下载压缩包 这个应该是目前(2023/10/24)最新的一个版本。下载之后是这个样子: 我打算添加外部依赖x264&a…

应用系统集成-概述

应用系统集成-概述 随着网络技术的发展和日益增长的软件复杂度,几乎已经不存在一个完全孤立的应用系统了,万物互联在应用层面就是系统互联,应用系统集成成为软件系统架构是需要考虑的核心问题之一。 基本介绍 应用系统集成面临的挑战 所有的…

【网络编程】一文带你搞懂HTTPS协议

文章目录 一、什么是HTTPS协议二、关于加密三、数据摘要 | 数据指纹 | 数字签名四、HTTPS的工作过程探究方案1:只使用对称加密方案2:只使用非对称加密方案3:双方都使用非对称加密方案4:非对称加密 对称加密中间人攻击 五、引入证…

重要功能更新:妙手正式接入SHEIN供货模式(OBM)店铺,赋能卖家把握出海新机遇!

继接入SHEIN平台模式店铺之后,妙手ERP积极响应卖家需求,正式接入SHEIN供货模式(OBM)店铺,并支持产品采集、批量刊登、产品管理等功能,帮助跨境卖家快速上品、高效运营,把握出海新机遇。 SHEIN供…

天软特色因子看板(2023.10 第11期)

该因子看板跟踪天软特色因子A05005(近一日单笔流通金额占比(%),该因子为近一个日单笔流通金额占比因子,用以刻/画股票在收盘时,力资金在总交易金额中所占的比重。 今日为该因子跟踪第11期,跟踪其在SW801130 (申万纺织服装) 中的表…