深度学习-循环神经网络-LSTM对序列数据进行预测

news2024/11/25 20:04:07

 项目简介:

使用LSTM模型, 对文本数据进行预测, 

每次截取字符20, 对第二十一个字符进行预测, 

LSTM层: units=100, activation=relu

Dense层: units=输入的文本中的字符种类, 比如我使用的文本有644个不同的字符, 那么units=64

                激活函数: 因为是多分类, 使用softmax

                因为这是最后一层, 所以输出神经元的个数也就是644

# ===================================================================
# 1.数据导入和数据预处理
# 读入数据
txt_data= open(r"C:\Users\鹰\Desktop\AI Assistant.txt", encoding='utf-8').read()
# 数据预处理

# 移除换行符
txt_data=txt_data.replace('\n','').replace('\r','').replace('#', '').replace('*','').replace('=','').replace('-','')
# print(txt_data)
# 字符去重
letters = list(set(txt_data))
# print(letters)
letters_num=len(letters)
# print(letters_num)

# 建立字典,让字符与数字对应
# 话说int_to_char这个字典是干嘛的, 到底没看出
int_to_char={a:b for a,b in enumerate(letters)}
# print(int_to_char)

char_to_int={b:a for a,b in enumerate(letters)}
# print(char_to_int)


# 设定time_step=20, 就是每次在文本中截取的字符长度为20, 然后预测第二十一个
time_step=20

# 滑动窗口提取数据--对字符数据进行截取转成列表给x, 将预测数据给y
def extract_char(data, slide):
    x=[]
    y=[]
    for i in range(len(data)-slide):
        x.append([a for a in data[i:i+slide]])
        y.append(data[i+slide])
    return x, y
    
# 批量转化--将字符转化为数字
def char_to_int_data(x, y, char_to_int_dict):
    x_to_int=[]
    y_to_int=[]
    for i in range(len(x)):
        x_to_int.append([char_to_int_dict[char] for char in x[i]])
        y_to_int.append(char_to_int_dict[y[i]])
    return x_to_int, y_to_int

# 实现文章的预处理, 参数--1.要处理的字符数据, 2.每次截取的字符长度, 3.进行转化的信息交换字典
def data_preprocessing(data, slide, letters_num, char_to_int_dict):
    # 提取滑动窗口数据
    char_data = extract_char(data, slide)
    int_data = char_to_int_data(char_data[0], char_data[1], char_to_int_dict)
    
    input_data = int_data[0]
    output_data = int_data[1]
    
    # 转换成 one-hot 编码
    input_reshape = np.array(input_data).reshape(len(input_data), slide)
    new = np.zeros((input_reshape.shape[0], input_reshape.shape[1], letters_num), dtype=bool)  # 使用 bool
    
    for i in range(input_reshape.shape[0]):
        new[i, :, :] = to_categorical(input_reshape[i, :], num_classes=letters_num)
    
    # 将布尔值转换为 0 和 1
    new = new.astype(int)
    
    return new, output_data

# 调用函数
x,y= data_preprocessing(txt_data, time_step, letters_num, char_to_int)

print(x.shape)
print(x[0])
print(len(y))

#数据集分割
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test=train_test_split(x, y, test_size=0.1, random_state=10)

# 对训练集目标值转化为one-hot格式
y_train_category=to_categorical(y_train, letters_num)
print(y_train_category)
# ===================================================================================
# 2.模型搭建和模型训练
# 搭建模型
from keras.models import Sequential
LSTM_model=Sequential()
from keras.layers import LSTM, Dense
# 一会要不要调整一下神经元数量? 
LSTM_model.add(LSTM(units=100, input_shape=(x_train.shape[1], x_train.shape[2]), activation='relu'))
LSTM_model.add(Dense(units=letters_num, activation='softmax'))
LSTM_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
LSTM_model.summary()

# 训练模型
LSTM_model.fit(x_train, y_train_category, epochs=50, batch_size=64)

# ===========================================================================
# 3.基于训练集和测试集进行预测和评估
# 训练集预测+评估
y_predict_base_train=LSTM_model.predict(x_train)
y_predict_base_train=np.argmax(y_predict_base_train, axis=1)
# 以数值类型输出预测结果
# print(y_predict_base_train)
y_predict_base_train_char=[int_to_char[i] for i in y_predict_base_train]
# 以字符类型输出结果
print(y_predict_base_test_char)
# 计算模型预测准确率
from sklearn.metrics import accuracy_score
accuracy_score=accuracy_score(y_train, y_predict_base_train)
print("accuracy is ", accuracy_score)

# 测试集预测+评估
y_predict_base_test=LSTM_model.predict(x_test)
y_predict_base_test=np.argmax(y_predict_base_test, axis=1)
# 以数值类型输出预测结果
# print(y_predict_base_test)

y_predict_base_test_char=[int_to_char[i] for i in y_predict_base_test]
# 以字符类型输出结果
# print(y_predict_base_test_char)

# 计算模型预测准确率
from sklearn.metrics import accuracy_score
accuracy_score=accuracy_score(y_test, y_predict_base_test)
print("accuracy is ", accuracy_score)


# 课外实践, ===========================================================================
# 4.实战预测: 输入"怎么样开发一个ai助手, 可以根据我提出的需求自动进行进行开发网站, 移动端app, 桌面应用程序,可以进行数据获取, 数据分析", 看看效果

x_new="怎么样开发一个ai助手, 可以根据我提出的需求自动进行进行开发网站, 移动端app, 桌面应用程序,可以进行数据获取, 数据分析"
x_new, y_new= data_preprocessing(x_new, time_step, letters_num, char_to_int)
y_new_predict=LSTM_model.predict(x_new)
y_new_predict=np.argmax(y_new_predict, axis=1)
y_new_char=[int_to_char[i] for i in y_new]
print(y_new_char)
y_new_predict_char=[int_to_char[i] for i in y_new_predict]
print(y_new_predict_char)

# 计算模型准确率
from sklearn.metrics import accuracy_score
accuracy_score=accuracy_score(y_new, y_new_predict)
print("accuracy is ", accuracy_score)


# 注意哈, 原来的实战目标是输入"ai应用程序开发", 结果报错了, 因为这个模型之前规定time_step=20, 而"ai应用程序开发"这一串字符小于二十, 模型无法完成一次正常的截取, 当然会报错

作者的备注:

兄弟, 不是不想给你们准备数据集,

而是这个文件里面的数据对我用处很大,

原谅我

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2224277.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

树莓派5使用pytorch训练模型(CPU)

Pytorch 对于树莓派提供了较好的支持,可以利用 Pytorch 在树莓派上进行试试推理,当然也可以使用树莓派进行模型训练了,这里尝试使用树莓派CPU对模型进行训练。 0 环境配置 必要的环境安装,这个步骤没有什么值得说的,…

c++(树)

定义 2-3 树中的每一个节点都有两个孩子(称为 2 节点,2-node)或三个孩子(称为 3 节点,3-node)。 2 节点,有一个数据元素和两个孩子。只能有两个孩子或没有孩子,不能出现只有一个孩子的情况。如果…

JVM学习总结:字节码篇

本文是学习尚硅谷宋红康老师主讲的 尚硅谷JVM精讲与GC调优教程 的总结 ,部分内容也参考了 JavaGuide 网站(文末有链接) JVM 概述 Oracle JDK 与 OpenJDK 是什么关系? 2006 年 SUN 公司将 Java 开源,也就有了 OpenJDK。…

【verilog】四位全加器

文章目录 前言一、实验原理二、实验过程三、实验结果参考文献 前言 进行 FPGA 全加器 实验 一、实验原理 module adder(ain,bin,cin,cout,s); input ain,bin,cin; output cout,s; assign coutain&bin | ain&cin | bin&cin; assign sain^bin^cin; endmoduletimesc…

复杂类型map与struct

1.map:Key-Value 型数据格式 建表: create table myhive.test_map( id int, name string, members map<string,string>, age int) row format delimited fields terminated by , COLLECTION ITEMS TERMINATED BY # MAP KEYS TERMINATED BY :; 数据导入:load data local …

基于ssm+jsp的地方疫情管理系统(含源码+数据库)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: apache tomcat 主要技术: Java,Spring,SpringMvc,mybatis,mysql,vue 2.视频演示地址 3.功能 该系统包含两个…

「二叉树进阶题解:构建、遍历与结构转化全解析」

文章目录 根据二叉树创建字符串思路代码 二叉树的层序遍历思路代码 二叉树的最近公共祖先思路代码 二叉搜索树与双向链表思路代码 从前序与中序遍历序列构造二叉树思路代码 总结 根据二叉树创建字符串 题目&#xff1a; 样例&#xff1a; 可以看见&#xff0c;唯一特殊的就…

Fast Simulation of Mass-Spring Systems in Rust 论文阅读

参考资料&#xff1a; Fast Simulation of Mass-Spring Systems in Rust 论文阅读&#xff1a;Fast Simulation of Mass-Spring Systems 【论文精读】讲解刘天添2013年的fast simulation of mass spring system(Projective Dynamics最早的论文) Projective Dynamics笔记(一…

新手做私域学会这三步,一周时间营收翻倍

在数字化营销的时代&#xff0c;私域流量的运营已经成为品牌和创业者提升营收的关键。如果你是一个私域营销的新手&#xff0c;那么这篇文章将是你的福音。我们将分享三个简单而有效的步骤&#xff0c;帮助你在短短一周内实现营收翻倍的目标。 第一步&#xff1a;明确人设——…

SpringBoot项目整合Knife4J

SpringBoot项目整合Knife4J 前言为什么要使用API文档什么是API文档 Knife4jKnife4j的进化史Swagger和Knife4J的关系 SpringBoot整合Knife4j版本适配实现步骤1.导入依赖2.编写配置类新建一个controller进行测试启动项目 Knife4j增强配置常用注解例子展示实体类注解Controller注解…

【大数据学习 | kafka】kafuka的基础架构

1. kafka是什么 Kafka是由LinkedIn开发的一个分布式的消息队列。它是一款开源的、轻量级的、分布式、可分区和具有复制备份的&#xff08;Replicated&#xff09;、基于ZooKeeper的协调管理的分布式流平台的功能强大的消息系统。与传统的消息系统相比&#xff0c;KafKa能够很好…

HarmonyOS 相对布局(RelativeContainer)

1. HarmonyOS 相对布局&#xff08;RelativeContainer&#xff09; 文档中心:https://developer.huawei.com/consumer/cn/doc/harmonyos-guides-V5/arkts-layout-development-relative-layout-V5   RelativeContainer为采用相对布局的容器&#xff0c;支持容器内部的子元素设…

海螺 2.27.1 |AI生成视频 AI音乐 语音通话

嗨&#xff01;我是小海螺&#xff0c;你的AI智能伙伴&#xff0c;帮助你学习工作效率加倍&#xff01;我无所不知&#xff0c;又像朋友陪你左右&#xff0c;遇到问题&#xff0c;就问我吧。我所使用的技术&#xff0c;是MiniMax公司自研的万亿参数MoE大模型。我们希望能与用户…

【SpringCloud】Seata微服务事务

Seata微服务事务 分布式事务问题&#xff1a;本地事务分布式事务演示分布式事务问题&#xff1a;示例1 分布式事务理论CAP定理一致性可用性分区容错矛盾 Base理论解决分布式事务的思路 初识SeataSeata的架构部署TC服务微服务集成Seata引入依赖配置TC地址 其他服务 动手实践XA模…

WRB Hidden Gap,WRB隐藏缺口,MetaTrader 免费公式!(指标教程)

WRB Hidden Gap MetaTrader 指标用于检测和标记宽范围的柱体&#xff08;非常长的柱体&#xff09;或宽范围的烛身&#xff08;具有非常长实体的阴阳烛&#xff09;。此指标可以识别WRB中的隐藏跳空&#xff0c;并区分显示已填补和未填补的隐藏跳空&#xff0c;方便用户一眼识别…

Zustand介绍与使用 React状态管理工具

文章目录 前言基本使用编写状态加方法在组件中使用异步方法操作 中间件简化状态获取优化性能 持久化保存 前言 在现代前端开发中&#xff0c;状态管理一直是一个关键的挑战。随着应用规模的扩大&#xff0c;组件间的状态共享变得愈加复杂。为了应对这一需求&#xff0c;开发者…

Java-图书管理系统

我的个人主页 欢迎来到我的Java图书管理系统&#xff0c;接下来让我们一同探索如何书写图书管理系统吧&#xff01; 1管理端和用户端 2建立相关的三个包&#xff08;book、operation、user&#xff09; 3建立程序入口Main类 4程序运行 1.首先图书馆管理系统分为管理员端和…

使用Poste搭建内网邮件服务器

使用Poste搭建内网邮件服务器 Poste.io 也是一个流行的邮件服务器方案&#xff0c;它可以通过 Docker 容器轻松部署&#xff0c;非常适合搭建内部邮件服务器。 本文档将向您展示如何开始使用 Poste.io 邮件服务器。在 5 分钟内&#xff0c;您将拥有一个可发送和接收邮件的邮件…

Springboot 使用EasyExcel导出Excel文件

Springboot 使用EasyExcel导出Excel文件 Excel导出系列目录&#xff1a;引入依赖创建导出模板类创建图片转化器 逻辑处理controllerservice 导出效果遗留问题 Excel导出系列目录&#xff1a; 【Springboot 使用EasyExcel导出Excel文件】 【Springboot 使用POI导出Excel文件】 …

基于Python大数据的王者荣耀战队数据分析及可视化系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏&#xff1a;…