89. 注意力机制以及代码实现Nadaraya-Waston 核回归

news2024/12/23 1:34:49

1. 心理学

  • 动物需要在复杂环境下有效关注值得注意的点
  • 心理学框架:人类根据随意线索和不随意线索选择注意点

随意:随着自己的意识,有点强调主观能动性的意味。

在这里插入图片描述

2. 注意力机制

在这里插入图片描述

2. 非参注意力池化层

在这里插入图片描述

3. Nadaraya-Waston 核回归

在这里插入图片描述

4. 参数化的注意力机制

在这里插入图片描述

5. 总结

在这里插入图片描述

6. 代码实现注意力汇聚:Nadaraya-Waston 核回归

import torch
from torch import nn
from d2l import torch as d2l

6.1 生成数据集

在这里插入图片描述

n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本
def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
n_test

下面的函数将绘制所有的训练样本(样本由圆圈表示), 不带噪声项的真实数据生成函数 𝑓 (标记为“Truth”), 以及学习得到的预测函数(标记为“Pred”)。

def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

6.2 平均汇聚

先使用最简单的估计器来解决回归问题。 基于平均汇聚来计算所有训练样本输出值的平均值:

在这里插入图片描述

如下图所示,这个估计器确实不够聪明。 真实函数 𝑓 (“Truth”)和预测函数(“Pred”)相差很大。

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

在这里插入图片描述

6.3 非参数注意力汇聚

接下来,我们将基于这个非参数的注意力汇聚模型来绘制预测结果。 从绘制的结果会发现新的模型预测线是平滑的,并且比平均汇聚的预测更接近真实。

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

运行结果:

在这里插入图片描述

现在来观察注意力的权重。 这里测试数据的输入相当于查询,而训练数据的输入相当于键。 因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近, 注意力汇聚的注意力权重就越高。

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

运行结果:

在这里插入图片描述

6.4 带参数注意力汇聚

1. 批量矩阵乘法

因此,假定两个张量的形状分别是 (𝑛,𝑎,𝑏) 和 (𝑛,𝑏,𝑐) , 它们的批量矩阵乘法输出的形状为 (𝑛,𝑎,𝑐)

X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape

运行结果:

在这里插入图片描述
在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。

weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))

运行结果:

在这里插入图片描述

2. 定义模型

基于带参数的注意力汇聚,使用小批量矩阵乘法, 定义Nadaraya-Watson核回归的带参数版本为:

class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1)
        # values的形状为(查询个数,“键-值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)

3. 训练

接下来,将训练数据集变换为键和值用于训练注意力模型。 在带参数的注意力汇聚模型中, 任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算, 从而得到其对应的预测输出。

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降。

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

for epoch in range(5):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))

运行结果:

在这里插入图片描述
如下所示,训练完带参数的注意力汇聚模型后可以发现: 在尝试拟合带噪声的训练数据时, 预测结果绘制的线不如之前非参数模型的平滑。

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

运行结果:

在这里插入图片描述

为什么新的模型更不平滑了呢? 下面看一下输出结果的绘制图: 与非参数的注意力汇聚模型相比, 带参数的模型加入可学习的参数后, 曲线在注意力权重较大的区域变得更不平滑

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

运行结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/173722.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Downie4.6.4视频下载工具

前言 Downie是Mac下一个简单的下载管理器,可以让您快速将不同的视频网站上的视频下载并保存到电脑磁盘里然后使用您的默认媒体播放器观看它们。 下载 Downie 解压后直接安装 主要特点 支持许多网站目前支持超过1,000个不同的网站(包括YouTube&#…

Linux | 浅谈Shell运行原理【王婆竟是资本家】

文章目录💧Shell的运行原理👉Shell的基本概念与作用👉原理的展示与剖析👉Shell外壳感性理解【一门亲事】💧总结💧Shell的运行原理 👉Shell的基本概念与作用 Linux严格意义上说的是一个操作系统…

华为数字化转型之道 平台篇 第十三章 变革治理体系

第十三章 变革治理体系 约翰科特在《领导变革》一书中说:“变革的领导团队既需要管理能力,也需要领导能力,他们必须结合起来。 前面我们也谈到,数字化转型不仅是技术的创新,更是一项系统工程和企业真正的变革。企业要转型成功,既需要各个组织的积极参与和通力合作,又不…

深度学习中高斯噪声:为什么以及如何使用

在数学上,高斯噪声是一种通过向输入数据添加均值为零和标准差(σ)的正态分布随机值而产生的噪声。 正态分布,也称为高斯分布,是一种连续概率分布,由其概率密度函数 (PDF) 定义: pdf(x) (1/ (σ*sqrt(2*π))) *e^(- (x…

Task6:文本函数查找函数

文章目录一 文本函数1 Text函数2 mid函数3 replace函数二 查找函数1 Vlookup2 Xlookup一 文本函数 1 Text函数 作用:将数值转换为指定格式的文本 语法:TEXT(value,format_text) (1)转换为大写 消费日期转换为大写 TEXT(A2,”[DB…

第五届字节跳动青训营 前端进阶学习笔记(六)什么才是好的JavaScript代码

文章目录前言问题引入实现一个交通信号灯的状态切换1.基本实现2.状态封装实现3.职责分离实现求一个数是否是4的幂1.基本实现3.数学优化洗牌算法1.基本实现2.均匀算法实现总结前言 课程重点: 代码规范相关事项如何优化代码 问题引入 试看下面一段代码&#xff0c…

认识UDP、TCP协议

一、Socket 首先,我们需要了解一下socket。 在上一篇文章当中,我们了解了TCP-IP五层协议模型初识网络:IP、端口、网络协议、TCP-IP五层模型_革凡成圣211的博客-CSDN博客TCP/IP五层协议详解https://blog.csdn.net/weixin_56738054/article/det…

Crack:RadiAnt DICOM Viewer 2023.1 BETA #1300

RadiAnt DICOM Viewer 2023.1 BETA #1300 built on January 13, 2023 New features: Length ratio calculation. Ellipsoid / bullet volume calculation. Added option to color and/or pin specific items to top in the DICOM tags window. 多式DICOM的技术支持 该软件能够打…

是时候分享一波jenkins centos的安装了

1、下载注意:至少安装2.319的版本,否则插件安装失败,2.357 之后版本需要java11,请注意java版本a、开始下载,利用华为云地址 https://mirrors.huaweicloud.com/home,速度杠杠快搜索jenkins,点击j…

api接口对接如何实现,php如何对接api

这篇文章来分享下api接口对接如何实现,还有源码,希望对新手有所帮助。 什么是API? 我的回答:API( 应用程序编程接口):一般来说,这是一套明确定义的各种软件组件之间的通信方法。 什么是API&…

Android数据库之SharedPreferences、SQLite、Room

文章目录一、SharedPreferences二、SQLite三、Room使用Room进行增删改查Room数据库升级一、SharedPreferences 要想使用SharePreferences来存储数据,首先需要获取到SharedPreferences对象。Android中提供了三种方法用于得到SharedPreferences对象 1.Context类中的g…

20230119英语学习

Back to the Future 在故宫修钟表是种什么样的体验? After a year of complex restoration, specialists from the Palace Museum in Beijing have given a pair of antique pagoda clock automata a new lease on life. In the form of a nine-tiered pagoda, th…

单片机寄存器

单片机寄存器简述 1、单片机寄存器就是单片机片内存储器(片内RAM)一部分,每一个都有地址。只不过这几个寄存器有特殊的作用,比如指令:MUL AB,这条指令用到两个寄存器A,B进行乘法,结果存到BA里面,这条指令必…

Linux基本功系列之type命令实战

文章目录一. type命令介绍二. 语法格式及常用选项三. 参考案例3.1 查看别名3.2 查看是否是内建命令3.3 查看是否为关键字3.4 显示所有命令的位置3.5 判断当前命令是否为alias或者keyword等总结前言🚀🚀🚀 想要学好Linux,命令是基本…

4-2指令系统-指令的寻址方式

文章目录一.指令寻址1.顺序寻址2.跳跃寻址二.数据寻址1.隐含寻址2.立即(数)寻址3.直接寻址4.间接寻址5.寄存器寻址6.寄存器间接寻址7.相对寻址(程序浮动、转移指令)8.基址寻址(多道程序)9.变址寻址&#xf…

移动web字体图标

字体图标下载字体图标使用字体图标使用类名引入字体图标使用unicode编码(了解)在线字体图标使用伪元素字体图标小结下载字体图标 具体的步骤&#xff1a; 使用字体图标 引入相关文件 复制相关的文件&#xff0c;到 fonts文件夹里面。 引入 css <link rel"styleshe…

回溯法复习(总结篇)

根据课本上的学习要点梳理&#xff0c;“通用解题法”&#xff0c;可以系统的搜索一个问题的所有解、任一解&#xff0c;他是一个既带有系统性&#xff08;暴力遍历&#xff09;又带有跳跃性&#xff08;剪枝&#xff09;的搜索算法。 理解回溯法和深度优先搜索策略 回溯的本质…

Kafka入门与核心概念

前言在我们开发过程中&#xff0c;有一些业务功能比较耗时&#xff0c;但是又不是很重要的核心功能&#xff0c;最典型的场景就是注册用户以后发送激活邮件分为两步1&#xff1a;向数据库插入一条数据2&#xff1a;向注册用户发送邮件第2步其实并不是核心功能&#xff0c;但是发…

SpringMVC-拦截器

1&#xff0c;pringMVC-拦截器 对于拦截器这节的知识&#xff0c;我们需要学习如下内容: 拦截器概念入门案例拦截器参数拦截器工作流程分析 1.1 拦截器概念 讲解拦截器的概念之前&#xff0c;我们先看一张图: (1)浏览器发送一个请求会先到Tomcat的web服务器 (2)Tomcat服务…

字节青训前端笔记 | 响应式系统与 React

本节课为前端框架 React 的基础课程讲解 React的设计思路 UI编程的特点 状态更新的时候&#xff0c;UI不会自动更新&#xff0c;需要手动调用DOM接口进行更新欠缺基本的代码层面的封装和隔离&#xff0c;代码层面没有组件化UI之间的数据依赖关系&#xff0c;需要手动维护&am…