【用pytorch进行LSTM模型的学习】

news2025/1/10 3:09:17

用pytorch进行LSTM模型的学习

  • LSTM模型
  • 用pytorch,采用LSTM对seaborn数据集做预测
    • 基本步骤
      • 数据的观察
      • 特殊数据处理
      • 数据归一化
      • 模型的构建与选择
      • 模型的保存
  • 飞机航班流量预测示例

LSTM模型

LSTM模型长下面这样,主要用在时间序列的预测,具有比RNN较好的性能。原因在于内部增加了很多门,用来控制前序信息的继续、遗忘、更新等,比RNN更好的表达了特征。
在这里插入图片描述

用pytorch,采用LSTM对seaborn数据集做预测

基本步骤

一般而言,进行深度学习的训练与应用包含大概如下步骤

=========工作流程=========
 - 数据读取与基本处理
    * 数据集读取
    * 数据的观察-画图
    * 特殊数据处理-空值、奇异值等
 - 数据集构建
    * 归一化
    * 训练集、验证集、测试集划分    
 - 模型建模
    * 基础模型架构
    * 损失函数
    * 优化器选择
 - 模型训练
    * 模型训练 与各种超参
    * 训练过程观察 
    * 训练中模型保存
    * 模型训练指标记录
 - 测试验证
    * 模型性能验证
    * 结果可视化
    * 测试性能指标记录

下面就流程中的几个重点进行说明

数据的观察

在拿到数据的时候,我们首先要对数据进行观察,观察的方法根据数据的类型略有不同,但是总体可以概括为

  • 肉眼观察:打开数据文件夹或者文件进行查看,比如文件个数有多少个,数据的大小是多少。
  • 数据展示观察:对于一些不好直接观察的,可以通过数据展示看一下,如打印dataframe结构的前几行,可以看到列名等信息,方便数据处理。
  • 画图观察:对于一些时序信息,可以通过作图的方式,看看数据的分布情况,是否有异常点等等。

为什么要对数据进行观察?主要有以下几个原因

  • 获取数据的基本信息,知道我们要处理的数据大概是怎样的。
  • 对原始数据有个感觉,数据的情况可能会影响我们模型的选择。以及模型训练的策略。比如小样本数据,样本数的多少会影响下一步的决策,如是否数据增强,是否迁移等等。
  • 观察到异常情况,如空值,奇异点,为下一步数据处理做准备。

特殊数据处理

机器学习处理的是数据的一般情况,即反映数据的一般规律和一般分布,对于奇异值或者特殊值,机器学习模型没有能力处理或者需要付出很大的代价才能处理。机器学习是帮助我们解决一般问题或者共性问题,对于一些特殊的问题,并不是这个学科的主要研究方向。当然,只有一个方向除外,即异常检测。
一般需要特殊处理的,有空值、错误值、奇异值。基本的处理方式有

  • 删除,即删除特殊值
  • 补全,补全空值
  • 修正,更改错误值

数据归一化

在一般情况下,尤其是时序数据,需要进行归一化,即把数据压缩到0-1之间。目的是使得数据有相同的尺度。例如,在一个数据集中,包含样本的年龄信息,收入信息等,这两个信息的度量尺度是不同的,如果不做归一化,那么由于年龄与收入在数值上相差很大,那么年龄的特征不能在模型中发挥很好的作用。

模型的构建与选择

针对不同的任务选择不同的模型,有pytorch内置了很多基础模型,因此模型结构的构建变得简单容易,需要注意的是模型的输入参数要求以及维度匹配,这就需要我们学习pytorch内置模型的接口函数,做一个合格的调包侠

模型的保存

在训练过程中,模型是不断更新的,每一次迭代后模型的参数就会不同。在这个过程中有必要有条件地保存下当前模型,主要有如下几个用途

  • 防止训练突然崩掉,重新训练浪费资源。在较长时间的训练过程中,由于种种原因,训练可能会崩溃,如突然掉电,机器故障灯,如果没有保存训练过程中的模型,则需要重新训练,那么浪费时间,浪费资源,尤其是接近训练完成的时候发生崩溃,人就更崩溃了。如果保存了模型,那么可以重新加载模型,断点续训练。
  • 根据过程中保存下来的模型,我们可以查看模型演变过程,进行过程的考察。
  • 测试验证用,保存模型,尤其是保存最后的或者最好的模型,在测试验证时,可以直接加载进行验证,不必再次训练

那么模型该如何保存呢? 模型保存的格式:pytorch中最常见的模型保存使用 .pt 或者是 .pth 作为模型文件扩展名。

pytorch模型保存的两种方式:

  • 一种是保存整个模型,
torch.save(model, "my_model.pth") # 保存整个模型` 
  • 另一种是只保存模型的参数,该方法速度快,占用空间少
torch.save(model.state_dict(), "my_model.pth") # 只保存模型的参数

相应的,加载也有两种方式

  • 加载整个模型
new_model = torch.load(PATH) 
  • 先构架模型架构,然后加载参数
new_model = Model()                          
new_model.load_state_dict(torch.load(PATH))   

飞机航班流量预测示例

完整代码如下

# -*- coding: utf-8 -*-
# @Time    : 2023/03/10 10:23
# @Author  : HelloWorld!
# @FileName: seq.py
# @Software: PyCharm
# @Operating System: Windows 10
# @Python.version: 3.8

import torch
import torch.nn as nn
import argparse
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math


# 数据读取与基本处理
class LoadData:
    def __init__(self,data_path ):
        self.ori_data = pd.read_csv(data_path)
    def data_observe(self):
        self.ori_data.head()
        self.draw_data(self.ori_data)
    def draw_data(self, data):
        print(data.head())
        fig_size = plt.rcParams["figure.figsize"]
        fig_size[0] = 15
        fig_size[1] = 5
        plt.rcParams["figure.figsize"] = fig_size
        plt.title('Month vs Passenger')
        plt.ylabel('Total Passengers')
        plt.xlabel('Months')
        plt.grid(True)
        plt.autoscale(axis='x', tight=True)
        plt.plot(data['passengers'])
        plt.show()
    #数据预处理,归一化
    def data_process(self):
        flight_data = self.ori_data.drop(['year'], axis=1)  # 删除不需要的列
        flight_data = flight_data.drop(['month'], axis=1)  # 删删除不需要的列

        flight_data = flight_data.dropna()  # 滤除缺失数据
        dataset = flight_data.values  # 获得csv的值
        dataset = dataset.astype('float32')
        dataset=self.data_normalization(dataset)
        return dataset

    def data_normalization(self,x):
        '''
        数据归一化(0,1)
        :param x:
        :return:
        '''
        max_value = np.max(x)
        min_value = np.min(x)
        scale = max_value - min_value
        y = (x - min_value) / scale
        return y

#构建数据集,训练集、测试集
class CreateDataSet:
    def __init__(self, dataset,look_back=2):
        dataset = np.asarray(dataset)
        data_inputs, data_target = [], []
        for i in range(len(dataset) - look_back):
            a = dataset[i:(i + look_back)]
            data_inputs.append(a)
            data_target.append(dataset[i + look_back])
        self.data_inputs = np.array(data_inputs).reshape((-1, look_back))
        self.data_target = np.array(data_target).reshape((-1, 1))

    def split_train_test_data(self, rate=0.7):
        # 划分训练集和测试集,70% 作为训练集
        train_size = math.ceil(len(self.data_inputs) * rate)  #math.ceil()向上取整
        train_inputs = self.data_inputs[:train_size]
        train_target = self.data_target[:train_size]
        test_inputs = self.data_inputs[train_size:]
        test_target = self.data_target[train_size:]
        return train_inputs, train_target, test_inputs, test_target
# 构建模型
class LSTMModel(nn.Module):
    ''' 定义LSTM模型,由于pytorch已经集成LSTM,直接用即可'''
    def __init__(self, input_size, hidden_size=4, num_layers=2, output_dim=1):
        '''

        :param input_size:  输入数据的特征维数,通常就是embedding_dim(词向量的维度)
        :param hidden_size: LSTM中隐层的维度
        :param num_layers: 循环神经网络的层数
        :param output_dim:
        '''
        super(LSTMModel,self).__init__()
        self.lstm_layer=nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers)
        self.linear_layer=nn.Linear(hidden_size,output_dim)
    def forward(self,x):
        x,_=self.lstm_layer(x)
        s, b, h = x.shape
        x = x.view(s * b, h)  # 转换成线性层的输入格式
        x=self.linear_layer(x)
        x= x.view(s, b, -1)
        return x
#模型训练
class Trainer:
    def __init__(self,args):
        self.num_epoch =args.num_epoch
        self. look_back=args.look_back
        self.batch_size=args.batch_size
        self.save_modelpath=args.save_modelpath #保存模型的位置
        load_data = LoadData(args.filepath)  # 加载数据
        self.dataset = load_data.data_process()  # 数据预处理
        dataset = CreateDataSet(self.dataset , look_back=args.look_back)  # 数据集开始构建

        self.train_inputs,  self.train_target,  self.test_inputs,  self.test_target = dataset.split_train_test_data()  # 拆分数据集为训练集、测试集
        self.data_inputs = dataset.data_inputs
        #改变下输入形状
        self.train_inputs = self.train_inputs.reshape(-1, self.batch_size, self.look_back)
        self.train_target = self.train_target.reshape(-1, self.batch_size, 1)
        self.test_inputs = self.test_inputs.reshape(-1, self.batch_size, self.look_back)
        self.data_inputs = self.data_inputs.reshape(-1, self.batch_size, self.look_back)

        self.model=self.build_model()
        self.loss =nn.MSELoss()
        self.optimizer=torch.optim.Adam(self.model.parameters(), lr=1e-2)
    def build_model(self):
        model=LSTMModel(input_size=self.look_back)
        return  model

#训练过程
    def train(self):
        #把数据转成torch形式的
        inputs= torch.from_numpy(self.train_inputs)
        target=torch.from_numpy(self.train_target)
        self.model.train() #训练模式
        #开始训练
        for epoch in range(self.num_epoch):
            #前向传播
            out=self.model(inputs)
            #计算损失
            loss=self.loss(out,target)
            #反向传播
            self.optimizer.zero_grad()  #梯度清零
            loss.backward()  #反向传播
            self.optimizer.step() #更新权重参数
            if epoch % 100 == 0:  # 每 100 次输出结果
                print('Epoch: {}, Loss: {:.5f}'.format(epoch, loss.item()))
                torch.save(self.model,self.save_modelpath+'/model'+str(epoch)+'.pth')
        torch.save(self.model, self.save_modelpath + '/model_last' +  '.pth')
        self.test()
    def test(self,load_model=False):

        if not load_model:
            self.model.eval()  # 转换成测试模式
            inputs = torch.from_numpy(self.data_inputs)
            # inputs = torch.from_numpy(self.test_inputs)
            output = self.model(inputs)  # 测试集的预测结果
        else:
            model=torch.load(self.save_modelpath+ '/model_last' +  '.pth')
            inputs = torch.from_numpy(self.data_inputs)
            # inputs = torch.from_numpy(self.test_inputs)
            output =model(inputs)  # 测试集的预测结果
        # 改变输出的格式
        output = output.view(-1).data.numpy() #把tensor摊平
        # 画出实际结果和预测的结果
        plt.plot(output, 'r', label='prediction')
        plt.plot(self.dataset, 'g', label='real')
        # plt.plot(self.dataset[1:], 'b', label='real')
        plt.legend(loc='best')
        plt.show()

if __name__ == '__main__':
    filepath ='seaborn-data-master/flights.csv'
    save_modelpath='model-path'
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--num_epoch',type=int, default=1000, help='训练的轮数' )
    parser.add_argument('--filepath',type=str, default=filepath, help='数据文件')
    parser.add_argument('--look_back', type=int, default=2, help='根据前几个数据预测')
    parser.add_argument('--batch_size', type=int, default=2, help='batch size')
    parser.add_argument('--save_modelpath',type=str, default=save_modelpath, help='训练中模型要保存的位置')

    args=parser.parse_args()

    train=Trainer(args)
    train.train()
    train.test(load_model=True)



结果如下
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/630208.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

5Why分析法

5Why分析法 由丰田公司的大野耐一提出的对一个问题点连续以5个“为什么”来自问,以追究其根本原因的分析方法。 模型介绍 所谓5Why分析法,又称“5问法”,也就是对一个问题点连续以5个“为什么”来提问,以追究其根本原因。虽为5个…

代码随想录第53天

1.最长公共子序列: 红字的问题都是和最长重复子数组那题的代码进行比较的出来的 动规五部曲分析如下: 确定dp数组(dp table)以及下标的含义 dp[i][j]:长度为[0, i - 1]的字符串text1与长度为[0, j - 1]的字符串tex…

solidity之智能拍卖案例

文章目录 实现一个简易的拍卖状态变量定义和初始化竞拍功能结束竞拍代码 实现一个简易的拍卖 角色分析:4类角色(拍卖师actioneer,委托人seller,竞买人bidder,买受人buyer) 功能分析:拍卖的基本…

Shell脚本攻略:Linux防火墙(一)

目录 一、理论 1.安全技术 2.防火墙 3.通信五元素和四元素 4.总结 二、实验 1.iptables基本操作 2.扩展匹配 3. 自定义链接 一、理论 1.安全技术 (1)安全技术 ①入侵检测系统(Intrusion Detection Systems)&#xff1…

汽车电子AUTOSAR之BswM模块

目录 前言 正文 总体设计框架 模式仲裁过程 模式控制过程 模式仲裁 模式请求来源(ModeRequestPorts) 模式条件(ModeCondition) 逻辑表达式(LogicExpressions) 模式规则(ModeRules) 模式规则的初始化 模式控制 模式控制基本流程 模式行为 常用函数接口 前言 首先&…

Dependency not found解决方案(Springboot,绝对有效)

目录 问题描述解决方案systemPathmvn install 问题描述 今天在弄一个项目的依赖的时候,easyexcel 的依赖就是下载不了,虽然我的 Maven 配置没问题。 依赖:    Maven 配置:    我切换了几个版本,也无法从镜像下…

git diff去除^M的方法

一,简介 本文主要介绍在git修改的时候,修改文件后,git diff查看修改内容时,发现修改的地方每行结束的地方都会有“^M”,很影响查看。故今天分享一种去除“ ^M”显示的方法,供参考。 二,问题原…

案例29:基于Springboot医疗挂号系统开题报告设计

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

最新版本Portraiture4.1中文版ps磨皮滤镜插件安装包

在Portraiture有非常强大的手动功能,可以为用户进行手动调整照片中的皮肤区域以达到更加完美的效果,软件还支持同时导入上千张照片,用户可以通过自动识别照片中的人脸从而依照自己的风格进行批量处理十分的方便快捷。 最新版本Portraiture 4…

空气污染气象学期末复习笔记

空气污染气象学 (一)研究什么 运用气象学方法研究空气污染物自排放源进入大气层后的散布规律,核心是研究大气输送和扩散 (二)大气污染 大气污染是指由于人类活动或自然过程引起某种物质进入大气中,呈现出足…

Mysql数据库入门基础篇--mysql 多表查询

【Mysql数据库入门基础篇--mysql 多表查询 🔻一、mysql 多表查询1.1 🍃 7种sql joins 的实现1.2 🍃 错误写法---笛卡尔积错误1.3 🍃 正确的多表select写法 🔻二、内连接( inner) join🔻三、 外连接&#xf…

【LeetCode】23. 合并 K 个升序链表

23. 合并 K 个升序链表(困难) 方法一:顺序合并 思路 ListNode* mergeTwoLists(ListNode *a, ListNode *b) {if ((!a) || (!b)) return a ? a : b;ListNode head, *tail &head, *aPtr a, *bPtr b;while (aPtr && bPtr) {if (…

【第十期】Apache DolphinScheduler 每周 FAQ 集锦

点击蓝字 关注我们 摘要 为了让 Apache DolphinScheduler 的广大用户和爱好者对于此项目的疑问得到及时快速的解答,社区特发起此次【每周 FAQ】栏目,希望可以解决大家的实际问题。 关于本栏目的要点: 本栏目每周将通过腾讯文档(每…

卡尔曼滤波与组合导航原理(十二)扩展卡尔曼滤波:EKF、二阶EKF、迭代EKF

文章目录 一、多元向量的泰勒级数展开二、扩展Kalman滤波三、二阶滤波四、迭代EKF滤波 一、多元向量的泰勒级数展开 { y 1 f 1 ( X ) f 1 ( x 1 , x 2 , ⋯ x n ) y 2 f 2 ( X ) f 2 ( x 1 , x 2 , ⋯ x n ) ⋮ y m f m ( X ) f m ( x 1 , x 2 , ⋯ x n ) \left\{\begin{…

大家都说Java有三种创建线程的方式,并发编程中的惊天骗局

在Java中,创建线程是一项非常重要的任务。线程是一种轻量级的子进程,可以并行执行,使得程序的执行效率得到提高。Java提供了多种方式来创建线程,但许多人都认为Java有三种创建线程的方式,它们分别是继承Thread类、实现…

论文浅尝 | Dually Distilling KGE for Faster and Cheaper Reasoning

笔记整理:张津瑞,天津大学硕士,研究方向为知识图谱 链接:https://dl.acm.org/doi/10.1145/3488560.3498437 动机 知识图谱已被证明可用于各种 AI 任务,如语义搜索,信息提取和问答等。然而众所周知&#xff…

【C++】C++11常用新特性

✍作者:阿润菜菜 📖专栏:C 目录 一、统一的列表初始化二、 简化声明2.1 auto2.2 decltype2.3 nullptr 三、右值引用和移动语义 -- 重要3.1 区分左值引用和右值引用3.2 对比左值引用看看右值引用使用价值3.3 万能引用和完美转发(st…

基于word文档,使用Python输出关键词和词频,并将关键词的词性也标注出来

点击上方“Python爬虫与数据挖掘”,进行关注 回复“书籍”即可获赠Python从入门到进阶共10本电子书 今 日 鸡 汤 移船相近邀相见,添酒回灯重开宴。 大家好,我是Python进阶者。 一、前言 前几天在有个粉丝问了个问题,大概意思是这样…

一道北大强基题背后的故事(三)——什么样的题是好题?

早点关注我,精彩不错过! 上回我们针对这道北大强基题[((1 sqrt(5)) / 2) ^ 12]在答案的基础上给出了出题的可能思路,想一探究竟,相关内容请戳: 一道北大强基题背后的故事(二)——出题者怎么想的…

【Kubernetes入门】Service四层代理入门实战详解

文章目录 一、Service四层代理概念、原理1、Service四层代理概念2、Service工作原理3、Service原理解读4、Service四种类型 二、Service四层代理三种类型案例1、创建ClusterIP类型Service2、创建NodePort类型Service3、创建ExternalName类型Service 三、拓展1、Service域名解析…