注意力汇聚:Nadaraya-Watson 核回归

news2024/12/28 2:36:03
  • Nadaraya-Watson核回归是具有注意力机制的机器学习范例。

  • Nadaraya-Watson核回归的注意力汇聚是对训练数据中输出的加权平均。从注意力的角度来看,分配给每个值的注意力权重取决于将值所对应的键和查询作为输入的函数。

  • 注意力汇聚可以分为非参数型和带参数型。

  • 参考10.2. 注意力汇聚:Nadaraya-Watson 核回归 — 动手学深度学习 2.0.0 documentation

框架下的注意力机制的主要成分 : 查询(自主提示)和(非自主提示)之间的交互形成了注意力汇聚; 注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。

本节将介绍注意力汇聚的更多细节, 以便从宏观上了解注意力机制在实践中的运作方式。 具体来说,1964年提出的Nadaraya-Watson核回归模型 是一个简单但完整的例子,可以用于演示具有注意力机制的机器学习。

pip install mxnet==1.7.0.post1
pip install d2l==0.15.0
from mxnet import autograd, gluon, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

1.生成数据集 

简单起见,考虑下面这个回归问题: 给定的成对的“输入-输出”数据集 {(x1,y1),…,(xn,yn)}, 如何学习f来预测任意新输入x的输出y^=f(x)?

根据下面的非线性函数生成一个人工数据集, 其中加入的噪声项为ϵ:

其中ϵ服从均值为0和标准差为0.5的正态分布。 在这里生成了50个训练样本和50个测试样本。 为了更好地可视化之后的注意力模式,需要将训练样本进行排序。

n_train = 50  # 训练样本数
x_train = np.sort(np.random.rand(n_train) * 5)   # 排序后的训练样本

def f(x):
    return 2 * np.sin(x) + x**0.8

y_train = f(x_train) + np.random.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = np.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
n_test

50

下面的函数将绘制所有的训练样本(样本由圆圈表示), 不带噪声项的真实数据生成函数f(标记为“Truth”), 以及学习得到的预测函数(标记为“Pred”)。

def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

 2.平均汇聚

 先使用最简单的估计器来解决回归问题。 基于平均汇聚来计算所有训练样本输出值的平均值:

如下图所示,这个估计器确实不够聪明。 真实函数f(“Truth”)和预测函数(“Pred”)相差很大。

y_hat = y_train.mean().repeat(n_test)
plot_kernel_reg(y_hat)

 3. 非参数注意力汇聚

显然,平均汇聚忽略了输入xi。 于是Nadaraya (Nadaraya, 1964)和 Watson (Watson, 1964)提出了一个更好的想法, 根据输入的位置对输出yi进行加权:

其中K是(kernel)。 公式 (10.2.3)所描述的估计器被称为 Nadaraya-Watson核回归(Nadaraya-Watson kernel regression)。 这里不会深入讨论核函数的细节, 但受此启发, 我们可以从 图10.1.3中的注意力机制框架的角度 重写 (10.2.3), 成为一个更加通用的注意力汇聚(attention pooling)公式:

 其中x是查询,(xi,yi)是键值对。 比较 (10.2.4)和 (10.2.2), 注意力汇聚是yi的加权平均。 将查询x和键xi之间的关系建模为 注意力权重(attention weight)α(x,xi), 如 (10.2.4)所示, 这个权重将被分配给每一个对应值yi。 对于任何查询,模型在所有键值对注意力权重都是一个有效的概率分布: 它们是非负的,并且总和为1。

为了更好地理解注意力汇聚, 下面考虑一个高斯核(Gaussian kernel),其定义为:

将高斯核代入 (10.2.4)和 (10.2.3)可以得到:

 在 (10.2.6)中, 如果一个键xi越是接近给定的查询x, 那么分配给这个键对应值yi的注意力权重就会越大, 也就“获得了更多的注意力”。 

值得注意的是,Nadaraya-Watson核回归是一个非参数模型。 因此, (10.2.6)是 非参数的注意力汇聚(nonparametric attention pooling)模型。 接下来,我们将基于这个非参数的注意力汇聚模型来绘制预测结果。 从绘制的结果会发现新的模型预测线是平滑的,并且比平均汇聚的预测更接近真实。

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = npx.softmax(-(X_repeat - x_train)**2 / 2)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = np.dot(attention_weights, y_train)
plot_kernel_reg(y_hat)

 现在来观察注意力的权重。 这里测试数据的输入相当于查询,而训练数据的输入相当于。 因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近, 注意力汇聚的注意力权重就越高。

#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.asnumpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);
show_heatmaps(np.expand_dims(np.expand_dims(attention_weights, 0), 0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

 

4.带参数注意力汇聚

非参数的Nadaraya-Watson核回归具有一致性(consistency)的优点: 如果有足够的数据,此模型会收敛到最优结果。 尽管如此,我们还是可以轻松地将可学习的参数集成到注意力汇聚中。

例如,与 (10.2.6)略有不同, 在下面的查询x和键xi之间的距离乘以可学习参数w:

本节的余下部分将通过训练这个模型 (10.2.7)来学习注意力汇聚的参数。

4.1.批量矩阵乘法 

为了更有效地计算小批量数据的注意力, 我们可以利用深度学习开发框架中提供的批量矩阵乘法。

假设第一个小批量数据包含n个矩阵X1,…,Xn, 形状为a×b, 第二个小批量包含n个矩阵Y1,…,Yn, 形状为b×c。 它们的批量矩阵乘法得到n个矩阵 X1Y1,…,XnYn, 形状为a×c。 因此,假定两个张量的形状分别是(n,a,b)和(n,b,c), 它们的批量矩阵乘法输出的形状为(n,a,c)。

X = np.ones((2, 1, 4))
Y = np.ones((2, 4, 6))
npx.batch_dot(X, Y).shape
(2, 1, 6)

在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。

weights = np.ones((2, 10)) * 0.1
values = np.arange(20).reshape((2, 10))
npx.batch_dot(np.expand_dims(weights, 1), np.expand_dims(values, -1))
array([[[ 4.5]],

       [[14.5]]])

 

4.2.定义模型

基于 (10.2.7)中的 带参数的注意力汇聚,使用小批量矩阵乘法, 定义Nadaraya-Watson核回归的带参数版本为:

class NWKernelRegression(nn.Block):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = self.params.get('w', shape=(1,))

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询数,“键-值”对数)
        queries = queries.repeat(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = npx.softmax(
            -((queries - keys) * self.w.data())**2 / 2)
        # values的形状为(查询数,“键-值”对数)
        return npx.batch_dot(np.expand_dims(self.attention_weights, 1),
                             np.expand_dims(values, -1)).reshape(-1)

4.3.训练

接下来,将训练数据集变换为键和值用于训练注意力模型。 在带参数的注意力汇聚模型中, 任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算, 从而得到其对应的预测输出。

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = np.tile(x_train, (n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = np.tile(y_train, (n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - np.eye(n_train)).astype('bool')].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - np.eye(n_train)).astype('bool')].reshape((n_train, -1))

训练带参数的注意力汇聚模型时,使用平方损失函数随机梯度下降

net = NWKernelRegression()
net.initialize()
loss = gluon.loss.L2Loss()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5})
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

for epoch in range(5):
    with autograd.record():
        l = loss(net(x_train, keys, values), y_train)
    l.backward()
    trainer.step(1)
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))

书上结果是这样的,代码复制过去运行的,我不理解

 

 

如下所示,训练完带参数的注意力汇聚模型后可以发现: 在尝试拟合带噪声的训练数据时, 预测结果绘制的线不如之前非参数模型的平滑。

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = np.tile(x_train, (n_test, 1))
# value的形状:(n_test,n_train)
values = np.tile(y_train, (n_test, 1))
y_hat = net(x_test, keys, values)
plot_kernel_reg(y_hat)

噢,啥玩意,预测啥也不是啊,书上是这样的

 

 

为什么新的模型更不平滑了呢? 下面看一下输出结果的绘制图: 与非参数的注意力汇聚模型相比, 带参数的模型加入可学习的参数后, 曲线在注意力权重较大的区域变得更不平滑。

d2l.show_heatmaps(np.expand_dims(
                      np.expand_dims(net.attention_weights, 0), 0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

 书上这样的,哎,回头再看看,看不出哪里的问题

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/189991.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

这5个Linux安全相关的命令,很实用!

1.ss ss更多人认为是“网络”命令,但该命令也可用于安全目的。 比如: ss:列出所有连接ss -a:列出侦听和非侦听端口ss -t:列出 TCP 连接 2.who who命令列出所有登录的人。 Linux 是多用户的,所以这个命…

VBA检查指定应用程序是否已经打开

VBA中提供了CreateObject和GetObject两种方法获得对象实例,二者的区别在于GetObject用于获取已经打开的应用程序对象,但是如果该应用程序并没有打开,那么将产生运行时错误,代码中需要加入额外的错误处理代码。 在任务管理器中可以…

第17讲:Python中元组的概念以及应用

文章目录1.元组的概念2.元组的基本使用2.1.定义一个元组2.2.定义一个空元组2.3.元组的元素是不可变的2.4.当元组中的元素是一个列表时列表中的元素可变2.5.当元组中只定义一个元素时的注意事项3.列表的所有操作同样适用于元组4.就是想修改元组中的某个元素1.元组的概念 Python…

CVE-2022-33980 Apache Commons Configuration 远程命令执行漏洞分析

漏洞描述 7月6日,Apache官方发布安全公告,修复了一个存在于Apache Commons Configuration 组件的远程代码执行漏洞,漏洞编号:CVE-2022-33980,漏洞威胁等级:高危。恶意攻击者通过该漏洞,可在目标…

3. 其他数仓/BI架构解析

文章目录1. 独立数据集市架构2. 辐射状企业信息工厂Inmon架构范式建模维度建模3. 混合辐射状架构与Kimball架构目前,经过长时间的演进,各种数仓架构之间的区别变得越来越小,且不论哪种数仓架构,都会涉及维度建模。下面是几种常见的…

《项目管理精华》读书笔记

《项目管理精华》读后感 《项目管理精华》书中用实例讲解了项目管理的基本流程及一些项目管理技巧,通俗易懂且让人有读下去的欲望。本书按照项目管理的五大过程组的顺序逻辑加上一些作者认为项目管理中比较有价值的内容来介绍的,十大知识领域穿插在这些…

nodejs+vue+elementui在线水果商城购物网站vscode项目

摘 要 1 Abstract 1 1 系统概述 4 1.1 概述 4 1.2课题意义 4 1.3 主要内容 4 2 系统开发环境 5 2.1 Vue技术介绍 5 端技术:nodejsvueelementui 前端:HTML5,CSS3、JavaScript、VUE 系统分为不同的层次:视图层&#…

哪个室内导航地图好用?画地图用什么软件好?

在互联网服务中,所有互联网巨头均斥巨资布局C端的地图服务,喜闻乐见的地图形式极大地方便了人们在日常生活中的查询位置、了解环境、导航等应用需求,因此取得了巨大的成功。但目前这些地图服务数据的颗粒度基本仅到城市级;园区、景…

CVE-2022-32532 认证绕过漏洞分析

前言 结合自身经历,在使用正则表达式去匹配流量特征时,由于正则表达式中元字符“.”不匹配换行符(\n、\r)导致一直提取不上所需的流量。而如今,之前踩过的坑却出现在了Apache Shiro框架之中… 漏洞描述 6月29日&…

《MySQL实战45讲》——学习笔记24 “Master/Slave主备同步机制“

binlog可以用来归档,也可以用来做主备同步,binlog在MySQL的各种高可用方案上扮演了重要角色;本篇主要介绍MySQL主备(M-S结构)的基本原理、不同格式binlog的优缺点和设计者的思考、MySQL双主结构和循环复制问题&#xf…

关于 Linux 系统

个人主页:ζ小菜鸡 大家好我是ζ小菜鸡,小伙伴们,让我们一起来学习Linux系统。 如果文章对你有帮助、欢迎关注、点赞、收藏(一键三连) 目录 一、Linux系统的组成 二、Linux 命令行的操作 三、Linux 目录结构 四、文件基本操作 五、用户及文件权限管…

jspssm小型药店药品管理系统

目 录 第一章 概述 1 1.1 研究背景 1 1.2开发意义 1 1.3 研究现状 1 1.4 研究内容 2 1.5论文结构 2 第二章 开发技术介绍 3 2.1 系统开发平台 3 2.2 平台开发相关技术 3 2.2.1 JSP技术介绍 3 2.2.2 Mysql数据库介绍 3 2.2.3 B/S架构 …

excel排名技巧:万能透视表加筛选找出销售冠军

在一组数据中统计单项排名第一的人员名单,如果用函数,我们需要考虑不同项目业绩相等的查找问题,需要考虑同项目业绩相等的查找问题。但如果用数据透视表加筛选,一切就很简单了。最后要么逐条地复制粘贴,要么用技巧合并…

ChatGPT 逆天测试,结局出乎预料

目录一、数学解题能力二、编程能力三、日常生活咨询四、问一些离谱的问题,它有啥反应?五、逆天大测试一、数学解题能力 据说 ChatGPT 会做数学题,给他几个条件不充分的问题,看看他是否真的会思考。 这家伙心理素质真好&#xff…

Uniswap 解析:恒定乘积做市商模型Constant Product Market Maker Model 的Vyper 实作

大纲 一. 前言 二. 恒定乘积做市商模型Constant Product Market Maker Model 1. 计入手续费 2. 程式码结构 3. 演算法核心与实作 4. 段落小结 三. 流动性Liquidity 1. 第一笔流动性注入、决定k值 2. 除了第一笔以外的情况 四. 结语 一. 前言 暨上一篇开始接触了Vyper 后&…

《吐血整理》保姆级系列教程-玩转Fiddler抓包教程(2)-初识Fiddler让你理性认识一下

1.前言 今天的理性认识主要就是讲解和分享Fiddler的一些理论基础知识。其实这部分也没有什么,主要是给小伙伴或者童鞋们讲一些实际工作中的场景,然后隆重推出我们的猪脚(主角)-Fiddler。 1.1工作场景 做app测试,你是…

SpringMVC的工作原理

SpringMVC的工作原理图: SpringMVC流程 1、 用户发送请求至前端控制器DispatcherServlet。 2、 DispatcherServlet收到请求调用HandlerMapping处理器映射器。 3、 处理器映射器找到具体的处理器(可以根据xml配置、注解进行查找),生成处理器对象及处…

二叉树29:二叉搜索树中的插入操作

主要是我自己刷题的一些记录过程。如果有错可以指出哦,大家一起进步。 转载代码随想录 原文链接: 代码随想录 leetcode链接:701.二叉搜索树中的插入操作 题目: 给定二叉搜索树(BST)的根节点 root 和要插入…

苏嵌实训——day19

文章目录一、数据库1.1 在ubuntu中安装数据库1.2 数据库的操作1.2.1 数据库命令的分类1.2.2 常用的系统命令1.2.3 数据中的常用的语句1.3 sqlite数据库中常用api1. sqlite3_open2. int sqlite3_close(sqlite3 * db)3. sqlite3_exec4 sqlite3_get_table5 void sqlite3_free_tabl…

JavaEE12-Spring MVC程序开发

目录 1.什么是Spring MVC? 1.1.MVC定义 1.1.1.Model(模型) 1.1.2.View(视图) 1.1.3.Controller(控制器) 1.2.MVC和Spring MVC的关系 2.为什么要学Spring MVC? 3.怎么学Spring MVC? 3.1.Spring …