Python基于PyTorch实现循环神经网络回归模型(LSTM回归算法)项目实战

news2024/12/29 9:59:06

说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取。




1.项目背景

LSTM网络是目前更加通用的循环神经网络结构,全称为Long Short-Term Memory,翻译成中文叫作“长‘短记忆’”网络。读的时候,“长”后面要稍作停顿,不要读成“长短”记忆网络,因为那样的话,就不知道记忆到底是长还是短。本质上,它还是短记忆网络,只是用某种方法把“短记忆”尽可能延长了一些。

本项目通过基于PyTorch实现循环神经网络回归模型。

2.数据获取

本次建模数据来源于网络(本项目撰写人整理而成),数据项统计如下:

数据详情如下(部分展示):

 

3.数据预处理

3.1 用Pandas工具查看数据

使用Pandas工具的head()方法查看前五行数据:

关键代码:

 

3.2 数据缺失查看

使用Pandas工具的info()方法查看数据信息:

 

从上图可以看到,总共有11个变量,数据中无缺失值,共2000条数据。

关键代码:

 

3.3 数据描述性统计

通过Pandas工具的describe()方法来查看数据的平均值、标准差、最小值、分位数、最大值。  

关键代码如下:

 

4.探索性数据分析

4.1 y变量直方图

用Matplotlib工具的hist()方法绘制直方图:

 

从上图可以看到,y变量主要集中在-400~400之间。

4.2 相关性分析

 

从上图中可以看到,数值越大相关性越强,正值是正相关、负值是负相关。

5.特征工程

5.1 建立特征数据和标签数据

关键代码如下:

5.2 数据集拆分

通过train_test_split()方法按照80%训练集、20%测试集进行划分,关键代码如下:

6.构建循环神经网络回归模型

主要使用LSTM回归算法,用于目标回归。

6.1 构建模型

 

7.模型评估

7.1 评估指标及结果

评估指标主要包括可解释方差值、平均绝对误差、均方误差、R方值等等。

 

 从上表可以看出,R方0.9871,为模型效果良好。

关键代码如下:

 

7.2 真实值与预测值对比图

 

从上图可以看出真实值和预测值波动基本一致,模型拟合效果良好。   

8.结论与展望

综上所述,本文基于PyTorch实现循环神经网络回归模型,最终证明了我们提出的模型效果良好。此模型可用于日常产品的预测。

# 定义训练函数
def train(model, train_loader, criterion, optimizer, device):
    model.train()  # 设置训练模式

    for i, (inputs, labels) in enumerate(train_loader):  # 进行循环
        inputs, labels = inputs.to(device), labels.to(device)  # 输入数据、标签数据

        optimizer.zero_grad()  # 清空过往梯度



本次机器学习项目实战所需的资料,项目资源如下:

项目说明:
链接:https://pan.baidu.com/s/1dW3S1a6KGdUHK90W-lmA4w 
提取码:bcbp



# y变量分布直方图
fig = plt.figure(figsize=(8, 5))  # 设置画布大小
plt.rcParams['font.sans-serif'] = 'SimHei'  # 设置中文显示
plt.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题
data_tmp = df['y']  # 过滤出y变量的样本
# 绘制直方图  bins:控制直方图中的区间个数 auto为自动填充个数  color:指定柱子的填充色
plt.hist(data_tmp, bins='auto', color='g')

更多项目实战,详见机器学习项目实战合集列表:

机器学习项目实战合集列表_机器学习实战项目_胖哥真不错的博客-CSDN博客


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/784042.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机视觉(二)图像特征提取

文章目录 颜色特征量化颜色直方图适用颜色空间:RGB、HSV等颜色空间操作 几何特征边缘 Edge边缘定义边缘提取 基于关键点的特征描述子引入几何特征:关键点几何特征:Harris角点FAST角点检测几何特征:斑点局部特征:SIFT预…

GPT-4 模型详细教程

GPT-4(Generative Pretrained Transformer 4)是 OpenAI 的最新语言生成模型,其在各类文本生成任务中表现优秀,深受开发者和研究者喜爱。这篇教程将帮助你理解 GPT-4 的基本概念,并向你展示如何使用它来生成文本。 什么…

前端vue入门(纯代码)35_导航守卫

星光不问赶路人,时光不负有心人 【33.Vue Router--导航守卫】 导航守卫 正如其名,vue-router 提供的导航守卫主要用来通过跳转或取消的方式守卫导航。有多种机会植入路由导航过程中:全局的, 单个路由独享的, 或者组件级的。 记住参数或查…

uniapp JS文件里面调用自定义组件(不用每个页面在template中加组件标签)

前言 工具:uniapp 开发端:微信小程序 其他:uview 2.0 场景:路由器里面,统一验证是否已登录,如果没登录,则直接弹出登录弹窗出来,不管哪个页面都如此。 效果如下: 直接上…

【笔试强训选择题】Day29.习题(错题)解析

作者简介:大家好,我是未央; 博客首页:未央.303 系列专栏:笔试强训选择题 每日一句:人的一生,可以有所作为的时机只有一次,那就是现在!!!&#xff…

rsync—远程同步

目录 一:rsync概述 1.1rsync简介 1.2rsync同步方式 二:rsync特性 三:rsync同步源 四:rsync与cp、scp对比 五:常用rsync命令 六:rsync本地复制实例 七:配置源的俩种表示方法 八&#x…

[NLP]Huggingface模型/数据文件下载方法

问题描述 作为一名自然语言处理算法人员,hugging face开源的transformers包在日常的使用十分频繁。在使用过程中,每次使用新模型的时候都需要进行下载。如果训练用的服务器有网,那么可以通过调用from_pretrained方法直接下载模型。但是就本人…

安全DNS,状态码,编码笔记整理

一 DNS DNS(Domain Name System)是互联网中用于将域名转换为IP地址的系统。 DNS的主要功能包括以下几个方面: 域名解析:DNS最主要的功能是将用户输入的域名解析为对应的IP地址。当用户在浏览器中输入一个域名时,操作…

工程安全监测无线振弦采集仪在建筑物中的应用

工程安全监测无线振弦采集仪在建筑物中的应用 工程安全监测无线振弦采集仪是一种用于建筑物结构安全监测的设备,它采用了无线传输技术,具有实时性强、数据精度高等优点,被广泛应用于建筑物结构的实时监测和预警。下面将从设备的特点、应用场…

力扣热门100题之接雨水【困难】

题目描述 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子,下雨之后能接多少雨水。 示例 1: 输入:height [0,1,0,2,1,0,1,3,2,1,2,1] 输出:6 解释:上面是由数组 [0,1,0,2,1,0,1,3…

如何使用GPT作为SQL查询引擎的自然语言

​生成的AI输出并不总是可靠的,但是下面我会讲述如何改进你的代码和查询的方法,以及防止发送敏感数据的方法。与大多数生成式AI一样,OpenAI的API的结果仍然不完美,这意味着我们不能完全信任它们。幸运的是,现在我们可以…

Packet Tracer – 配置动态 NAT

Packet Tracer – 配置动态 NAT 拓扑图 目标 第 1 部分:配置动态 NAT 第 2 部分:验证 NAT 实施 第 1 部分: 配置动态 NAT 步骤 1: 配置允许的流量。 在 R2 上,为 ACL 1 配置一个语句以允许属于 172.16.0.…

【JVM】浅看JVM的运行流程和垃圾回收

1.JVM是什么 JVM( Java Virtual Machine)就是Java虚拟机。 Java的程序都运行在JVM中。 2.JVM的运行流程 JVM的执行流程: 程序在执行之前先要把java代码转换成字节码(class文件),JVM 首先需要把字节码通过…

Visio/PPT/Matlab输出300dpi以上图片【满足标准投稿要求】

1. visio 遵照如下输出选项,另存为tif格式文件时,选择正确输出便是300dpi以上 2. matlab 文件选项选中导出设置,在渲染中选择dpi为600,导出图片即可,科研建议选择tif格式文件 3.ppt 打开注册表,winr键…

【报错】在python3.9环境下安装sqlmap报错

问题描述 报错内容: missing one or more core extensions (‘ssl’, ‘sqlite3’) most likely because current version of Python has been built without appropriate dev packages 原因分析: 缺少一个或多个核心扩展(‘ssl’、‘sqlit…

常见的栈溢出StackOverFlow 与 内存溢出OutOfMemory的区别

0、前言:内存模型 对于多线程运行情况下的jvm内存,我们应当知道: 每创建一个线程,jvm就会为其分配一块线程私有的工作内存,其中包括程序计数器、栈,等等。 对于每一个线程私有的栈,当线…

怎么限制文件打开次数、打开时间?

一些公司出于业务需求,可能会给客户或者合作伙伴发一些涉密图纸、文档、文件等重要文件,但是又不想文件被外发泄露随意传播,今天就教大家一个方法限制文件外发后别人打开这个文件的打开次数、打开时间、另存为等操作。 设置方法 本篇文章测试…

page _refcount和_mapcount字段

linux page有两个非常重要的引用计数字段_refcount和_mapcount,都是atomic_t类型,其中,_refcount表示内核中应用该page的次数。当_refcount 0时,表示该page为空闲或者将要被释放。当_refcount > 0,表示该page页面已…

APP-脱壳+反编译

APP反编译加固-自动查壳脱壳 为什么要脱壳? 因为不脱壳无法进行反编译 查壳工具:https://pan.baidu.com/s/1rDfsEvqQwhUmep1UBLUwSQ 密码: wefd 脱壳工具:https://github.com/CodingGay/BlackDex 查壳演示: 使用Java运行jar包&a…