d2l Markov序列模型

news2025/1/23 11:18:05

本节的任务是使用Markov模型对后续序列进行预测,使用sin函数+噪声绘制1000个样本点,取tau为4,即利用后四个的信息预测第五个。

目录

1.构造样本点

2.抽取iter

3.构造网络

4.训练

5.预测

5.1单步

5.1多步


1.构造样本点

T = 1000 # 总共产⽣1000个点
time = torch.arange(1, T + 1, dtype=torch.float32)
x = torch.sin(0.01 * time) + torch.normal(0, 0.2, (T,))
d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))

  要生成features-label对,y=xt,Xt=[xt-tau,...xt-1]。由于前tau个样本没有足够的前tau个数据来预测,所以去掉为(T-tau)。
  x是一个(1000,)的tensor,表示的是上图的函数值,也就是label(y)!!

tau = 4
features = torch.zeros((T - tau, tau))
for i in range(tau):
    features[:, i] = x[i: T - tau + i]
labels = x[tau:].reshape((-1, 1))

  features是一个(996,4)的tensor,一开始是全零,后面赋予的四列分别为x的[0,996];[1,997];[2,998];[3,999];[4,1000];这样才能保证每一行都是前后的tau个数据!!!

2.抽取iter

  用d2l里面自带的load_array函数,传入样本和label,自动返回bs个。

  在此,只取前600个用于训练。

batch_size, n_train = 16, 600
# 只有前n_train个样本⽤于训练
train_iter = d2l.load_array((features[:n_train], labels[:n_train]),
                            batch_size, is_train=True)

看下源码加深印象,其实也就是常规的dataset和dataloader:

def load_array(data_arrays, batch_size, is_train=True):
    """Construct a PyTorch data iterator.

    Defined in :numref:`sec_linear_concise`"""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

3.构造网络

  用一个简单的MLP

# 初始化⽹络权重的函数
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
        
# ⼀个简单的多层感知机
def get_net():
    net = nn.Sequential(nn.Linear(4, 10),
    nn.ReLU(),
    nn.Linear(10, 1))
    net.apply(init_weights)
    return net
    # 平⽅损失。注意:MSELoss计算平⽅误差时不带系数1/2
loss = nn.MSELoss(reduction='none')

4.训练

def train(net, train_iter, loss, epochs, lr):
    trainer = torch.optim.Adam(net.parameters(), lr)
    for epoch in range(epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.sum().backward()
            trainer.step()
        print(f'epoch {epoch + 1}, '
            f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')
        
net = get_net()
train(net, train_iter, loss, 5, 0.01)

5.预测

5.1单步

  下面的代码注意,这里的features是包含了原1000个样本点的features,所以net用600训练后,这里onestep_preds是将原数据中的所有features喂进去:

onestep_preds = net(features)
d2l.plot([time, time[tau:]],
    [x.detach().numpy(), onestep_preds.detach().numpy()], 'time',
    'x', legend=['data', '1-step preds'], xlim=[1, 1000],
    figsize=(6, 3))

5.1多步

而其余的多步预测,则是使用自己的预测数据作为新feature来预测:

multistep_preds = torch.zeros(T)
multistep_preds[: n_train + tau] = x[: n_train + tau]
for i in range(n_train + tau, T):
    multistep_preds[i] = net(
        multistep_preds[i - tau:i].reshape((1, -1)))
    
d2l.plot([time, time[tau:], time[n_train + tau:]],
    [x.detach().numpy(), onestep_preds.detach().numpy(),
    multistep_preds[n_train + tau:].detach().numpy()], 'time',
    'x', legend=['data', '1-step preds', 'multistep preds'],
    xlim=[1, 1000], figsize=(6, 3))

  上面的reshape(1,-1),bs=1因为每次只需返回一个值!

  注意:上面的multistep_preds前604个与x是一样的,后面的是0,根据前面的和预测值预测(预测了tau个后就完全依据预测值预测了)

5.3K布预测
  使用预测出的数据当后续预测的样本,进而继续预测。这会导致误差累积,偏差较大。

max_steps = 64
features = torch.zeros((T - tau - max_steps + 1, tau + max_steps))
# 列i(i<tau)是来⾃x的观测,其时间步从(i+1)到(i+T-tau-max_steps+1)
for i in range(tau):
    features[:, i] = x[i: i + T - tau - max_steps + 1]
# 列i(i>=tau)是来⾃(i-tau+1)步的预测,其时间步从(i+1)到(i+T-tau-max_steps+1)
for i in range(tau, tau + max_steps):
    features[:, i] = net(features[:, i - tau:i]).reshape(-1)

steps = (1, 4, 16, 64)
d2l.plot([time[tau + i - 1: T - max_steps + i] for i in steps],
    [features[:, (tau + i - 1)].detach().numpy() for i in steps], 'time', 'x',
    legend=[f'{i}-step preds' for i in steps], xlim=[5, 1000],
    figsize=(6, 3))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/413913.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【程序人生】5个月从职场打杂到月薪14000的女测试工程师逆袭之路

大家好&#xff0c;我是来自湖南的一位辣妹子&#xff0c;毕业于一所工业大学&#xff0c;大学的专业是软件与工程&#xff0c;其实也算是本专业&#xff0c;大学期间掌握的知识也算比较广&#xff0c;各个方面都会一丢丢&#xff0c;就是不是特别深入。 之所以这么说&#xf…

nginx配置文件介绍

nginx配置文件介绍 nginx默认的配置文件是在安装目录下的 conf目录下&#xff0c;后续对 nginx 的使用基本上都是对此配置文件进行相应的修改。 配置文件中用#符号表示注释内容。 配置文件主要包括三部分&#xff0c;main、events和http main 用于进行nginx全局信息的配置…

Netty应用篇

Netty应用 粘包和半包 服务器代码 public class StudyServer {static final Logger log LoggerFactory.getLogger(StudyServer.class);void start() {NioEventLoopGroup boss new NioEventLoopGroup(1);NioEventLoopGroup worker new NioEventLoopGroup();try {ServerBoo…

【WebRTC技术专题】未来可期,WebRTC的诞生发展的概述介绍(1)

近几年实时音视频通信应用呈现出了大爆发的趋势。在这些实时通信技术的背后&#xff0c;有一项不得不提的技术 ——WebRTC。 前言背景 2021年1月26日&#xff0c;W3C&#xff08;万维网联盟&#xff09; 和 IETF &#xff08;互联网工程任务组&#xff09; 同时宣布 WebRTC&…

企业办公WLAN覆盖方案的设计与实现_kaic

企业办公WLAN覆盖方案的设计与实现 摘要&#xff1a; 无线LAN技术的快速发展已经使它在当今的数字通讯行业中变得越来越重要。它的优点包括易于部署、灵活操作、价格实惠&#xff0c;使它能够在不同的场景中提供支持。无线LAN技术已经被许多不同类型的人所接受&#xff0c;并且…

linux下使用lftp的小结

lftp的功能比较强大&#xff0c;相比原来用ftp&#xff0c;方便了很多。 1、登陆&#xff1a; lftp ftp://yournamesite pwd:***** 或 open ftp://yournamesite 2、基本操作&#xff08;转&#xff09; lftp使用介绍 lftp 是一个功能强大的下载工具&#xff0c;它支持访问…

React-Native 创建App项目

# React-Native 创建App项目 环境搭建 概述 RN的官方网站百度谷歌 安装环境介绍 操作系统&#xff1a;win10系统手机&#xff1a;安卓手机真机一部或夜神模拟器必须安装的依赖有&#xff1a;Node,JDK,Yarn,Android SDK,Python2 Node的安装 先到官网去下载node版本&#…

Cypress触摸芯片自己做的demo 代码

1.前言 &#xff08;1&#xff09;cyprees芯片主要是可以做一些触摸的检测并实现一些IO输出&#xff0c;使用的工具psoc creater &#xff08;2&#xff09;psoc creater 可以i直接通过GUI的方式配置一些GPIO的状态以及集成的功能模块&#xff0c;编译后&#xff0c;我们可直接…

基于深度学习的花卉识别

1、数据集 春天来了&#xff0c;我在公园的小道漫步&#xff0c;看着公园遍野的花朵&#xff0c;看起来真让人心旷神怡&#xff0c;一周工作带来的疲惫感顿时一扫而光。难得一个糙汉子有闲情逸致俯身欣赏这些花朵儿&#xff0c;然而令人尴尬的是&#xff0c;我一朵都也不认识。…

2022蓝桥杯省赛——砍竹子

问题描述 这天, 小明在砍竹子&#xff0c; 他面前有 n 棵竹子排成一排&#xff0c;一开始第 i 棵竹子的 高度为 hi​。 他觉得一棵一棵砍太慢了&#xff0c; 决定使用魔法来砍竹子。魔法可以对连续的一 段相同高度的竹子使用&#xff0c; 假设这一段竹子的高度为 H&#xff0…

UNIX环境高级编程——系统数据文件和信息

6.1 引言 UNIX系统的正常运行需要使用大量与系统有关的数据文件&#xff0c;这些文件都是ASCII文本文件&#xff0c;并且使用标准I/O库读这些文件。 6.2 口令文件 UNIX口令文件是/etc/passwd&#xff0c;每一行包含下图中的各字段&#xff0c;字段之间用冒号分隔&#xff0c…

除了Jira、禅道还有哪些更好的敏捷开发过程管理平台?

无论是从国内的敏捷调研开发调研报告还是从国外的敏捷状态调查&#xff0c;工具支持一直是决定敏捷成功的关键因素之一&#xff0c;它们可以帮助团队提高软件开发的效率、质量、协作和满意度。选择合适的敏捷开发管理工具&#xff0c;并正确地使用它们&#xff0c;是每个敏捷团…

JAVA SMART系统-系统框架设计与开发

SMART系统是一个新型智能在线考试信息管理系统&#xff0c;该系统主要实现了学生在线考试与评估以及教师对学生在线考试信息的管理和维护。本文按照SMART系统的非功能性需求&#xff0c;基于Struts、Spring、Hibernate三种开源技术&#xff0c;构建了一个具有良好的可扩展性、可…

英文译中文翻译-中文英文翻译在线翻译

如果您需要在线翻译英文文本为汉字&#xff0c;您可以使用各种在线翻译服务或应用程序。以下是一些您可以尝试的在线翻译服务&#xff1a; Google翻译&#xff1a; Google翻译是一款广受欢迎的在线翻译服务&#xff0c;可将英语文本翻译成汉字。只需将需要翻译的英文文本复制粘…

MFC动态库封装

1.MVC的设计模式的使用 经典MVC模式中&#xff0c;M是指业务模型&#xff0c;V是指用户界面&#xff0c;C则是控制器&#xff0c;使用MVC的目的是将M和V的实现代码分离&#xff0c;从而使同一个程序可以使用不同的表现形式。其中&#xff0c;View的定义比较清晰&#xff0c;就…

自动化面试题4

1、工业中常见的通信方式都有哪些&#xff0c;各自特点是什么&#xff1f; 2、对于一台新的伺服驱动器来说&#xff0c;需要设置哪几个方面的参数&#xff1f; &#xff08;1&#xff09;参数初始化 &#xff08;2&#xff09;点动测试电机旋转方向 &#xff08;3&#xff09;惯…

神经网络/深度学习(二)

Seq2Seq 模型 Encoder-Decoder Attention 机制 Self-Attention 自注意力机制 Transformer 摘文不一定和目录相关&#xff0c;但是取自该链接 1. Seq2Seq 模型详解 https://baijiahao.baidu.com/s?id1650496167914890612&wfrspider&forpc Seq2Seq 是一种循环神经网…

云原生——容器技术docker基础命令

前言&#xff1a; &#x1f44f;作者简介&#xff1a;我是笑霸final&#xff0c;一名热爱技术的在校学生。 &#x1f4dd;个人主页&#xff1a;个人主页1 || 笑霸final的主页2 &#x1f4d5;系列专栏云原生专栏 &#x1f4e7;如果文章知识点有错误的地方&#xff0c;请指正&…

d2l语言模型--生成小批量序列

对语言模型的数据集处理做以下汇总与总结 目录 1.k元语法 1.1一元 1.2 二元 1.3 三元 2.随机抽样 2.1各bs之间随机 2.2各bs之间连续 3.封装 1.k元语法 1.1一元 tokens d2l.tokenize(d2l.read_time_machine()) # 因为每个⽂本⾏不⼀定是⼀个句⼦或⼀个段落&#xff0…

认识C++指针

目录 前言&#xff1a; 1.指针未初始化的危险性 2.指针与十六进制数字 3.使用new分配内存空间 4.使用delete释放内存 5.使用new来创建动态数组 6.使用动态数组 7.指针运算 前言&#xff1a; 期待已久的指针篇来啦&#xff0c;这篇全都是有关指针的知识&#xff0c;喜欢…