Python 模型训练:LSTM 时间序列销售额预测(训练、保存、调用)

news2025/1/19 3:19:31

LSTM (long short-term memory) 长短期记忆网络,具体理论的就不一一叙述,直接开始

流程

    • 一、数据导入
    • 二、数据归一化
    • 三、划分训练集、测试集
    • 四、划分标签和属性
    • 五、转换成 LSTM 输入格式
    • 六、设计 LSTM 模型
      • 6.1 直接建模
      • 6.2 找最好
    • 七、测试与图形化展示
    • 八、保存模型到 pkl 文件
    • 九、模型调用
      • 9.1 Python 模型调用端
      • 9.2 Java 程序调用端

一、数据导入

  • 正常的 pandas 读取数据,将时间列转成索引(看其他教程这样做,感觉没啥用,按照时间顺序就行)
# 获取数据
import pandas as pd
from datetime import datetime
dataset = pd.read_csv('../data.csv', index_col='时间', usecols=[0,2,3,5], date_parser=lambda x:datetime.strptime(x, '%Y年%m月'))
dataset

在这里插入图片描述

二、数据归一化

  • 将数据缩小到 0-1 范围,我这里将所有数据归到一列来,这样缩小范围就是一样的,后续可以直接用这个来转换
# 数据归一化
from sklearn.preprocessing import MinMaxScaler
values = dataset.values
# 转换成一列
values_res = values.reshape(values.shape[0] * values.shape[1], 1)
scaler = MinMaxScaler(feature_range=(0, 1))
# 训练 scaler
scaled = scaler.fit_transform(values_res)
# 再转换成原来的样子
scaled_dataset = scaled.reshape(values.shape)
scaled_dataset

在这里插入图片描述

三、划分训练集、测试集

  • 数据需要按照时间顺序,所以这里之前前后切割 20%
# 切分训练集和测试集
split = round(len(scaled_dataset)*0.20)
train = scaled_dataset[:-split]
test = scaled_dataset[-split:]
test

在这里插入图片描述

四、划分标签和属性

  • 数据的第一列是标签数据,第二三列是属性条件数据
# 划分标签和属性
train_x, train_y = train[:, 1:], train[:, 0]
test_x, test_y = test[:, 1:], test[:, 0]
test_x

在这里插入图片描述

五、转换成 LSTM 输入格式

  • 转为LSTM模型的输入格式(samples, timesteps, features)
train_x_input = train_x.reshape((train_x.shape[0], 1, train_x.shape[1]))
test_x_input = test_x.reshape((test_x.shape[0], 1, test_x.shape[1]))
test_x_input

在这里插入图片描述

六、设计 LSTM 模型

  • 设计 LSTM 模型有两个方式,第一个是知道最佳参数是什么,第二个是多输入几个参数,然后找到最佳参数

6.1 直接建模

# 设计 LSTM 模型
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
model = Sequential()
model.add(LSTM(50, input_shape=(1, 2)))
model.add(Dense(1))
model.compile(loss="mae", optimizer="adam")
model.fit(train_x_input, train_y, epochs=10, batch_size=1, validation_data=(test_x_input, test_y), verbose=2, shuffle=False)

6.2 找最好

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from keras.wrappers.scikit_learn import KerasRegressor
from sklearn.model_selection import GridSearchCV

def build_model(optimizer):
    grid_model = Sequential()
    grid_model.add(LSTM(50,return_sequences=True,input_shape=(1,2)))
    grid_model.add(LSTM(50))
    grid_model.add(Dropout(0.2))
    grid_model.add(Dense(1))
    grid_model.compile(loss = 'mse',optimizer = optimizer)
    return grid_model

grid_model = KerasRegressor(build_fn=build_model,verbose=1,validation_data=(test_x_input,test_y))
# 把各种可能的参数都丢上去
parameters = {'batch_size' : [1],
            'epochs' : [10,11],
            'optimizer' : ['adam', 'rmsprop'] } 
grid_search = GridSearchCV(estimator = grid_model,
                          param_grid = parameters,
                          cv = 2)
# 训练
grid_search = grid_search.fit(train_x_input, train_y)
# 最好的参数
print(grid_search.best_params_)
# 最好参数对应的模型
model = grid_search.best_estimator_.model

七、测试与图形化展示

from matplotlib import pyplot as plt
from sklearn.metrics import mean_squared_error
import math

# 测试
pred = model.predict(test_x_input)
# 获取原始值
real = scaler.inverse_transform(test_y.reshape(1, -1)).reshape(-1, 1)
predicted = scaler.inverse_transform(pred)
plt.plot(real, color = 'red', label = 'Real')
plt.plot(predicted, color = 'blue', label = 'Predicted')
plt.title('Sale Prediction')
plt.xlabel('Time')
plt.ylabel('Sale')
plt.legend()
plt.show()
rmse = math.sqrt(mean_squared_error(real, predicted))
print("均方根误差:" + str(rmse))

均方根误差:2.1375958318221455
在这里插入图片描述

八、保存模型到 pkl 文件

# 保存模型
import dill
with open('./sale_predict_model.pkl', 'wb') as outfile:
    dill.dump({
        'scaler': scaler,
        'model': model
    }, outfile)

九、模型调用

  • 模型要部署到线上进行调用,直接可以写一个脚本进行调用,同时考虑到每次调用都要读取一次模型,浪费性能,直接使用 Socket 形式传参,后台形成一个常驻服务
    • Socket 固定传入格式 “a,b"

9.1 Python 模型调用端

import socket
import threading
import numpy as np
import pickle

# Socket 操作
sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sk.bind(('127.0.0.1', 10001))
sk.listen(5)
count = 0
# 读取模型
file = 'sale_predict_model.pkl'
with open(file, 'rb') as f:
    model = pickle.load(f)


# 模型预测
def predict(a, b):
    data = np.array([[a, b]])
    # 转换格式,使用的是模型训练时训练出来的编译器
    data_scaled = model['scaler'].transform(data.reshape(data.shape[0] * data.shape[1], 1)).reshape(data.shape)
    # 直接导入模型,一样要进行转换格式
    data_res = model['model'].predict(data_scaled.reshape((data_scaled.shape[0], 1, data_scaled.shape[1])))
    # 返回最终结果
    return model['scaler'].inverse_transform(data_res)[0][0]


# 处理 Socket 连接
def tcp(sock, addr):
    try:
        print('Accept new connection from %s:%s...' % addr)
        print('Request count: %d' % count)
        # 读取参数
        data = sock.recv(1024)
        # 解码参数
        data_str = data.decode('utf-8')
        print("Param: %s" % data_str)
        # 切割参数
        data_list = data_str.split(',')
        # 判断参数合法性
        if len(data_list) == 2:
	        # 合法参数调用模型并返回数据
            sock.send(str(predict(data_list[0], data_list[1])).encode('utf-8'))
            print("Invoke success")
        else:
            sock.send(('Error param: %s' % data_str).encode('utf-8'))
            print('Error param: %s' % data_str)
    except Exception as e:
        print('Except:', e)
        sock.send('Invoke error'.encode('utf-8'))
    finally:
        sock.close()


if __name__ == '__main__':
    while True:
    	# 监听连接
        data, addr = sk.accept()
        count += 1
        # 交给线程处理
        thread = threading.Thread(target=tcp, args=(data, addr))
        # 启动线程
        thread.start()

9.2 Java 程序调用端

package org.example.service;

import java.io.IOException;
import java.net.Socket;
import java.nio.charset.StandardCharsets;

public class InvokeModel {
	// service 测试
    public static void main(String[] args){
        System.out.println(invoke(54.4, 14.4));
    }
	// service 调用方法
    public static String invoke(Double sale1, Double sale2) {
    	// 拼装参数
        String req = sale1 + "," + sale2;
        Socket socket = null;
        try {
        	// 创建 Socket
            socket = new Socket("127.0.0.1", 10001);
            // 传输数据
            socket.getOutputStream().write(req.getBytes(StandardCharsets.UTF_8));
            System.out.println("Request param: " + req);
            byte[] buf = new byte[256];
            // 读取返回的数据
            int len = socket.getInputStream().read(buf);
            // 返回最终的结果(是一个 Double,方便操作直接用 String)
            return new String(buf, 0, len);
        } catch (IOException e) {
            throw new RuntimeException(e);
        } finally {
            try {
                if (socket != null)
                    socket.close();
            } catch (IOException e) {
                System.err.println("Invoke model error");
            }
        }
    }

}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/150745.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaSE-07

字节流输入输出数据: InputStream和OutputStream作为字节流输入输出流的超类。 字节流写数据时千万记得close关闭资源,可设置追加写为true 字节流读数据时,FileInputStream a new FileInputStream (“”); int by a.read(); char b (char…

隐蔽信道学习

隐蔽信道作为一种能够在不被系统感知的情况下稳定窃取秘密信息的通信手段,尽管其带宽通常较低,但其设计上的复杂性和多样性,使得常规的流量审计系统难以对抗或检测。同时,隐蔽信道也是密钥、身份认证、商业机密等秘密信息传输的重…

基于JAVA SSM框架的新闻管理系统源码+数据库,实现 登录 、 注册、 新闻内容、类别、评论、个人信息、系统管理等功能

[基于SSM框架的新闻管理系统] 前言 下载地址:基于JAVA SSM框架的新闻管理系统源码数据库 基于SSM框架的新闻管理系统; 实现 登录 、 注册 、 新闻内容、类别、评论、个人信息、系统管理等功能 ; 可继续完善增加前端、等其他功能等&#x…

federated引擎实现mysql跨服务器表连接

📢作者: 小小明-代码实体 📢博客主页:https://blog.csdn.net/as604049322 📢欢迎点赞 👍 收藏 ⭐留言 📝 欢迎讨论! 📢本文链接:https://xxmdmst.blog.csdn.n…

IU5066 高耐压带OVP保护1.2A单节锂电池线性充电IC

概要 IU5066E是面向空间受限便携应用的,高度集成锤离子和锤聚合物线性充电器器件。该器件由USB端口或交流适配器供电。带输入过压保护的高输入电压范围支持低成本、非稳压适配器。 电池充电经历以下三个阶段: 涓流、电流、恒压。在所有充电阶段&#x…

JSON对象(javascript)

本文内容主要包括了对于JS中JSON对象的一些内容。我们知道JSON格式是前后端进行信息交换的中介信息格式。适用于取代XML格式的一种格式,在多数编程语言中都有关于JSON的处理方法。那么javascript也提供了JSON对象用于处理相应的数据。 1. 什么是JSON格式&#xff1…

mac安装jdkidea配置jdk

第一步:mac安装jdk1、下载jdk,下载地址:https://www.oracle.com/java/technologies/downloads/#java82、安装后,终端输入java -version查看java是否安装成功3、配置环境变量a.在终端输入 /usr/libexec/java_home 可以得到JAVA_HOM…

【矩阵论】5. 线性空间与线性变换——线性映射与自然基分解,线性变换

矩阵论 1. 准备知识——复数域上矩阵,Hermite变换) 1.准备知识——复数域上的内积域正交阵 1.准备知识——Hermite阵,二次型,矩阵合同,正定阵,幂0阵,幂等阵,矩阵的秩 2. 矩阵分解——SVD准备知识——奇异值…

深度学习人体解析

人体解析旨在将图像或视频中的人体分割成多个像素级的语义部分。在过去的十年中,它在计算机视觉社区中获得了极大的兴趣,并在广泛的实际应用中得到了应用,从安全监控到社交媒体,再到视觉特效,这只是其中的一小部分。尽…

Markdown语法大全(够你用一辈子)

标题 # 一级标题 ## 二级标题 ### 三级标题 #### 四级标题 ##### 五级标题 ###### 六级标题一级标题 二级标题 三级标题 四级标题 五级标题 六级标题 文本样式 > 引用文本 > 最外层 > > 第一层嵌套 > > > 第二层嵌套引用文本 最外层 第一层嵌套 第二层…

js中的call和apply

js中的call和apply1.call()可以调用某一函数2.call()可以这个函数的this指向3.call()也可以接受参数每次看到js中的call方法,都是懵逼的要去查查百度,自己研究记录下1.call()可以调用某一函数 testCall() {let person {fullName: function () {console.…

webpack基本使用

1、内置模块path (1)path模块用于对路径和文件进行处理,提供了很多好用的方法。 (2)我们知道在Mac OS、Linux和window上的路径时不一样的 window上会使用 \或者 \\ 来作为文件路径的分隔符,当然目前也支…

SpringBoot+VUE前后端分离项目学习笔记 - 【17 SpringBoot文件上传下载功能 MD5实现文件唯一标识】

Sql 数据库新建sys_file用来保存上传文件信息 CREATE TABLE sys_file (id int(11) NOT NULL AUTO_INCREMENT COMMENT id,name varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 文件名称,type varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 文…

STM32——I2C通信

文章目录I2C通信使用I2C通信的硬件设备硬件电路I2C时序基本单元起始与终止发送接收发送应答与接收应答I2C时序指定地址写当前地址读指定地址读连续读与写MPU6050简介MPU6050参数硬件电路MPU6050框图系统时钟MPU6050的中断源寄存器映像软件I2C读写MPU6050电路设计关键代码I2C通信…

C语言-扫雷

文章目录完整扫雷1. 说明2. 思路3. 各个功能实现3.1 雷盘初始化与打印1)雷盘定义2) 随机布置雷3.2 玩家排查雷1) 获取坐标周围雷数2) 递归展开3)胜负判断3) 显示雷位置4. 游戏试玩5. 游戏完整代码game.htes…

【定时任务】---- xxl-job、@Scheduled

一、Scheduled注解实现的定时任务 要实现计划任务,首先通过在配置类注解EnableScheduling来开启对计划任务的支持,然后在要执行计划任务的方法上注解Scheduled,声明这是一个计划任务。 在Spring Boot 的入口类 XXXApplication 中,必然会有S…

东南大学洪伟教授评述:毫米波与太赫兹技术

今日推荐文章作者为东南大学毫米波国家重点实验室主任、IEEE Fellow 著名毫米波专家洪伟教授,本文选自《毫米波与太赫兹技术》,发表于《中国科学: 信息科学》2016 年第46卷第8 期——《信息科学与技术若干前沿问题评述专刊》。 本文概要介绍了毫米波与太…

CSS知识点精学6-精灵图、背景图片大小、文字阴影、盒子阴影、过渡

目录 一.精灵图 1.精灵图的介绍 2.精灵图的使用步骤 二.背景图片大小 三.文字阴影 四.盒子阴影 五.过渡 一.精灵图 1.精灵图的介绍 场景:项目中将多张小图片,合成一张大图片,这张图片称之为精灵图 优点:减少服务器发送次…

clickhouse入门学习以及数据迁移

本文主要介绍如何入门clickhouse,以及将mariadb数据迁移过来,最后介绍当前几种的训练的示例数据库集。1、中文教程:中文教程:中文教程有了教程,需要有数据可以训练,教程提供示例数据集,但是数据…

Java基础之《netty(22)—Protobuf》

一、Protobuf基本介绍 1、Protobuf是Google发布的开源项目,全称Google Protobuf Buffers,是一种轻便高效的结构化数据存储格式,可以用于结构化数据串行化,或者说序列化。它很适合做数据存储或RPC数据交换格式。 2、参考文档 htt…