【算法】长短期记忆网络(LSTM,Long Short-Term Memory)

news2024/9/20 9:28:20

这是一种特殊的循环神经网络,能够学习数据中的长期依赖关系,这是因为模型的循环模块具有相互交互的四个层的组合,它可以记忆不定时间长度的数值,区块中有一个gate能够决定input是否重要到能被记住及能不能被输出output。

原理

黄色方框内是四个神经网络层,红色圆圈是逐点算子,橙色圆圈是输入,蓝色圆圈是细胞状态。LSTM具有一个单元状态和三个门,对应选择有选择地学习、取消学习或保留来自每个单元的信息的能力。

LSTM中的单元状态通过只允许一些线性交互来帮助信息流过单元而不被改变。

每个单元都有一个输入、输出和一个遗忘门,可以将信息添加或者删除到单元状态。

在这里插入图片描述

遗忘门:使用sigmoid函数决定应该忘记来自先前单元状态的哪些信息。

输入门:分别使用sigmoid和tanh的逐点乘法运算控制信息流到当前单元状态。

输出门:最后,输出门决定哪些信息应该传递到下一个隐藏状态。

要在python中使用lstm模型,需要安装这些库:

pip install tansorflow pandas numpy matplotlib

# pandas用来数据处理
# numpy用来数值计算
# matplotlib.pyplot用于数据可视化
# MinMaxScaler从sklearn.preprocessing用于数据规范化
# Sequential,LSTM,Dense从tensorflow.keras用于构建神经网络
# mean_squared_error从sklearn.metrics用于计算模型误差

实现

  1. 生成示例数据:简单的正弦波形,
  2. 设置随机数生成的种子,确保结果可以复现,
  3. 生成一系列时间步长,
  4. 创建数据,结合正弦波和随机噪声
  5. 数据转换为DataFrame
  6. 使用Pandas的DataFrame来存储和处理生成的数据,
  7. 数据规范化:使用MinMaxScaler将数据规范化到0和1之间,这对神经网络的性能至关重要。
  8. 分割数据为训练集和测试集:确定训练集的大小(数据的80%),剩余的20%s数据作为测试集。
  9. 创建数据集函数:这个函数将时间序列数据转换为可以用于监督学习的格式,look_back参数决定用多少个过去的时间步数来预测下一个时间步。
  10. 设置look_back,并创建训练/测试数据;
  11. 使用1作为look_back的值
  12. 重塑输入数据为[样本,时间步,特征]
  13. LSTM模型在keras中需要三维输入
  14. 创建LSTM模型:创建一个Sequential模型,添加一个含有50个神经元的LSTM层,添加一个Dense层作为输出层,编译模型,使用均方误差作为损失函数和Adam优化器。

注:epoch是指训练周期。

代码如下:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from sklearn.metrics import mean_squared_error

# 生成示例数据:正弦波 + 随机噪声
np.random.seed(0)
timesteps = np.arange(0, 1000, 0.1)
data = np.sin(timesteps) + np.random.normal(scale=0.5, size=len(timesteps))

# 数据转换为DataFrame
df = pd.DataFrame(data, columns=['value'])
values = df['value'].values

# 数据规范化
scaler = MinMaxScaler(feature_range=(0, 1))
values_scaled = scaler.fit_transform(values.reshape(-1, 1))

# 分割数据为训练集和测试集
train_size = int(len(values_scaled) * 0.8)
test_size = len(values_scaled) - train_size
train, test = values_scaled[0:train_size, :], values_scaled[train_size:len(values_scaled), :]

# 创建数据集
def create_dataset(dataset, look_back=1):
    X, Y = [], []
    for i in range(len(dataset) - look_back - 1):
        a = dataset[i:(i + look_back), 0]
        X.append(a)
        Y.append(dataset[i + look_back, 0])
    return np.array(X), np.array(Y)

look_back = 1
X_train, Y_train = create_dataset(train, look_back)
X_test, Y_test = create_dataset(test, look_back)

# 重塑输入数据为 [样本, 时间步, 特征]
X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))

# 创建LSTM模型
model = Sequential()
model.add(LSTM(50, input_shape=(1, look_back)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')

# 训练模型
model.fit(X_train, Y_train, epochs=5, batch_size=1, verbose=2)

# 进行预测
train_predict = model.predict(X_train)
test_predict = model.predict(X_test)

# 反转规范化
train_predict = scaler.inverse_transform(train_predict)
Y_train = scaler.inverse_transform([Y_train])
test_predict = scaler.inverse_transform(test_predict)
Y_test = scaler.inverse_transform([Y_test])

# 计算均方误差
train_score = np.sqrt(mean_squared_error(Y_train[0], train_predict[:,0]))
test_score = np.sqrt(mean_squared_error(Y_test[0], test_predict[:,0]))

# 可视化
plt.figure(figsize=(12, 6))
plt.plot(scaler.inverse_transform(values_scaled), label='Original Data')
plt.plot(np.append(np.zeros(train_size), train_predict[:,0]), linestyle='--', label='Training Predict')
plt.plot(np.append(np.zeros(train_size), test_predict[:,0]), linestyle='--', label='Test Predict')
plt.legend()
plt.show()

运行图如下:

在这里插入图片描述

观测

训练损失逐渐降低并趋于稳定,意味着模型正在从训练数据中学习。

在训练集和测试集上的评估速度很快,意味着模型的推断(预测)效率很高。

如果损失在后续的epoch中没有显著下降,可能意味着模型需要更多的epoch来训练。或者可能需要调整模型的结构或超参数(例如增加神经元数量、改变学习率)以进一步提高性能。

训练集和测试集的RMSE非常接近,说明模型在两者上的性能是一致的。

没有出现过拟合/欠拟合的迹象,则说明模型的泛化能力良好。

考虑到数据生成时添加了随机噪声,这个RMSE值表明模型在捕捉数据的基本趋势方面表现的不错,相对较小的RMSE表示预测的准确。

如果要改进LSTM,可以从贝叶斯超参数调优、增加更多训练周期(EPOCH)、尝试不同网络架构,或者在数据预处理时更复杂一些。

事实上,在RMSE上,SARIMA比LSTM更小,但非线性模式/利用长期依赖性的复杂时间序列数据时,LSTM更好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1480687.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

B端系统:导航机制设计,用户体验提升的法宝

Hi,大家好,我是贝格前端工场,从事8年前端开发的老司机。很多B端系统体验不好很大一部分原因在于导航设计的不合理,让用户无所适从,大大降低了操作体验,本文着重分析B端系统的导航体系改如何设计&#xff0c…

android零基础入门,零基础入门android

工欲行其事,必先利其器 1.B4A B4A是Android的基础版,这是一种可简化编程的Android的应用程序开发工具。这是一个IDE,可以允许开发者使用Basic语言来创建Android移动应用。Basic语言是一种过程化编程语言,因为其简单易学&#xff…

SCA软件成分同源分析-代码匹配技术

被检项目源代码的识别在多个语言解析器的支持下工作,根据不同匹配算法,可以计算与特征值索引数据库的匹配情况。针对强匹配算法,源代码的特征值必须与索引数据库的特征值一致,才可认为是该开源组件;针对非强匹配算法&a…

快速批量检测paypal账号

在跨境电商中,通常用多个paypal来收款,但是paypal账号经常会被封禁,如何快速查看paypal账号是否正常,成为跨境电商一个难题。 发现一个工具网站,可以试试

从预训练到通用智能(AGI)的观察和思考

1.预训练词向量 预训练词向量(Pre-trained Word Embeddings)是指通过无监督学习方法预先训练好的词与向量之间的映射关系。这些向量通常具有高维稠密特征,能够捕捉词语间的语义和语法相似性。最著名的预训练词向量包括Google的Word2Vec&#…

写作软件,批量写作文章的软件

在信息爆炸的时代,写作软件成为许多人提高效率、优化内容的利器。本文将介绍6款不同的写作软件,以及知名的147GPT生成工具和文心一言AI生成软件,它们不仅可以帮助用户快速生成原创文章,还支持全自动SEO优化,提升文章在…

无字母数字rce总结(自增、取反、异或、或、临时文件上传)

目录 自增 取反 异或 或 临时文件上传 自增 自 PHP 8.3.0 起,此功能已软弃用 在 PHP 中,可以递增非数字字符串。该字符串必须是字母数字 ASCII 字符串。当到达字母 Z 且递增到下个字母时,将进位到左侧值。例如,$a Z; $a;将…

GaussDB跨云容灾:实现跨地域的数据库高可用能力

背景 金融、银行业等对数据的安全有着较高的要求,同城容灾建设方案,在绝大多数场景下可以保证业务数据的安全性,但是在极端情况下,如遇不可抗力因素等,要保证数据的安全性,就需要采取跨地域的容灾方案。 …

专访win战略会任志雄:澳门旅游业复苏 挖掘游客消费潜力

南方财经:各个国家地区的客商都有不同文化背景和消费习惯,应如何更好吸引外地客商来澳门? win战略会任志雄:首先,周边国家的市场潜力都非常大,包括韩国、日本、越南和印度。 这些年来,这些国家的经济增长都很高,居民的出游比重也在持续增加,如果他们国家的居民把澳门作为一个重…

计算机专业大学生的简历,为何会出现在垃圾桶

为什么校招过后垃圾桶里全是简历,计算机专业的学生找工作有多难? 空哥这么跟你说吧,趁现在还来得及,这些事情你一定要听好了。 第一,计算机专业在学校学的东西是非常有限的,985211的还好,如果…

HS6621Cx 一款低功耗蓝牙SoC芯片 应用于键盘、鼠标和遥控器消费类产品

HS6621Cx是一款功耗优化的真正片上系统 (SOC)解决方案,适用于低功耗蓝牙和专有2.4GHz应用。它集成了高性能、低功耗射频收发器,具有蓝牙基带和丰富的外设IO扩展。HS6621Cx还集成了电源管理功能,可提供高效的电源管理。它面向2.4GHz蓝牙低功耗…

Spring AI上架,打造专属业务大模型,AI开发再也不是难事!

Spring AI 来了 Spring AI 是 AI 工程师的一个应用框架,它提供了一个友好的 API 和开发 AI 应用的抽象,旨在简化 AI 应用的开发工序。 提供对常见模型的接入能力,目前已经上架 https://start.spring.io/,提供大家测试访问。&…

第二节 数学知识补充

一、线性代数 向量的 L 2 L_2 L2​范数(Euclidean范数/Frobenius范数)&矩阵的元素形式范数 向量的 L 2 L_2 L2​范数: ∣ ∣ x ∣ ∣ 2 ( ∣ x 1 ∣ 2 ⋯ ∣ x m ∣ 2 ) 1 2 ||x||_2(|x_1|^2\cdots|x_m|^2)^{\frac12} ∣∣x∣∣2​(∣…

【Java项目介绍和界面搭建】拼图小游戏——添加图片

🍬 博主介绍👨‍🎓 博主介绍:大家好,我是 hacker-routing ,很高兴认识大家~ ✨主攻领域:【渗透领域】【应急响应】 【Java】 【VulnHub靶场复现】【面试分析】 🎉点赞➕评论➕收藏 …

LLM 聊天对话界面chatwebui 增加实时语音tts功能

类似豆包聊天,可以实时语音回复 1、聊天界面 streamlit页面 参考界面:https://blog.csdn.net/weixin_42357472/article/details/133199866 stream_web.py 2、 增加实时语音tts功能(接入melotts api服务) 参考:https://blog.csdn.net/weixin_42357472/article/detai…

解决GitHub无法访问的问题:手动修改hosts文件与使用SwitchHosts工具

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨ 🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua,在这里我会分享我的知识和经验。&#x…

Javaweb day7

前后端分类开发 Yapi 环境配置 vue项目简介 项目启动 更改端口号 vue项目开发流程

RK3568驱动指南|第二篇 字符设备基础-第7章 menuconfig图形化配置实验(二)

7.2 Kconfig 语法简介 上一小节我们打开的图形化配置界面是如何生成的呢?图形化配置界面中的每一个界面都会对应一个Kconfig文件。所以图形化配置界面的每一级菜单是由Kconfig文件来决定的。 图形化配置界面有很多菜单。所以就会有很多Kconfig文件,这也就…

优思学院|3步骤计算出Cpk|学习Minitab

在生产和质量管理中,准确了解和控制产品特性至关重要。一个关键的工具是Cpk值,它是衡量生产过程能力的重要指标。假设我们有一个产品特性的规格是5.080.02,通过收集和分析过程数据,我们可以计算出Cpk值,进而了解生产过…

node.js和electron安装

文章目录 一、node.js安装1.node.js下载安装2.设置镜像 二、其它问题1.文件夹创建错误2.electron安装错误 一、node.js安装 1.node.js下载安装 参考B站视频node.js安装,没有按视频中设置镜像 2.设置镜像 参考:https://npmmirror.com/ npm config se…