《动手学深度学习》-64注意力机制

news2024/11/18 3:41:36

沐神版《动手学深度学习》学习笔记,记录学习过程,详细的内容请大家购买书籍查阅。
b站视频链接
开源教程链接

注意力机制

生物学中的注意力提示

在这里插入图片描述

灵长类动物的视觉系统接受了大量的感官输入,这些感官输入远远超出了大脑所能够完全处理的能力。然而并非所有刺激的影响都是等同的。意识的汇聚和专注使灵长类动物能够在复杂的视觉环境中将注意力引向感兴趣的物体,如猎物和天敌。

举例注意力在视觉世界中的应用,引出非自主性提示自主性提示,非自主性提示是基于环境中物体的突出性和易见性,如左图中,咖啡杯在当前视觉环境中是最突出和显眼的,不由自主地引起人们地注意力。当我们喝完咖啡想要读书时,重新聚集,这时因为受到认知和意识地控制,注意力在基于自主性提示来进行选择时会更为谨慎。

查询、键和值

在这里插入图片描述

自主性与非自主性地注意力提示解释了人类的注意力地方式。
考虑一个相对简单的情况,即只是用非自主性提示。要想将选择偏向感官输入,则可以简单地使用参数化的全连接层,甚至是非参数化的最大汇聚层或平均汇聚层。

因此“是否包含自主性提示”将注意力机制与全连接层或汇聚层区分开来。在注意力机制的背景下,自主性提示被称为查询(query)。给定任何查询,注意力机制通过注意力汇聚(attention pooling)将选择引导至感官输入(sensory input),例如中间特征表示。在注意力机制中,这些感官输入被称为(value)。更通俗的讲,每个值都与一个(key)匹配,这可以想象为感官输入的非自主性提示。

概括就是:通过设计注意力汇聚的方式,便于给定的查询(自主性提示)与键(非自主性提示)进行匹配,这将引导得出最匹配的值(感官输入)。

在这里插入图片描述

Nadaraya-Watson核回归是一个非参数的注意力汇聚模型,非参数的Nadaraya-Watson核回归具有一致性的优点:如果有足够的数据,此模型会收敛到最佳结果。

在这里插入图片描述

将可学习的参数集成到注意力汇聚中:

在这里插入图片描述

总结

在这里插入图片描述

动手学

注意力的可视化

import torch
from d2l import torch as d2l

#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);

attention_weights = torch.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')

在这里插入图片描述

注意力汇聚: Nadaraya-Watson核回归

from torch import nn

# 生成数据集
n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本 0-5之间随机生成50个数

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数

# 绘制所有的训练样本(圆圈)、不带噪声项的真实数据生成函数、学习得到的预测函数
def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);
# 平均汇聚
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

在这里插入图片描述

# 非参数注意力汇聚

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

在这里插入图片描述
观察注意力的权重。 这里测试数据的输入相当于查询,而训练数据的输入相当于键。 因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近, 注意力汇聚的注意力权重就越高。

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

在这里插入图片描述
将训练数据集变换为键和值用于训练注意力模型。 在带参数的注意力汇聚模型中, 任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算, 从而得到其对应的预测输出。

# 带参数注意力汇聚
class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1)
        # values的形状为(查询个数,“键-值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

for epoch in range(50):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))

在这里插入图片描述
如下所示,训练完带参数的注意力汇聚模型后可以发现: 在尝试拟合带噪声的训练数据时, 预测结果绘制的线不如之前非参数模型的平滑。

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

在这里插入图片描述

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/831454.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue 标题文字字数过长超出部分用...代替 动态显示

效果: 浏览器最大化: 浏览器缩小: 代码: html: <div class"title overflow">{{item.name}}</div> <div class"content overflow">{{item.content}}</div> css: .overflow {/* 一定要加宽度 */width: 90%;/* 文字的大小 */he…

台风来袭,这份避险防御指南一定收好

台风天气的到来&#xff0c;我们必须高度警惕&#xff01;大到暴雨、雷电、雷雨大风&#xff0c;甚至短时强降水等强对流天气&#xff0c;可能给我们的生活带来严重威胁。为了确保家人安全&#xff0c;让我们共同学习一些智慧防护措施&#xff0c;做好个人安全防范。定期关注天…

C++初阶之一篇文章让你掌握vector(理解和使用)

vector&#xff08;理解和使用&#xff09; 1.什么是vector&#xff1f;2.vector的使用2.1 vector构造函数2.2 vector迭代器&#xff08;Iterators&#xff09;函数2.2.1 begin()2.2.2 end()2.2.3 rbegin()2.2.4 rend()2.2.5 cbegin()、cend()、crbegin()和crend() C11 2.3 vec…

Java类集框架(二)

目录 1.Map&#xff08;常用子类 HashMap&#xff0c;LinkedHashMap&#xff0c;HashTable&#xff0c;TreeMap&#xff09; 2.Map的输出&#xff08;Map.Entry,iterator,foreach&#xff09; 3.数据结构 - 栈&#xff08;Stack&#xff09; 4.数据结构 - 队列&#xff08;Q…

485modbus转profinet网关连三菱变频器modbus通讯触摸屏监控

本案例介绍了如何通过485modbus转profinet网关连接威纶通与三菱变频器进行modbus通讯。485modbus转profinet网关提供了可靠的连接方式&#xff0c;使用户能够轻松地将不同类型的设备连接到同一网络中。通过使用这种网关&#xff0c;用户可以有效地管理和监控设备&#xff0c;从…

人工智能与物理学(软体机器人能量角度)的结合思考

前言 好久没有更新我的CSDN博客了&#xff0c;细细数下来已经有了16个月。在本科时期我主要研究嵌入式&#xff0c;研究生阶段对人工智能感兴趣&#xff0c;看了一些这方面的论文和视频&#xff0c;因此用博客记录了一下&#xff0c;后来因为要搞自己的研究方向&#xff0c;就…

使用Golang实现一套流程可配置,适用于广告、推荐系统的业务性框架——组合应用

在《使用Golang实现一套流程可配置&#xff0c;适用于广告、推荐系统的业务性框架——简单应用》中&#xff0c;我们看到了各种组合Handler的组件&#xff0c;如HandlerGroup和Layer。这些组件下面的子模块又是不同组件&#xff0c;比如LayerCenter的子组件是Layer。如果此时我…

Windows用户如何将cpolar内网穿透配置成后台服务,并开机自启动?

Windows用户如何将cpolar内网穿透配置成后台服务&#xff0c;并开机自启动&#xff1f; 文章目录 Windows用户如何将cpolar内网穿透配置成后台服务&#xff0c;并开机自启动&#xff1f;前置准备&#xff1a;VS Code下载后&#xff0c;默认安装即可VS CODE切换成中文语言 1. 将…

FSC 认证产品门户网站正式上线

【FSC 认证产品门户网站正式上线】 FSC 国际正式推出自助式服务平台——FSC认证产品门户网站。FSC 证书持有者均可通过该平台自行添加企业或组织的 FSC 认证产品&#xff0c;寻求更多商机&#xff1b;也可通过该门户申请参与亚马逊气候友好项目&#xff08;Amazon Climate-Frie…

低代码平台,让应用开发更简单!

一、前言 随着社会数字化进程的加速&#xff0c;旺盛的企业个性化需求和有限的专业开发人员供给之间的矛盾日益显著&#xff0c;业界亟需更快门槛、更高效率的开发方法和工具&#xff0c;低代码技术便应运而生。 低代码开发&#xff0c;是通过编写少量代码甚至无需代码&#xf…

作为一个老程序员,想对新人说什么?

前言 最近知乎上&#xff0c;有一位大佬邀请我回答下面这个问题&#xff0c;看到这个问题我百感交集&#xff0c;感触颇多。 在我是新人时&#xff0c;如果有前辈能够指导方向一下&#xff0c;分享一些踩坑经历&#xff0c;或许会让我少走很多弯路&#xff0c;节省更多的学习的…

Vue3文本省略(Ellipsis)

APIs 参数说明类型默认值必传maxWidth文本最大宽度number | string‘100%’falseline最大行数numberundefinedfalsetrigger展开的触发方式‘click’undefinedfalsetooltip是否启用文本提示框booleantruefalsetooltipMaxWidth提示框内容最大宽度&#xff0c;单位px&#xff0c;…

数据结构--单链表OJ题

上文回顾---单链表 这章将来做一些链表的相关题目。 目录 1.移除链表元素 2.反转链表 3.链表的中间结点 4.链表中的倒数第k个结点 5.合并两个有序链表 6.链表分割 7.链表的回文结构 8.相交链表 9.环形链表 ​编辑 10.环形链表II ​编辑 ​编辑 1.移除链表元素 思…

Camera之元数据(meta data)和原始数据(raw data)区别(三十一)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生从来没有捷径,只有行动才是治疗恐惧和懒惰的唯一良药. 更多原创,欢迎关注:Android…

JavaScript原生将图片转成base64

1.写个html文件 <!-- 产品照片 --> <div class"mb-3"> <label for"cover" class"form-label">产品图片</label><inputtype"file"class"form-control"id"coverfile"/> </div>…

这些你不容错过的音频转换器免费推荐给你

音频格式转换技术是一种能够将不同格式的音频文件互相转换的技术。你可能会想&#xff0c;为什么要进行音频格式转换呢&#xff1f;原因可是多种多样的&#xff01;有时候你可能收到了一个音频文件&#xff0c;但却无法在你的设备上播放&#xff0c;这时候就需要将其转换为兼容…

shiro快速入门

文章目录 权限管理什么是权限管理&#xff1f;什么是身份认证&#xff1f;什么是授权&#xff1f; 什么是shiro&#xff1f;shiro的核心架构shiro中的三个核心组件 shiro中的认证shiro中的授权shiro使用默认Ehcache实现缓存shiro使用redis作为缓存实现 权限管理 什么是权限管理…

从0到1学会手写操作系统,我只用了2个小时!!!

继《自己动手做一台计算机》 《嵌入式入门-模电基础》两大教程之后 黑马嵌入式教程再出力作 重磅发布第三弹 《自己动手写嵌入式操作系统》 问&#xff1a;嵌入式开发不是只学单片机就行&#xff1f;为什么要学操作系统&#xff1f; 答&#xff1a;年轻人&#xff0c;别把路…

LeetCode每日一题Day4——26. 删除有序数组中的重复项

✨博主&#xff1a;命运之光 &#x1f984;专栏&#xff1a;算法修炼之练气篇&#xff08;C\C版&#xff09; &#x1f353;专栏&#xff1a;算法修炼之筑基篇&#xff08;C\C版&#xff09; &#x1f433;专栏&#xff1a;算法修炼之练气篇&#xff08;Python版&#xff09; …

Python 批量处理JSON文件,替换某个值

Python 批量处理JSON文件&#xff0c;替换某个值 直接上代码&#xff0c;替换key TranCode的值 New 为 Update。输出 cancel忽略 import json import os import iopath D:\\Asics\\850\\202307 # old path2 D:\\test2 # new dirs os.listdir(path) num_flag 0 for file…