d2l Nadaraya-Waston核回归

news2025/1/16 8:45:35

注意力机制里面的非参数注意力汇聚

目录

1.目标任务

2.数据生成

2.1构造原始数值

3.非参数注意力汇聚

4.对注意力机制的理解


1.目标任务

使用y_train(有噪声),拟合y_truth(没噪声)。给你所有的y_train,构造注意力权重生成拟合曲线。

2.数据生成

n_train = 50 # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5) # 排序后的训练样本

其中:对于sort的返回值:

 sort返回两个值,第一个是拍好了从小到大的顺序后的values,另一个是对应原数据的indices

2.1构造原始数值

噪声服从u=0;std=0.5的正态分布:

 y_train为上述计算式,包含噪声;y_truth为上式不包含噪声

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # 训练样本的输出
x_test = torch.arange(0, 5, 0.1) # 测试样本
y_truth = f(x_test) # 测试样本的真实输出
n_test = len(x_test) # 测试样本数
n_test

训练样本是有噪声的x_train与y_train;真实数据是不带噪声的y_truth

3.非参数注意力汇聚

# X_repeat的形状:(n_test,n_train),
# 每⼀⾏都包含着相同的测试输⼊(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每⼀⾏都包含着要在给定的每个查询的值(y_train)之间分配的注意⼒权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意⼒权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

上块代码中:

repeat_interleave就是将原来的x_test中的每一个元素赋值n_train次,得到一个一维tensor,再使用reshape操作使复制的相同元素都在同一行中。

 对dim=1进行softmax,即将每个Xi与给定x的距离首先进行:1.映射非负(exp)+2.归一化。此时attention_weights表示各个候选Xi与给定x的距离的归一化权重。

 

 由图可见,在第一行给定的x=0,距离Xi中对应的X1距离最近,其余该行上的值绝对值都比它大,所以最后的softmax它得分最多,权重最多;对于倒数第二行对应的x=4.8,最后一个X50=4.8466距离最近,所以对应位置的softmax得分最高

X_repeat - x_train,形状为(50,50)-(50),第一个tensor中的每一行的各个元素减去第二个tensor里面的每个元素,用来计算给定x(0.1等分)与候选Xi的距离

 dim=1表示的是沿着列增的方向操作,本质是对每一行操作!!!

4.对注意力机制的理解

 给定x时(0.1等距分布),对所有候选Xi(rand)计算一下与x的距离(得到的X_repeat - x_train是(50,50)的tensor,其中每一行表示单个指定的x与各个Xi的距离。

  任务的目的使用y_train(有噪声),拟合y_truth(没噪声)。给你所有的y_train,构造权重生成拟合曲线。

  至于具体的在某个拟合点x(0.1等分),用attention_weights与所有的y_train相乘,每一行表示的是各个x_train(候选Xi)与x的距离的softmax值,距离越近权重越大,考虑的就越多。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/435552.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

五款高效易用的项目管理软件,提升团队工作效率

项目管理软件是为了协助团队或公司便捷和高效地完成工作任务和管理项目而专门设计的软件工具。有了它,团队成员可以共享资源,跟踪项目进度和成果,识别问题并及时解决。与传统的手工方式相比,项目管理软件可以提高工作效率和生产力…

Centos7上安装vscode和ssh

Centos7上安装vscode和ssh 一.前言二.Centos7上安装vscode三,Centos7上配ssh3.1 查看是否安装ssh环境3.2 配置ssh配置文件3.3 启动ssh服务 一.前言 在用linux环境编译项目的时候,比较习惯用ubuntu环境,而对centos环境的一些命令工具使用的比…

外链是什么意思,什么是外链

外链就是指在别的网站导入自己网站的链接。导入链接对于网站优化来说是非常重要的一个过程。导入链接的质量间接影响了我们的网站在搜索引擎中的权重。外链是互联网的血液,是链接的一种。没有链接的话,信息就是孤立的,结果就是我们什么都看不…

计算机网络笔记(方老师408课程)(持续更新)

文章目录 前言互联网概述互联网发展的三个阶段互联网标准化机构 互联网的组成边缘部分的通信方式核心部分的交换方式 我国计算机网络的发展计算机网络的类别计算机网络的性能速率、带宽、吞吐量时延时延带宽积往返时间RTT(Round-Trip Time)利用率非性能特…

SpringCloud分布式配置中心——Config

Config 本专栏学习内容来自尚硅谷周阳老师的视频 有兴趣的小伙伴可以点击视频地址观看 由于微服务越来越多,项目越来越庞大,每一个项目都至少有两三个不同环境的application.properties文件,不易管理,假设我们数据库迁移&#xff…

笔记--java sort() 方法排序

背景 最近在刷一道算法题 《字符串重新排序》时,发现自己有思路但是写代码的时候就无从下手了 而且看了答案之后还没看懂 关键就是基础不好 对于排序没有理解(虽然我学过常用的排序算法 但是都是理念 实践少) 目的 从实践和原理出发 重点是从…

参数处理、查询语句

一、Mybatis参数处理 1、数据准备 pojo类: public class Student {private Long id;private String name;private Integer age;private Double height;private Character sex;private Date birth;// constructor// setter and getter// toString }2、单个简单类型…

设计模式 -- 命令模式

前言 月是一轮明镜,晶莹剔透,代表着一张白纸(啥也不懂) 央是一片海洋,海乃百川,代表着一块海绵(吸纳万物) 泽是一柄利剑,千锤百炼,代表着千百锤炼(输入输出) 月央泽,学习的一种过程,从白纸->吸收各种知识->不断输入输出变成自己的内容 希望大家一起坚持这个过程,也同…

线性表详解

目录 1.线性表的定义和特点 2.案例 2.1一元多项式的计算 可以通过下面这个题目简单练习一下 2.2稀疏多项式的计算 2.3图书信息管理系统 3.线性表的类型定义 4.线性表的顺序表示和实现 4.1线性表的顺序储存表示 4.2顺序表中基本操作的实现 5.线性表的链式表现和实现 …

vba:inputbox

inputbox函数与方法 1.区别一:外观区别 InputBox 函数 在一对话框来中显示提示,等待用户输入正文或按下按钮,并返回包含文本框内容的 String。 Application.InputBox 方法 显示一个接收用户输入的对话框。返回此对话框中输入的信息。 -----…

分享一个国内使用的ChatGPT的方法

介绍 ChatGPT ChatGPT是一种基于自然语言处理技术的对话生成模型。它是由OpenAI公司开发的一种语言模型,可以在大规模语料库上进行无监督学习,并生成高质量的自然语言文本。ChatGPT可以用于多种应用场景,例如智能客服、语音助手、聊天机器人…

JAVA学习笔记(注解)

1. JDK预定义注解 (1) Deprecated(表示标记对象已过时) (2) SuppressWarnings("all") (忽略标记对象的警告) 2. 元注解(用于描述注解的注解) Target 描述注解所生效的位置 Retention 描述注…

SpringBooot

目录 一、简介 1、使用原因 2、JavaConfig (1)Configuration注解 (2)Bean注解 (3)ImportResource注解 (4)PropertyResource注解 (5)案例 3、简介 4…

Faster-RCNN代码解读8:代码调试与总结

Faster-RCNN代码解读8:代码调试与总结 前言 ​ 因为最近打算尝试一下Faster-RCNN的复现,不要多想,我还没有厉害到可以一个人复现所有代码。所以,是参考别人的代码,进行自己的解读。 ​ 代码来自于B站的UP主&#xff0…

网络协议-前端重点——DNS和CDN

目录 DNS的基础知识 统一资源定位符(URL)(网址) DNS(Dimain Name System)(域名系统) DNS Query过程 DNS记录 A记录 AAAA记录 CNAME记录(Canonical Name Record) MX记录&#…

Blender3.5 视图切换

目录 1. 数字小键盘切换视图1.1 正交顶视图1.2 正交前视图1.3 正交右视图1.4 透视图1.5 四格视图 2. 鼠标点击切换视图2.1 点击视图,根据需求选择对应视图2.2 点导航栏的坐标轴切换 3. 启用字母区数字键3.1 编辑——偏好设置——输入——勾选“模拟数字键” 1. 数字…

Linux驱动——高级I/O操作(四)

目录 几种I/O模型总结 异步通知 几种I/O模型总结 阻塞 IO:在资源不可用时,进程阻塞,阻塞发生在驱动中,资源可用后进程被唤醒,在阻塞期间不占用CPU,是最常用的一种方式。 非阻塞 I/O: 调用立即返回,即便是在资…

《Unity Shader 入门精要》第10章 高级纹理

第10章 高级纹理 10.1 立方体纹理 在图形学中,立方体纹理 (Cubemap) 是环境映射 (Environment Mapping) 的一种实现方法。 和之前见到的纹理不同,立方体纹理一共包含了6张图像,这些图像对应了…

typescript的keyof的用法

第一种:与接口一起用,返回联合类型 interface Person {name: string;age: number;location: string;}type K1keyof Person; // "name" | "age" | "gender" let a:K1name 第二种:与typeof一起用,可…

天梯赛练习(L2-013 ~ L2-020)

L2-013 红色警报 战争中保持各个城市间的连通性非常重要。本题要求你编写一个报警程序,当失去一个城市导致国家被分裂为多个无法连通的区域时,就发出红色警报。注意:若该国本来就不完全连通,是分裂的k个区域,而失去一…