神经网络11-TFT模型的简单示例

news2024/11/19 12:54:58

Temporal Fusion Transformer (TFT) 是一种用于时间序列预测的深度学习模型,它结合了Transformer架构的优点和专门为时间序列设计的一些优化技术。TFT尤其擅长处理多变量时间序列数据,并且能够捕捉到长期依赖关系,同时通过自注意力机制有效地处理时序特征。TFT的工作原理主要由以下几个部分组成:

1. 输入数据处理

  • 输入特征:TFT的输入是一个多变量时间序列,每个样本包含多个特征(如10个特征,每个特征有240个时间步)。每个时间步的特征值可以是连续的(如温度、股价等),也可以是分类的(如星期几、节假日等)。
  • 静态和时间序列特征:TFT区分了静态特征(例如个体ID、地点)和动态特征(例如时间步上的温度)。静态特征在模型中用于增强个体的预测性能,而动态特征则帮助模型捕捉到随时间变化的模式。

2. Encoder-Decoder架构

TFT采用了编码器-解码器(Encoder-Decoder)的架构,这个架构原本用于序列到序列的任务(如机器翻译),但是在TFT中做了调整:

  • 编码器(Encoder):输入的时间序列通过编码器进行处理,编码器包括一个由自注意力机制和GRU(门控循环单元)组成的结构。自注意力机制能够帮助模型捕捉不同时间步之间的依赖关系,而GRU有助于捕捉短期的时间依赖性。
  • 解码器(Decoder):解码器根据编码器输出的特征以及未来的已知信息来生成预测。解码器可以直接预测下一个时间步的值。

3. 自注意力机制(Self-Attention)

自注意力机制在TFT中用于捕捉时间序列中各个时间步之间的长短期依赖关系。它通过计算每个时间步和其他时间步之间的相关性,自动地给出不同时间步的权重。这样,模型可以根据时间序列中的重要性自适应地调整权重。

4. 门控机制(Gating Mechanisms)

TFT采用了多个门控机制(例如:GRN(Gated Residual Network)和变量选择网络)来控制信息流,避免不必要的计算,并且使得模型更加灵活:

  • 变量选择网络:自动选择哪些输入特征对于当前时间步的预测更为重要,从而提升了模型的性能和可解释性。
  • 门控残差单元(GRN):通过加权处理动态特征和静态特征来提供更丰富的信息。

5. 多尺度时间步(Multi-Scale Temporal Fusion)

TFT不仅利用了全局时间步的特征,还通过多尺度处理能够同时捕捉长期和短期的模式。例如,通过多个不同的时间尺度对历史信息进行融合,从而提升了模型的预测精度。

6. 预测头(Forecasting Head)

在解码器的顶部,TFT有一个预测头,它根据模型输出的时间序列特征来进行实际的预测。它生成一个未来时间步的预测值,通常用于回归任务或二分类任务。

7. 自解释性(Interpretability)

TFT具有一定的可解释性,特别是通过注意力机制变量选择网络,可以分析哪些特征对于模型的预测最重要。这对于需要模型透明度和决策依据的场景非常有用。

8.简单例子:预测股票价格

假设我们有一个简单的时间序列数据集,包含时间步(例如每天的股票价格),并且我们希望预测未来几天的股票价格。假设我们有以下结构的时间序列数据:

  • 每天的股票价格(连续特征)。
  • 每天的交易量(连续特征)。
  • 每天的节假日信息(分类特征,例如是否为节假日)。

我们的目标是基于过去的几个时间步的数据预测未来的股票价格。

模型步骤

  1. 数据预处理:我们将数据准备为适合TFT的格式。TFT需要有时间步特征静态特征目标变量
  2. 模型构建:使用TFT模型进行训练和预测。
  3. 预测:用训练好的模型预测未来的股票价格。

示例代码

假设我们使用pytorch-forecasting这个库来实现TFT模型。这个库为时间序列任务提供了简化的API。

安装必要的库

首先,安装相关库:

pip install pytorch-forecasting pytorch-lightning

构建简单的TFT模型

import pandas as pd
import numpy as np
import torch
from pytorch_forecasting import TemporalFusionTransformer
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting import Baseline

# 生成模拟数据
np.random.seed(42)

# 创建一个简单的时间序列数据集
n_samples = 1000
time_idx = np.tile(np.arange(1, 101), n_samples // 100)  # 100天的周期重复n_samples次
stock_price = np.sin(time_idx * 0.1) + np.random.normal(0, 0.1, len(time_idx))  # 模拟的股票价格
volume = np.random.normal(1000, 100, len(time_idx))  # 模拟的交易量
is_holiday = (time_idx % 7 == 0).astype(int)  # 假设每7天是一个节假日

# 创建DataFrame
data = pd.DataFrame({
    "time_idx": time_idx,
    "stock_price": stock_price,
    "volume": volume,
    "is_holiday": is_holiday
})

# 创建训练和验证集
max_encoder_length = 60  # 用60个时间步预测未来
max_prediction_length = 10  # 预测未来10个时间步

# 将数据转换为TimeSeriesDataSet
training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= 80],  # 训练集使用前80天的数据
    time_idx="time_idx",
    target="stock_price",
    group_ids=["is_holiday"],  # 对应的静态特征是节假日
    min_encoder_length=max_encoder_length,
    max_encoder_length=max_encoder_length,
    min_prediction_length=max_prediction_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=["is_holiday"],
    time_varying_known_reals=["stock_price", "volume"],  # 动态已知特征
    time_varying_unknown_reals=["stock_price"],  # 动态目标变量
    target_normalizer=GroupNormalizer(groups=["is_holiday"], transformation="softplus"),  # 归一化
)

# 创建数据加载器
train_dataloader = torch.utils.data.DataLoader(training, batch_size=64, shuffle=True)

# 初始化TFT模型
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.001,
    hidden_size=16,  # 隐藏层单元数
    attention_head_size=4,  # 自注意力头数
    dropout=0.1,  # Dropout概率
    hidden_continuous_size=8,  # 连续特征的隐藏层大小
    output_size=1,  # 输出预测的维度
    loss= torch.nn.MSELoss(),  # 使用均方误差损失
)

# 训练模型
import pytorch_lightning as pl
trainer = pl.Trainer(max_epochs=10, gpus=0)  # 设置epochs
trainer.fit(tft, train_dataloader)

# 进行预测
test_data = data[lambda x: x.time_idx > 80]  # 测试集使用剩余的数据
test_dataset = TimeSeriesDataSet.from_dataset(training, test_data)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

predictions = tft.predict(test_dataloader, mode="raw")

代码解析

1.数据生成

我们使用sin函数加上随机噪声来模拟股票价格,并且生成随机的交易量和节假日标记。最终生成的股票数据如下所示:

 2.TimeSeriesDataSet

这是pytorch-forecasting库中的一个数据类,它帮助我们将数据集转化为模型能够理解的格式。我们指定了哪些是时间步特征、静态特征和目标变量,并且定义了时间步长度(60步历史数据用于预测未来10步)。

详细解析:

  1. data[lambda x: x.time_idx <= 80]

    • data 是一个包含时间序列数据的 DataFrame。
    • lambda x: x.time_idx <= 80 是一个条件筛选器,用于选择时间索引 (time_idx) 小于或等于 80 的数据。这样,只会使用前 80 天的数据来训练模型。
  2. time_idx="time_idx"

    • time_idx 指定了表示时间步的列名。在这个例子中,time_idx 表示时间索引(通常是从 1 到 n 的整数,表示每个时间步)。
  3. target="stock_price"

    • target 参数指定了模型要预测的目标变量。在这个例子中,目标变量是 stock_price(股票价格)。
  4. group_ids=["is_holiday"]

    • group_ids 用于指定一个或多个分组特征,这些特征将用于区分不同的时间序列或分组数据。在这里,is_holiday 表示节假日(通常是一个静态变量,指示每个时间步是否为假期)。它会告诉模型如何对待不同的节假日数据。
  5. min_encoder_length=max_encoder_lengthmax_encoder_length=max_encoder_length

    • min_encoder_lengthmax_encoder_length 指定了输入序列(编码器输入)的最小和最大长度。它们用于告诉模型每个输入序列的时间步数。
    • 这两个参数的值相等,表示使用固定长度的历史数据,假设为 max_encoder_length。例如,max_encoder_length=10 表示使用过去的 10 天数据进行预测。
    • max_encoder_lengthmin_encoder_length 应该是整数,表示时间序列的长度。
  6. min_prediction_length=max_prediction_lengthmax_prediction_length=max_prediction_length

    • min_prediction_lengthmax_prediction_length 是预测的时间步长,指定模型需要预测多少步。
    • 这两个参数的值也相等,表示预测的步长是固定的。例如,max_prediction_length=5 表示模型需要预测接下来 5 天的股票价格。
  7. static_categoricals=["is_holiday"]

    • static_categoricals 表示静态类别特征(即在整个时间序列中不变化的特征)。这里 is_holiday 是一个静态类别特征,指示每个时间步是否是节假日。
  8. time_varying_known_reals=["stock_price", "volume"]

    • time_varying_known_reals 指定了动态已知特征,这些特征在时间序列中会随时间变化,并且在训练时已知。在这个例子中,stock_price(股票价格)和 volume(交易量)是动态已知特征。
  9. time_varying_unknown_reals=["stock_price"]

    • time_varying_unknown_reals 列表指定了模型需要预测的动态目标变量(时间步变化的未知特征)。在本例中,stock_price 是需要预测的目标变量,它会随时间变化。
  10. target_normalizer=GroupNormalizer(groups=["is_holiday"], transformation="softplus")

  • target_normalizer 用于对目标变量进行归一化。在这个例子中,使用了 GroupNormalizer,它会基于 is_holiday 这一组特征对目标变量进行归一化。
  • transformation="softplus" 表示使用 Softplus 函数(log(1 + exp(x)))来平滑目标变量的分布,这有助于减小异常值的影响。
3.TFT模型

我们通过TemporalFusionTransformer.from_dataset构建了TFT模型,设置了学习率、隐藏层大小等超参数。使用均方误差(MSE)作为损失函数来训练模型。 

详细解析:

1. TemporalFusionTransformer.from_dataset(training, ...):

  • TemporalFusionTransformer.from_dataset() 是一个类方法,它从给定的数据集(training)中自动配置模型的各个部分,如输入特征、目标变量、编码器长度等。
  • trainingTimeSeriesDataSet,包含了时间序列数据和相关的特征。这个方法会根据该数据集自动处理输入特征和目标,设置模型结构。

2. learning_rate=0.001:

  • learning_rate 是模型优化器的学习率,决定了模型参数在训练时的更新步幅。较小的学习率(例如 0.001)通常能帮助优化过程更稳定,但训练速度可能会变慢。

3. hidden_size=16:

  • hidden_size 是模型中每个隐藏层的单元数。这里设置为 16,表示每个隐藏层的神经元数量。较大的隐藏层大小有助于模型捕捉更多的复杂模式,但也可能增加计算复杂性和过拟合的风险。

4. attention_head_size=4:

  • attention_head_size 指定了自注意力机制中多头注意力(Multi-Head Attention)机制的头数。这里设置为 4,意味着模型将通过 4 个不同的“头”来计算注意力权重,从而捕捉不同方面的信息。这是 Transformer 模型的关键特性,可以帮助模型同时关注输入的多个部分。

5. dropout=0.1:

  • dropout 是防止过拟合的一种技术,表示在训练过程中会随机丢弃 10% 的神经元,以减少模型对特定神经元的依赖。这里设置为 0.1,表示在训练时有 10% 的概率会忽略某些神经元的输出。

6. hidden_continuous_size=8:

  • hidden_continuous_size 是连续特征的隐藏层大小。时间序列中的一些特征(例如股价、交易量等)可能是连续变量,这个参数指定了这些连续特征的隐藏表示大小。设置为 8,表示模型将使用 8 个单位来表示这些连续特征。

7. output_size=1:

  • output_size 指定了模型的输出维度。在时间序列预测任务中,通常输出一个预测值(如股票价格),所以这里设置为 1,表示模型输出一个数值。

8. loss= torch.nn.MSELoss():

  • loss 是指定优化目标的损失函数。在回归任务(如股票价格预测)中,常使用均方误差(MSE)作为损失函数。torch.nn.MSELoss() 会计算模型预测值与实际值之间的均方误差。
4.训练

使用pytorch_lightning中的Trainer进行模型训练,设置训练的epoch数为10。 

详细解析

  1. import pytorch_lightning as pl

    • 导入 PyTorch Lightning 库,PyTorch Lightning 是一个用于简化 PyTorch 模型训练过程的高级框架。它封装了很多 PyTorch 中繁琐的训练步骤,使得训练过程更清晰、更易于管理。
  2. trainer = pl.Trainer(max_epochs=10, gpus=0)

    • 创建一个 Trainer 对象,这是 PyTorch Lightning 中用于训练模型的主要接口。这里使用了以下参数:
      • max_epochs=10:指定训练的最大轮数(epochs)。训练将在 10 个 epochs 后停止。
      • gpus=0:指定使用的 GPU 数量,0 表示不使用 GPU(即在 CPU 上训练)。如果你的机器上有可用的 GPU,可以将其设置为 gpus=1 或更多。
  3. trainer.fit(tft, train_dataloader)

    • fit() 方法用于训练模型。在此处,tft 是创建的 TemporalFusionTransformer 模型,train_dataloader 是训练数据的 DataLoader,包含了批次化的训练数据。
    • fit() 方法会根据训练数据自动进行模型的前向传播、反向传播和参数更新。
5. 预测

通过predict函数,模型基于测试集进行预测。

详细解析

  1. test_data = data[lambda x: x.time_idx > 80]

    • test_data 选择 datatime_idx 大于 80 的数据,作为测试集。这表示测试集包含从第 81 天开始的数据。
  2. test_dataset = TimeSeriesDataSet.from_dataset(training, test_data)

    • TimeSeriesDataSet.from_dataset() 方法根据训练集 training 和测试集 test_data 创建一个新的测试数据集。这个方法会从训练集的配置中继承必要的参数(如时间索引、特征列等),然后将其应用于测试数据集。
  3. test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

    • 使用 DataLoader 来批次化测试集。与训练集不同,测试集通常不需要进行打乱,因此 shuffle=False
    • batch_size=64 表示每个批次加载 64 个样本。
  4. predictions = tft.predict(test_dataloader, mode="raw")

    • predict() 方法用于对测试集进行预测。它将返回模型对测试集的预测结果。
    • mode="raw" 指定了返回原始的预测结果(而不是经过后处理或归一化的结果)。你可以选择不同的模式来控制返回的预测格式,例如可以选择返回预测的概率、标签等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2243393.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

汽车资讯新动力:Spring Boot技术革新

摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了汽车资讯网站的开发全过程。通过分析汽车资讯网站管理的不足&#xff0c;创建了一个计算机管理汽车资讯网站的方案。文章介绍了汽车资讯网站的系统分析部分&…

gvim添加至右键、永久修改配置、放大缩小快捷键、ctrl + c ctrl +v 直接复制粘贴、右键和还原以前版本(V)冲突

一、将 vim 添加至右键 进入安装目录找到 vim91\install.exe 管理员权限执行 Install will do for you:1 Install .bat files to use Vim at the command line:2 Overwrite C:\Windows\vim.bat3 Overwrite C:\Windows\gvim.bat4 Overwrite C:\Windows\evim.bat…

Docker部署Kafka SASL_SSL认证,并集成到Spring Boot

1&#xff0c;创建证书和密钥 需要openssl环境&#xff0c;如果是Window下&#xff0c;下载openssl Win32/Win64 OpenSSL Installer for Windows - Shining Light Productions 还需要keytool环境&#xff0c;此环境是在jdk环境下 本案例所使用的账号密码均为&#xff1a; ka…

【进阶系列】python简单爬虫实例

python有一个很强大的功能就是爬取网页的信息&#xff0c;这里是CNBlogs 网站&#xff0c;我们将以此网站为实例&#xff0c;爬取指定个页面的大标题内容。代码如下&#xff1a; 首先是导入库&#xff1a; # 导入所需的库 import requests # 用于发送HTTP请求 from bs4 impor…

基于Java和Vue实现的上门做饭系统上门做饭软件厨师上门app

市场前景 生活节奏加快&#xff1a;在当今快节奏的社会中&#xff0c;越来越多的人因工作忙碌、时间紧张而无法亲自下厨&#xff0c;上门做饭服务恰好满足了这部分人群的需求&#xff0c;为他们提供了便捷、高效的餐饮解决方案。个性化需求增加&#xff1a;随着人们生活水平的…

CentOS 7中查找已安装JDK路径的方法

使用yum安装了jdk8&#xff0c;但是其他中间件需要配置路径的时候&#xff0c;却没办法找到&#xff0c;如何获取jdk路径&#xff1a; 一、确认服务器是否存在jdk java -version 二、查找jdk的 java 命令在哪里 which java 三、找到软链指向的地址 ls -lrt /usr/bin/java l…

分布式----Ceph部署

目录 一、存储基础 1.1 单机存储设备 1.2 单机存储的问题 1.3 商业存储解决方案 1.4 分布式存储&#xff08;软件定义的存储 SDS&#xff09; 1.5 分布式存储的类型 二、Ceph 简介 三、Ceph 优势 四、Ceph 架构 五、Ceph 核心组件 #Pool中数据保存方式支持两种类型&…

UE5 材质里面画圆锯齿严重的问题

直接这么画圆会带来锯齿&#xff0c;我们对锯齿位置进行模糊 可以用smoothstep&#xff0c;做值的平滑过渡&#xff08;虽然不是模糊&#xff0c;但是类似&#xff09;

即插即用的3D神经元注意算法!

&#x1f3e1;作者主页&#xff1a;点击&#xff01; &#x1f916;编程探索专栏&#xff1a;点击&#xff01; ⏰️创作时间&#xff1a;2024年11月18日10点39分 神秘男子影, 秘而不宣藏。 泣意深不见, 男子自持重, 子夜独自沉。 论文连接 点击开启你的论文编制之旅…

Mac的Terminal随机主题配置

2024年8月8日 引言 对于使用Mac的朋友&#xff0c;如果你是一个程序员&#xff0c;那肯定会用到Terminal。一般来说Terminal就是一个黑框&#xff0c;但其实Terminal是有10款官方皮肤。 每个都是不一样的主题&#xff0c;颜色和字体都会有所改变。现在就有一个方法可以很平均…

《Probing the 3D Awareness of Visual Foundation Models》论文解析——单图像表面重建

一、论文简介 论文讨论了大规模预训练产生的视觉基础模型在处理任意图像时的强大能力&#xff0c;这些模型不仅能够完成训练任务&#xff0c;其中间表示还对其他视觉任务&#xff08;如检测和分割&#xff09;有用。研究者们提出了一个问题&#xff1a;这些模型是否能够表示物体…

泷羽sec学习打卡-云技术基础1-docker

声明 学习视频来自B站UP主 泷羽sec,如涉及侵权马上删除文章 笔记的只是方便各位师傅学习知识,以下网站只涉及学习内容,其他的都与本人无关,切莫逾越法律红线,否则后果自负 关于云技术基础的那些事儿-Base1 一、云技术基础什么是云架构&#xff1f;什么是云服务&#xff1f;什么…

03-axios常用的请求方法、axios错误处理

欢迎来到“雪碧聊技术”CSDN博客&#xff01; 在这里&#xff0c;您将踏入一个专注于Java开发技术的知识殿堂。无论您是Java编程的初学者&#xff0c;还是具有一定经验的开发者&#xff0c;相信我的博客都能为您提供宝贵的学习资源和实用技巧。作为您的技术向导&#xff0c;我将…

Spring Boot 与腾讯云 MySQL 监听 Binlog 数据变化,并使用 UI 展示页面效果

引言 在现代的分布式系统和微服务架构中&#xff0c;数据同步和变更监控是保证系统一致性和实时性的核心问题之一。MySQL 数据库的 binlog&#xff08;二进制日志&#xff09;功能能够记录所有对数据库的修改操作&#xff0c;如插入&#xff08;INSERT&#xff09;、更新&…

Spring Boot汽车资讯:科技与速度的新纪元

摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了汽车资讯网站的开发全过程。通过分析汽车资讯网站管理的不足&#xff0c;创建了一个计算机管理汽车资讯网站的方案。文章介绍了汽车资讯网站的系统分析部分&…

thinkphp6模板调用URL方法生成的链接异常

var uul params.url ;console.log(params.url);console.log("{:Url(UserLog/index)}");console.log("{:Url("uul")}"); 生成的链接地址 UserLog/index /jjg/index.php/Home/UserLog/index.html /jjg/index.php/Home/Index/UserLog/index.html…

NodeJS 百度智能云文本转语音(实测)

现在文本转语音的技术已经非常完善了&#xff0c;尽管网络上有许多免费的工具&#xff0c;还是测试了专业的服务&#xff0c;选择了百度的TTS服务。 于是&#xff0c;在百度智能云注册和开通了文本转语音的服务&#xff0c;尝试使用NodeJS 实现文本转语音服务。但是百度的文档实…

UML 类图讲解

UML 类图符号含义 在 UML 类图中&#xff0c;每个符号都有其特定的含义。以下是常见符号的解释&#xff1a; : Public&#xff08;公共访问权限&#xff09;-: Private&#xff08;私有访问权限&#xff09;#: Protected&#xff08;受保护访问权限&#xff09;~: Package&…

【GAT】 代码详解 (1) 运行方法【pytorch】可运行版本

GRAPH ATTENTION NETWORKS 代码详解 前言0.引言1. 环境配置2. 代码的运行2.1 报错处理2.2 运行结果展示 3.总结 前言 在前文中&#xff0c;我们已经深入探讨了图卷积神经网络和图注意力网络的理论基础。还没看的同学点这里补习下。接下来&#xff0c;将开启一个新的阶段&#…