程序员学长 | 快速学会一个算法模型,LSTM

news2025/1/23 13:59:02

本文来源公众号“程序员学长”,仅用于学术分享,侵权删,干货满满。

原文链接:快速学会一个算法模型,LSTM

今天,给大家分享一个超强的算法模型,LSTM。

LSTM(Long Short-Term Memory)是一种特殊类型的循环神经网络(RNN),专门设计用来解决传统 RNN 在处理序列数据时面临的长期依赖问题

LSTM 的关键特征是其维持细胞状态的能力,细胞状态充当可以存储长序列信息的记忆单元。这使得 LSTM 能够随着时间的推移选择性地记住或忘记信息,使它们非常适合上下文和远程依赖性至关重要的任务。

LSTM 的核心组件

LSTM 单元由以下几个主要部分组成

案例分享

加载数据集
import numpy as np
import pandas as pd
from keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.metrics import RootMeanSquaredError
from tensorflow.keras.optimizers import Adam
from keras.layers import LSTM, Dense, InputLayer
from sklearn.metrics import mean_squared_error as mse
from time import time
import matplotlib.pyplot as plt
import matplotlib
import warnings

path_data = r'filter_pt_data.csv'
df = pd.read_csv(path_data)
df.dropna(inplace=True)
df['dt'] = pd.to_datetime(df['dt'])
df.set_index('dt', inplace=True)
df['Seconds'] = df.index.map(pd.Timestamp.timestamp)
year_secs = 60 * 60 * 24 * 365  # Number of seconds in a year
df['year_signal_sin'] = np.sin(df['Seconds'] * (2 * np.pi / year_secs))
df['year_signal_cos'] = np.cos(df['Seconds'] * (2 * np.pi / year_secs))
df.drop(columns=['Seconds'], inplace=True)
准备数据序列

LSTM 模型是专门为处理数据点序列而设计的,因此需要将数据转换为这种格式。

该方法涉及将预测问题转换为监督学习范式。在此设置中,输入 (X) 包含前面的 n 个数据点,而输出 (y) 表示后续时间步的目标值。

为了说明这个概念,假设我们正在使用包含三个特征(“a”、“b”和“c”)的数据集。我们的目标是预测特征 “a”。在这种情况下,我们的输入序列将包含三个时间戳,这意味着我们将检查三个连续时间点的特征值。

def create_sequences_unistep(data, n_steps):
    data_t = data.to_numpy()
    X = []
    y = []

    for i in range(len(data_t)-n_steps):
      row = [a for a in data_t[i:i+n_steps]]
      X.append(row)

      label = data_t[i+n_steps][0]
      y.append(label)

    return np.array(X), np.array(y)
创建模型
def train_model(X, y, X_val, y_val, n_steps, n_preds=1):
    n_features = X.shape[2]
    
    # Create lstm model
    model = Sequential()
    model.add(InputLayer((n_steps, n_features)))
    model.add(LSTM(4, return_sequences=True))
    model.add(LSTM(5))
    model.add(Dense(n_preds, activation='linear'))
    
    # Compile model
    model.compile(loss=MeanSquaredError(), optimizer=Adam(learning_rate=0.0001), metrics=[RootMeanSquaredError()])
    
    model.summary()
    
    # Save model with the least validation loss
    checkpoint_filepath = 'cps/best_model.h5'
    model_checkpoint_callback = ModelCheckpoint(
        filepath=checkpoint_filepath,
        monitor='val_loss',  # Monitor validation loss
        mode='min',          # Save the model with the minimum validation loss
        save_best_only=True)
    
    # Stop training if validation loss does not improve in 500 epochs
    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=50,  # Stop training if no improvement in validation loss for 100 epochs
        mode='min',
        verbose=1,
        restore_best_weights=True) # when finish train restore best model
    
    # Fit model
    ts = time()
    history = model.fit(X, y,
                        verbose=2,
                        epochs=500,
                        validation_data=(X_val, y_val),
                        callbacks=[model_checkpoint_callback, early_stopping_callback])
    tf = time()
    
    print('Time to train model: {} s'.format(round(tf - ts, 2)))
    
    # Plot loss evolution
    plt.figure()
    plt.plot(history.history['loss'], label='loss')
    plt.plot(history.history['val_loss'], label='val_loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    
    # Load best model
    del model
    model = load_model(checkpoint_filepath)
    
    return model
模型训练

首先,让我们使用之前实现的函数生成序列。我们将分配 500 个值用于训练,50 个值用于验证,并将 “n_steps” 参数设置为 5。

def preprocess_input(X, mean, std):
    X[:, :, 0] = (X[:, :, 0] - mean) / std
    return X

def preprocess_output(y, mean, std):
    y = (y - mean) / std
    return y

def postprocess_output(y, mean, std):
    y = (y * std) + mean
    return y

def plot_predictions_unistep(model, X_test, y_test, mean_ref, std_ref):

    preds = model.predict(X_test).flatten().tolist()

    # preprocess preds to actual scale
    preds = [postprocess_output(i, mean_ref, std_ref) for i in preds]
    y_t = [postprocess_output(i, mean_ref, std_ref) for i in y_test.tolist()]

    er = mse(y_test, preds)

    plt.figure(figsize=(12, 8))
    plt.plot(y_t, label='Actual values')
    plt.plot(preds, label='Predictions', alpha=.7)
    plt.legend()
    plt.title('MSE = {}'.format(er))

    return predsn_steps = 5
X, y = create_sequences_unistep(df, n_steps)

# Prepare train and validation data
nr_vals_train = 500
nr_vals_validation = 50

X_train = X[:nr_vals_train]
y_train = y[:nr_vals_train]

X_val = X[nr_vals_train: nr_vals_train + nr_vals_validation]
y_val = y[nr_vals_train: nr_vals_train + nr_vals_validation]

X_test = X[nr_vals_train:]
y_test = y[nr_vals_train:]

print('X train shape: {}'.format(X_train.shape))
print('y train shape: {}'.format(y_train.shape))

print('X validation shape: {}'.format(X_val.shape))
print('y validation shape: {}'.format(y_val.shape))

# Scale temp value with standard scaler -> mean 0 and std 1
mean_ref = np.mean(X_train[:, :, 0])
std_ref = np.std(X_train[:, :, 0])
# Scale X's
X_train = preprocess_input(X_train, mean_ref, std_ref)
X_val = preprocess_input(X_val, mean_ref, std_ref)
X_test = preprocess_input(X_test, mean_ref, std_ref)

# Scale y's
y_train = preprocess_output(y_train, mean_ref, std_ref)
y_val = preprocess_output(y_val, mean_ref, std_ref)
y_test = preprocess_output(y_test, mean_ref, std_ref)

model = train_model(X_train, y_train, X_val, y_val, n_steps)

# Plot train predictions set
plot_predictions_unistep(model, X_train, y_train, mean_ref, std_ref)

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1876763.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AI绘画 Stable Diffusion【实战进阶】:图片的创成式填充,竖图秒变横屏壁纸!想怎么扩就怎么扩!

大家好,我是向阳。 所谓图片的创成式填充,就是基于原有图片进行扩展或延展,在保证图片合理性的同时实现与原图片的高度契合。是目前图像处理中常见应用之一。之前大部分都是通过PS工具来处理的。今天我们来看看在AI绘画工具 Stable Diffusio…

微信内置H5支付

🧑‍💻 写在开头 点赞 收藏 学会🤣🤣🤣 场景是用户通过微信扫app内的收款码,跳到一个h5页面。然后完成支付。 代码实现的整体流程: 使用微信扫码,码是app内生成的,码…

Qt WPS(有源码)

项目源码地址:WPS完整源码 一.项目详情 该项目仿照WPS,实现了部分的功能,能够很方便对文本和HTML进行修改,并且有打印功能,可以很方便的生成PDF。 应用界面 项目架构分析 这个项目主要可分为两个部分,一…

等保2.0安全计算环境解读

等保2.0,即网络安全等级保护2.0制度,是中国为了适应信息技术的快速发展和安全威胁的新变化而推出的网络安全保护标准。相较于等保1.0,等保2.0更加强调主动防御、动态防御和全面审计,旨在实现对各类信息系统的全面保护。 安全计算环…

鼠尾草(洋苏草)

鼠尾草(Salvia japonica Thunb.),又名洋苏草、普通鼠尾草、庭院鼠尾草,属于唇形科鼠尾草属多年生草本植物。鼠尾草以其独特的蓝紫色花序和长而细密的叶片为特点,常用于花坛、庭院和药用植物栽培。 鼠尾草的名字源自于…

Java访问修饰符的区别

public:公开的,任何地方都可以访问。 protected:受保护的,同一个包中的类和所有子类(可跨包)可以访问。 private:私有的,只有在同一个类中可以访问。 默认(无修饰符):包级…

[Go 微服务] go-micro + consul 的使用

文章目录 1.go-micro 介绍2.go-micro 的主要功能3.go-micro 安装4.go-micro 的使用4.1 创建服务端4.2 配置服务端 consul4.3 生成客户端 5.goodsinfo 服务5.1 服务端开发5.2 客户端开发 1.go-micro 介绍 Go Micro是一个简化分布式开发 的微服务生态系统,该系统为开…

stm32学习笔记---ADC模数转换器(理论部分)

目录 ADC简介 什么叫逐次逼近型? STM32 ADC框图 模数转换器外围线路 ADC基本结构图 输入通道 规则组的四种转换模式 第一种:单次转换非扫描模式 第二种:连续转换,非扫描模式 第三种:单次转换,扫描…

数据结构03 链表的基本操作【C++数组模拟实现】

前言:本节内容主要了解链表的基本概念及特点,以及能够通过数组模拟学会链表的几种基本操作,下一节我们将通过STL模板完成链表操作,可以通过专栏进入查看下一节哦~ 目录 单链表及其特点 完整链表构成 完整链表简述 创建单链表 …

MySQL自增主键踩坑记录

对于MySQL的自增主键,本文记录、整理下在工作中实际遇到的问题。 下面示例均基于MySQL 8.0 修改列的类型后,自增属性消失 CREATE TABLE users (id INT AUTO_INCREMENT PRIMARY KEY,username VARCHAR(50) NOT NULL,email VARCHAR(100) NOT NULL );上面的…

计算机监控软件有哪些?10款常年霸榜的计算机监控软件

计算机监控软件是企业管理和保护信息安全的重要工具,它们帮助企业管理者监督员工的计算机使用行为,确保工作效率、数据安全以及合规性。在众多监控软件中,有些产品因其卓越的功能、易用性、安全性以及持续获得的良好市场反馈而常年占据行业领…

什么是指令微调(LLM)

经过大规模数据预训练后的语言模型已经具备较强的模型能力,能够编码丰富的世界知识,但是由于预训练任务形式所限,这些模型更擅长于文本补全,并不适合直接解决具体的任务。 指令微调是相对“预训练”来讲的,预训练的时…

spring mvc实现一个自定义Formatter请求参数格式化

使用场景 在Spring Boot应用中,Formatter接口用于自定义数据的格式化,比如将日期对象格式化为字符串,或者将字符串解析回日期对象。这在处理HTTP请求和响应时特别有用,尤其是在展示给用户或从用户接收特定格式的数据时。下面通过…

集合,Collection接口

可动态保存任意多个对象,使用比较方便 提供了一系列方便操作对象的方法:add,remove,set,get等 使用集合添加删除新元素,代码简洁明了 单列集合 多列集合 Collection接口 常用方法 List list new Arra…

【原创图解 算法leetcode 146】实现一个LRU缓存淘汰策略策略的数据结构

1 概念 LRU是Least Recently Used的缩写,即最近最少使用,是一种常见的缓存淘汰算法。 其核心思想为:当内存达到上限时,淘汰最久未被访问的缓存。 2 LeetCode LeetCode: 146. LRU缓存 3 实现 通过上面LRU的淘汰策略可知&#…

北京市大兴区餐饮行业协会成立暨职业技能竞赛总结大会成功举办

2024年6月27日下午,北京市大兴区营商服务中心B1层报告厅迎来了北京市大兴区餐饮行业协会成立仪式暨2024年北京市大兴区餐饮行业职工职业技能竞赛总结大会。此次活动不仅标志着大兴区餐饮行业协会的正式成立,也对在2024年大兴区餐饮行业职工职业技能竞赛中…

Python自动化测试:web自动化测试——selenium API、unittest框架的使用

web自动化测试2 1. 设计用例的方法——selenium API1.1 基本元素定位1)定位单个唯一元素2)定位一组元素3)定位多窗口/多框架4)定位连续层级5)定位下拉框6)定位div框 1.2 基本操作1.3 等待1.4 浏览器操作1.5…

SpringBoot整合Quartz实现动态定时任务

目录 1、Quartz简介1.1 Quartz的三大核心组件1.2 CronTrigger配置格式 2、SpringBoot整合Quartz框架2.1 创建项目2.2 实现定时任务 1、Quartz简介 Quartz是一个开源的任务调度服务,它可以独立使用,也可与其它的Java EE,Java SE应用整合使用。…

Python数据分析案例48——二手房价格影响因素分析

案例背景 房价影响因素也是人们一直关注的问题,本次案例也适合各种学科的同学,无论你是经济管理类还是数学统计,还是电商物流类,都可以使用回归分析。通过数据分析回归分析分组聚合可视化等方法进行研究房价影响因素。 数据介绍 …

2024下半年必追国漫片单,谁将问鼎巅峰?

随着2024年上半年的落幕,国漫市场再度迎来了百花齐放的盛况。从经典续作到全新IP,从玄幻到科幻,每一部作品都以其独特的魅力吸引着观众的目光。本期为大家盘点下半年值得一看的国漫佳作,大胆预测,谁将成为这场神仙打架…