75. 序列模型的代码实现

news2025/1/12 0:59:30

1. 训练

在了解了上述统计工具后,让我们在实践中尝试一下! 首先,我们生成一些数据:(使用正弦函数和一些可加性噪声来生成序列数据, 时间步为 1,2,…,1000 。)

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
T = 1000  # 总共产生1000个点
time = torch.arange(1, T + 1, dtype=torch.float32)
x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))
d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))

运行结果:

在这里插入图片描述

接下来,我们将这个序列转换为模型的特征-标签(feature-label)对。 基于嵌入维度 𝜏 ,我们将数据映射为数据对 𝑦𝑡=𝑥𝑡 和 𝐱𝑡=[𝑥𝑡−𝜏,…,𝑥𝑡−1] 。这比我们提供的数据样本少了 𝜏 个, 因为我们没有足够的历史记录来描述前 𝜏 个数据样本。

一个简单的解决办法是:如果拥有足够长的序列就丢弃这几项; 另一个方法是用零填充序列。 在这里,我们仅使用前600个“特征-标签”对进行训练。

tau = 4
# 因为是用过去4个样本来预测未来一个时刻,那么第1-4个时刻的数据是被用来预测第5个时刻
# 因此,只有后面5~T这T-4个数据是由前面4个时刻预测来的。
# 所以样本数为T-tau。而因为每个样本是由前tau个数据预测来的,所以
# 横坐标表示样本,纵坐标表示每个样本对应的前面tau个数据
features = torch.zeros((T - tau, tau))
# 通过for 循环对features矩阵赋值
for i in range(tau):
    features[:, i] = x[i: T - tau + i]
labels = x[tau:].reshape((-1, 1))
batch_size, n_train = 16, 600
# 只有前n_train个样本用于训练
train_iter = d2l.load_array((features[:n_train], labels[:n_train]),
                            batch_size, is_train=True)

在这里,我们使用一个相当简单的架构训练模型: 一个拥有两个全连接层的多层感知机ReLU激活函数和平方损失

# 初始化网络权重的函数
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)

# 一个简单的多层感知机
def get_net():
    net = nn.Sequential(nn.Linear(4, 10),
                        nn.ReLU(),
                        nn.Linear(10, 1))
    net.apply(init_weights)
    return net

# 平方损失。注意:MSELoss计算平方误差时不带系数1/2
loss = nn.MSELoss(reduction='none')

现在,准备训练模型了。实现下面的训练代码的方式与前面几节中的循环训练基本相同。因此,我们不会深入探讨太多细节。

def train(net, train_iter, loss, epochs, lr):
    trainer = torch.optim.Adam(net.parameters(), lr)
    for epoch in range(epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.sum().backward()
            trainer.step()
        print(f'epoch {epoch + 1}, '
              f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')

net = get_net()
train(net, train_iter, loss, 5, 0.01)

运行结果:

在这里插入图片描述

2. 预测

由于训练损失很小,因此我们期望模型能有很好的工作效果。 让我们看看这在实践中意味着什么。 首先是检查模型预测下一个时间步的能力, 也就是单步预测(one-step-ahead prediction)

onestep_preds = net(features)
d2l.plot([time, time[tau:]],
         [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',
         'x', legend=['data', '1-step preds'], xlim=[1, 1000],
         figsize=(6, 3))

运行结果:

在这里插入图片描述

正如我们所料,单步预测效果不错。 即使这些预测的时间步超过了 600+4 (n_train + tau), 其结果看起来仍然是可信的。 然而有一个小问题:如果数据观察序列的时间步只到 604 , 我们需要一步一步地向前迈进:
在这里插入图片描述

通常,对于直到 𝑥𝑡 的观测序列,其在时间步 𝑡+𝑘 处的预测输出 𝑥̂ 𝑡+𝑘 称为 𝑘 步预测( 𝑘 -step-ahead-prediction)。 由于我们的观察已经到了 𝑥604 ,它的 𝑘 步预测是 𝑥̂ 604+𝑘 。 换句话说,我们必须使用我们自己的预测(而不是原始数据)来进行多步预测。 让我们看看效果如何。

multistep_preds = torch.zeros(T)
multistep_preds[: n_train + tau] = x[: n_train + tau]
for i in range(n_train + tau, T):
    multistep_preds[i] = net(
        multistep_preds[i - tau:i].reshape((1, -1)))
d2l.plot([time, time[tau:], time[n_train + tau:]],
         [x.detach().numpy(), onestep_preds.detach().numpy(),
          multistep_preds[n_train + tau:].detach().numpy()], 'time',
         'x', legend=['data', '1-step preds', 'multistep preds'],
         xlim=[1, 1000], figsize=(6, 3))

在这里插入图片描述

上图这个预测是多步预测,也是1000-604+1=397步预测,也就是前604步来预测后面的397步。这样相当于误差累计了307次。

如上面的例子所示,绿线的预测显然并不理想。 经过几个预测步骤之后,预测的结果很快就会衰减到一个常数。 为什么这个算法效果这么差呢?事实是由于错误的累积: 假设在步骤 1 之后,我们积累了一些错误 𝜖1=𝜖¯ 。 于是,步骤 2 的输入被扰动了 𝜖1 , 结果积累的误差是依照次序的 𝜖2=𝜖¯+𝑐𝜖1 , 其中 𝑐 为某个常数,后面的预测误差依此类推。 因此误差可能会相当快地偏离真实的观测结果。

每次预测都有一点误差,这个误差进入到下一个数据的预测,误差又会增加,一直迭代下去,累积误差。

例如,未来 24 小时的天气预报往往相当准确, 但超过这一点,精度就会迅速下降。 我们将在本章及后续章节中讨论如何改进这一点。

基于 𝑘=1,4,16,64 ,通过对整个序列预测的计算, 让我们更仔细地看一下 𝑘 步预测的困难。

4步预测的意思是,用前面4步0,1,2,3来预测后面4步4,5,6,7,那么4通过0,1,2,3来预测,5通过1,2,3,4来预测,依此类推。那16步预测就是用前面4步来预测后面16步。

max_steps = 64

features = torch.zeros((T - tau - max_steps + 1, tau + max_steps))
# 列i(i<tau)是来自x的观测,其时间步从(i)到(i+T-tau-max_steps+1)
for i in range(tau):
    features[:, i] = x[i: i + T - tau - max_steps + 1]

# 列i(i>=tau)是来自(i-tau+1)步的预测,其时间步从(i)到(i+T-tau-max_steps+1)
for i in range(tau, tau + max_steps):
    features[:, i] = net(features[:, i - tau:i]).reshape(-1)
steps = (1, 4, 16, 64)
d2l.plot([time[tau + i - 1: T - max_steps + i] for i in steps],
         [features[:, tau + i - 1].detach().numpy() for i in steps], 'time', 'x',
         legend=[f'{i}-step preds' for i in steps], xlim=[5, 1000],
         figsize=(6, 3))

运行结果:

在这里插入图片描述

以上例子清楚地说明了当我们试图预测更远的未来时,预测的质量是如何变化的。 虽然“ 4 步预测”看起来仍然不错,但超过这个跨度的任何预测几乎都是无用的。

可以看出,难点在于去预测很远的未来。即使是很简单的正弦函数,对于去预测比较远的未来也是很困难的事情。

3. Q&A

Q1:在常规范围呢tau是不是越大越好,刚才例子tau=5是不是比4好?

A1:以马尔科夫假设,当然是能观察到更长的数据更好,但是如果tau太大的话,训练样本就会很少;并且tau增大的话,计算量会增大,模型需要更复杂去fit,而样本还少,这就更麻烦了。所以tau不能太大也不能太小,是有一个权衡的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/163206.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

新手nvm npm 卸载不用依赖包,项识别为 cmdlet、函数、脚本文件,等命令集合

nvm安装包&#xff1a;Releases coreybutler/nvm-windows GitHub下载ta就不用单独下载node了注意:vnm安装位置尽量不要动C:\Users\Administrator\AppData\Roaming\nvm\settings.txt增加下面代码node_mirror: https://npm.taobao.org/mirrors/node/ npm_mirror: https://npm.t…

java+Springboot交通事故档案管理系统

系统分为用户和管理员两个角色 用户的主要功能有&#xff1a; 1.用户注册和登陆系统 2.用户查看警察相关信息 3.用户查看我的相关事故信息&#xff0c;可以对交通事故进行交通申诉 4.用户查看交通申诉审核信息 5.退出登陆 管理员的主要功能有&#xff1a; 1.管理员输入账户登陆…

Metasploit渗透框架介绍及永恒之蓝复现

Metasploit渗透框架介绍及永恒之蓝复现一、Metasploit渗透框架介绍1.1 名词解释1.2 MSF简介1.3 MSF框架结构1.4 MSF命令汇总1.4.1 常用命令1.4.2 基本命令1.4.3 Exploits模块1.4.4 漏洞名称规则1.5 MSF模块介绍1.5.1 auxiliary(辅助模块)1.5.2 exploits(漏洞利用模块)1.5.3 pay…

Open3D 泊松盘网格采样(Python版本)

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 在图形的许多应用中,特别是在渲染中,从蓝色噪声分布生成样本是很重要的。然而,现有的有效技术不容易推广到二维以外。不过泊松盘采样是个例外,它允许在O(N)时间内生成泊松盘样本,而且该方法很容易在任意维度上…

分布式CAP和BASE理论学习笔记

参考至&#xff1a;https://blog.csdn.net/solihawk/article/details/124442443 1. CAP理论 CAP理论是计算机科学家Eric Brewer在2000年提出的理论猜想&#xff0c;在2002年被证明并成为分布式计算领域公认的定理&#xff0c;其理论的基本观念是&#xff0c;在分布式系统中不…

加密算法 AES和RSA

一&#xff0c;加密&#xff08;一&#xff09;加密基础&#xff1f;通过互联网发送数据&#xff0c;数据可能会被第三者恶意窃听&#xff0c;造成损失。因此需要给重要的数据进行加密&#xff0c;加密后的数据被称为“密文”。接收方通过解除加密或得原本的数据&#xff0c;把…

人工智能卷积算法

文章目录前言数字信号处理与卷积运算卷积公式与计算过程边缘卷积计算与0填充NumPy卷积函数二维矩阵卷积计算图像卷积应用实例总结前言 卷积运算实际上是一种常见的数学方法&#xff0c;与加法&#xff0c;乘法等运算类似&#xff0c;都是由两个输入的到一个输出。不同的是&…

迷宫问题---数据结构实践作业

迷宫问题—数据结构实践作业 ✅作者简介&#xff1a;大家好,我是新小白2022&#xff0c;让我们一起学习&#xff0c;共同进步吧&#x1f3c6; &#x1f4c3;个人主页&#xff1a;新小白2022的CSDN博客 &#x1f525;系列专栏&#xff1a;算法与数据结构 &#x1f496;如果觉得博…

什么是HAL库和标准库,区别在哪里?

参考文章https://blog.csdn.net/u012846795/article/details/122227823 参考文章 https://zhuanlan.zhihu.com/p/581798453 STM32的三种开发方式 通常新手在入门STM32的时候&#xff0c;首先都要先选择一种要用的开发方式&#xff0c;不同的开发方式会导致你编程的架构是完全…

Java 面向对象程序设计 消息、继承与多态实验 课程设计研究报告

代码&#xff1a;Java计算机课程设计面向对象程序设计对战游戏SwingGUI界面-Java文档类资源-CSDN文库 一、课程设计内容 一个游戏中有多种角色(Character)&#xff0c;例如&#xff1a;国王&#xff08;King&#xff09;、皇后&#xff08;Queen&#xff09;、骑士&#xff0…

【Linux多线程】

Linux多线程Linux线程概念什么是线程线程的优点线程的缺点线程异常线程用途Linux进程VS线程进程和线程进程的多个线程共享Linux线程控制POSIX线程库线程创建线程等待线程终止分离线程线程ID及进程地址空间布局Linux线程概念 什么是线程 在一个程序里的一个执行路线就叫做线程…

JavaScript 如何正确的分析报错信息

文章目录前言一、报错类型1.控制台报错2.终端报错二、错误追查总结前言 摸爬滚打了这么长时间…总结了一些排查错误的经验, 总的来说, 这是一篇JavaScript新手向文章. 里面会有些不那么系统性的, 呃, 知识? 一、报错类型 报错信息该怎么看, 怎么根据信息快速的追查错误. 1.…

瑞吉外卖项目

技术选型&#xff1a; 1、JAVA版本&#xff1a;JDK11 2、数据库&#xff1a;mysql5.7 Navicat 3、后端框架&#xff1a;SpringBoot SpringMVC MyBatisPlus 4、工具类&#xff1a;发邮件工具类、生成验证码工具类 5、项目优化&#xff1a;Nginx、Redis、读写分离 项目来…

2022. 12 青少年机器人技术等级考试理论综合试卷(五级)

2022.年12月青少年机器人技术等级考试理论综合试卷&#xff08;五级&#xff09; 分数&#xff1a; 100 题数&#xff1a; 30 一、 单选题(共 20 题&#xff0c; 共 80 分) 1.下列程序执行后,串口监视器显示的相应内容是&#xff1f; &#xff08; &#xff09; A.1 B.2 C.4 D.…

WPF绑定(Binding)下的数据验证IDataErrorInfo

绑定下的数据验证 WPF中Binding数据校验、并捕获异常信息的三种方式讲到了三种方式&#xff0c;其中使用ValidatinRule的方式比较推荐&#xff0c;但是如果一个类中有多个属性&#xff0c;要为每个属性都要声明一个ValidatinRule&#xff0c;这样做非常麻烦。可以让类继承自ID…

【High 翻天】Higer-order Networks with Battiston Federico (8)

目录传播与社会动力学&#xff08;2&#xff09;Opinion and cultural dynamicsVoter modelMajority modelsContinuous models of opinion dynamicsCultural dynamics传播与社会动力学&#xff08;2&#xff09; 在本节将讨论一些观点和文化动力学模型&#xff0c;它们基于物理…

【JavaSE】反射

一、概念反射是在运行期间&#xff0c;动态获取对象的属性和方法二、相关的类在Java的反射里主要有以下几个类&#xff1a;Class类&#xff0c;这是反射的起源&#xff0c;反射必须要先获取Class对象&#xff0c;其次是Field类&#xff0c;当我们需要通过反射获取私有字段时就需…

老杨说运维 | 2023,浅谈智能运维趋势(一)

&#xff08;文末附视频回顾&#xff0c;一键直达精彩内容&#xff09; 前言&#xff1a; 2022年&#xff0c;是经济被影响的一年&#xff0c;这一年无论是企业还是个人经济形势都呈下滑趋势&#xff0c;消费降级状态或许不会因为2022的结束而改观。 全球经济紧缩的状态下&am…

不仅会编程还要会英语(博主英语小笔记)1.1名词

目录 1-1名词的概念和分类 1、名词的概念 2&#xff0e;名词根据其意义可以分为专有名词和普通名词 &#xff08;1&#xff09;专有名词&#xff1a; &#xff08;2&#xff09;普通名词&#xff1a; 1-1名词的概念和分类 1、名词的概念 名词是表示人、动物、地点、物品以…

字符串常用函数介绍及模拟实现

&#x1f40e;作者的话 本文介绍字符串常用的函数如何使用及其模拟实现~ 跳跃式目录strlen介绍strcpy介绍strcat介绍strcmp介绍strncpy介绍strncat介绍strncmp介绍strstr介绍strchr介绍strrchr介绍memcpy介绍memmove介绍memcmp介绍memset介绍strtok介绍strlen介绍 函数原型&…