第100+10步 ChatGPT文献复现:ARIMA-ERNN预测百日咳

news2024/11/26 21:20:18

基于WIN10的64位系统演示

一、写在前面

我们来继续这篇文章:

《BMC Public Health》杂志的2022年一篇题目为《ARIMA and ARIMA-ERNN models for prediction of pertussis incidence in mainland China from 2004 to 2021》文章的模拟数据做案例。

这文章做的是用:使用单纯ARIMA模型和ARIMA-ERNN组合模型预测中国大陆百日咳发病率。

文章是用单纯的ARIMA模型作为对照,更新了ARIMA-ERNN模型。本期,我们试一试ARIMA-ERNN组合模型能否展示出更优秀的性能。

数据不是原始数据哈,是我使用GPT-4根据文章的散点图提取出来近似数据,只弄到了2004-2017年的。

二、闲聊和复现:

1ARIMA-ERNN组合模型的构建策略

又涉及到组合模型的构建。不过这个策略跟上一期提到的ARIMA-NARNN组合模型所使用的的建模策略不一样,来看看GPT-4的总结:

不懂大家看懂了没,反正我是有个疑问:

“所以是一个ARIMA的预测值作为输入,然后对应一个实际值作为输出,训练ERNN模型”

GPT-4肯定了我的回答:

是直接用n个ARIMA的预测值作为输入,实际值作为输出,去构建神经网络,有点像ARIMA-GRNN的策略。

所以说,组合模型,玩的就是花。

(2)组合模型构建

那就帮用pytorch写一个代码呗:

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error

# Load data
data_path = 'ERNN.csv'
data = pd.read_csv(data_path, parse_dates=True, index_col=0)

# Ensure column accuracy (second column is actual values, third column is ARIMA predictions)
actuals = data.iloc[:, 0]
predictions = data.iloc[:, 1]

# Use 36 consecutive ARIMA prediction values as input
n = 36
inputs = [predictions.shift(-i) for i in range(n)]
inputs = pd.concat(inputs, axis=1)[:-n]  # Remove the last n rows as they are incomplete
targets = actuals[n:]  # Remove the first n targets to align data

# Split data into training and testing sets, last 12 as test set
train_inputs = inputs[:-12]
train_targets = targets[:-12]
test_inputs = inputs[-12:]
test_targets = targets[-12:]

# Convert to PyTorch tensors
train_inputs = torch.tensor(train_inputs.values).float()
train_targets = torch.tensor(train_targets.values).float().unsqueeze(1)
test_inputs = torch.tensor(test_inputs.values).float()
test_targets = torch.tensor(test_targets.values).float().unsqueeze(1)

# Create DataLoaders
train_dataset = TensorDataset(train_inputs, train_targets)
test_dataset = TensorDataset(test_inputs, test_targets)
batch_size = 12  # One data point at a time
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

# Define ERNN model
import torch
import torch.nn as nn

# Define ERNN model using ReLU
class ERNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ERNN, self).__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.tanh = nn.Tanh()

    def forward(self, input, hidden):
        if input.size(0) != hidden.size(0):
            hidden = self.initHidden(input.size(0))
        combined = torch.cat((input, hidden), 1)
        hidden = self.tanh(self.i2h(combined))
        output = self.i2o(combined)
        return output, hidden

    def initHidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)

# Initialize model
model = ERNN(input_size=n, hidden_size=24, output_size=1)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.004, momentum=0.9)  # 使用带动量的SGD模拟traingdx

# 训练模型
def train_model(model, criterion, optimizer, train_loader, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            current_batch_size = inputs.size(0)  # 获取当前批次的大小
            hidden = model.initHidden(current_batch_size)  # 使用当前批次大小初始化隐藏状态
            outputs, hidden = model(inputs, hidden)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if epoch % 1000 == 0:  # 每1000次迭代输出一次损失
            print(f'Epoch {epoch}: Loss = {total_loss / len(train_loader)}')

train_model(model, criterion, optimizer, train_loader, epochs=200000)  # 训练10000次迭代

# Evaluate model
def evaluate_model(model, loader):
    model.eval()
    predictions, actuals = [], []
    with torch.no_grad():
        for inputs, targets in loader:
            current_batch_size = inputs.size(0)  # 获取当前批次的大小
            hidden = model.initHidden(current_batch_size)  # 使用当前批次大小初始化隐藏状态
            outputs, hidden = model(inputs, hidden)
            predictions.extend(outputs.detach().numpy())
            actuals.extend(targets.numpy())
    mse = mean_squared_error(actuals, predictions)
    mae = mean_absolute_error(actuals, predictions)
    mape = mean_absolute_percentage_error(actuals, predictions)
    return mse, mae, mape

# Evaluate on both training and testing data
train_mse, train_mae, train_mape = evaluate_model(model, train_loader)
test_mse, test_mae, test_mape = evaluate_model(model, test_loader)

# Print evaluation metrics for both training and testing datasets
print(f"Training MSE: {train_mse}, MAE: {train_mae}, MAPE: {train_mape}")
print(f"Test MSE: {test_mse}, MAE: {test_mae}, MAPE: {test_mape}")

又是可怕的调参,首先看看文章的参数:

“在确定隐藏层神经元数量(N)时,使用了以下经验公式:

N = √(n + m )+ a

其中m是输入层的神经元数量,n是输出层的神经元数量,a是一个常数(范围在1到10之间)。根据这个计算,ERNN的隐藏层有3到12个神经元。我们对ERNN的隐含层使用了Tan-Sigmoid函数,对输出层使用了Purelin函数,训练函数用的是traingdx,网络权重学习函数用的是Learngdm,模型性能评估使用了MSE。网络的参数设置如下:迭代步数为10,000步,学习率为0.01,学习目标(学习误差)为0.004。然后我们使用了一个2-9-1结构的ERNN来预测百日咳的发病率。ARIMA-ERNN模型的MSE为0.00077,优于ARIMA模型的0.00937。”

我试了试,一言难尽。

然后自己调了一下参数,目前最好就这样了,过拟合了:

Training MSE: 1.5360685210907832e-05, MAE: 0.003031895263120532, MAPE: 0.21757353842258453

Test MSE: 0.001023554359562695, MAE: 0.026323571801185608, MAPE: 0.37139490246772766

各位大佬有更好的结果麻烦分享一下参数,跪谢~

本代码可以调的参数有:

输入特征数量 (n):

这个参数决定了每个输入样本中使用的连续ARIMA预测值的数量。您可以尝试使用不同数量的输入特征来看是否能改善模型的表现。

批大小 (batch_size):

这影响到模型训练的速度及收敛性。较大的批大小可以提高内存利用率和训练速度,但可能影响模型的最终性能和稳定性。反之,较小的批大小可能使训练过程更稳定、收敛更精细,但训练速度会减慢。

隐藏层大小 (hidden_size):

这决定了隐藏层中神经元的数量。增加隐藏层的大小可以提高模型的学习能力,但也可能导致过拟合。相应地,减少隐藏层大小可能减少过拟合的风险,但可能限制模型的能力。

学习率 (lr):

学习率是优化算法中最重要的参数之一。如果学习率设置得太高,模型可能无法收敛;如果设置得太低,模型训练可能过于缓慢,甚至在到达最佳点之前停止训练。

动量 (momentum):

动量帮助加速SGD在正确方向上的收敛,还可以减少优化过程中的震荡。调整这个参数可以影响学习过程的平滑性和速度。

激活函数:

模型中使用的是 Tanh 激活函数。可以考虑使用 ReLU 或其变体(如 LeakyReLU, ELU 等)以期改善模型性能,尤其是在处理非线性问题时。

优化器:

虽然使用的是带动量的SGD,但您也可以尝试使用其他优化器,如 Adam、RMSprop 等,这些优化器可能在不同的应用场景中表现更好。

迭代次数 (epochs):

增加或减少训练迭代次数可以影响模型的学习程度。太少的迭代次数可能导致模型欠拟合,而太多的迭代次数可能导致过拟合或不必要的计算资源消耗。

三、后话

又学习了一种潜在的ARIMA组合模型的策略,同样的,把ERNN换成别的模型,就又是一片新天地了。

四、数据

不提供,自行根据下图提取吧

实在没有GPT-4,那就这个:

https://apps.automeris.io/wpd/index.zh_CN.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1796917.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

centos7安装mysql(完整)

官网5.7版本:https://cdn.mysql.com//Downloads/MySQL-5.7/mysql-5.7.29-1.el7.x86_64.rpm-bundle.tar 文件存于百度网盘:链接:https://pan.baidu.com/s/1x0fucIsD36_7agu88Jd2yg 提取码:s4m8 复制这段内容后打开百度网盘手机A…

k8s——Pod容器中的存储方式及PV、PVC

一、Pod容器中的存储方式 需要存储方式前提:容器磁盘上的文件的生命周期是短暂的,这就使得在容器中运行重要应用时会出现一些问题。 首先,当容器崩溃时,kubelet 会重启它,但是容器中的文件将丢失——容器以干净的状态&…

基于Gdb快速上手调试Redis

写在文章开头 近期很多读者有询问有没有什么简单的办法快速上手调试redis,对此,笔者用到了Linux系统中比较易上手的调试工具GDB,本文将基于一个C语言两数交换的例子演示一下这款工具的使用。 Hi,我是 sharkChili ,是个…

Unity DOTS技术(十二) SystemBase修饰及操作

文章目录 一.变量修饰容器二 . Native Container 分配器三.NativeArray的创建及释放四.线程阻塞释放容器五.只读容器六,安全检查开关七.实体操作八.更优的实体操作方式 一.变量修饰容器 在上节中我们讲到多线程操作,为避免对线程的操作导致数据错乱,我们需要为变量进行修饰.于…

Python开发运维:VSCode与Pycharm 部署 Anaconda虚拟环境

目录 一、实验 1.环境 2.Windows 部署 Anaconda 3.Anaconda 使用 4.VSCode 部署 Anaconda虚拟环境 5.Pycharm 部署 Anaconda虚拟环境 6.Windows使用命令窗口版 Jupyter Notebook 7.Anaconda 图形化界面 二、问题 1.VSCode 运行.ipynb代码时报错 2.pip 如何使用国内…

C++开发基础之初探CUDA计算环境搭建

一、前言 项目中有使用到CUDA计算的相关内容。但是在早期CUDA计算环境搭建的过程中,并不是非常顺利,编写此篇文章记录下。对于刚刚开始研究的你可能会有一定的帮助。 二、环境搭建 搭建 CUDA 计算环境涉及到几个关键步骤,包括安装适当的 C…

【C++】 使用CRT 库检测内存泄漏

CRT 库检测内存泄漏 一、CRT 库简介二、CRT 库的使用1、启用内存泄漏检测2、设置应用退出时显示内存泄漏报告3、丰富内存泄漏报告4、演示使用 内存泄漏是 C/C 应用程序中最微妙、最难以发现的 bug,存泄漏是由于之前分配的内存未能正确解除分配而导致的。 最开始的少…

MySQL主从同步优化指南:架构、瓶颈与解决方案

前言 ​ 在现代数据库架构中,MySQL 主从同步是实现高可用性和负载均衡的关键技术。本文将深入探讨主从同步的架构、延迟原因以及优化策略,并提供专业的监控建议。 MySQL 主从同步架构 ​ 主从复制流程: 从库生成两个线程,一个…

如何替换fmod studio的.bank文件内的音效?

🏆本文收录于「Bug调优」专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收藏&&…

【Java毕业设计】基于JavaWeb的旅游论坛管理系统

文章目录 摘 要目 录1 概述1.1 研究背景及意义1.2 国内外研究现状1.3 拟研究内容1.4 系统开发技术1.4.1 Java编程语言1.4.2 vue技术1.4.3 MySQL数据库1.4.4 B/S结构1.4.5 Spring Boot框架 2 系统需求分析2.1 可行性分析2.2 系统流程2.2.1 操作流程2.2.2 登录流程2.2.3 删除信息…

【微信小程序开发(从零到一)】——个人中心页面的实战项目(一)

👨‍💻个人主页:开发者-曼亿点 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 曼亿点 原创 👨‍💻 收录于专栏&#xff1a…

(ICLR,2024)HarMA:高效的协同迁移学习与模态对齐遥感技术

文章目录 相关资料摘要引言方法多模态门控适配器目标函数 实验 相关资料 论文:Efficient Remote Sensing with Harmonized Transfer Learning and Modality Alignment 代码:https://github.com/seekerhuang/HarMA 摘要 随着视觉和语言预训练&#xf…

Rhino-Grasshopper:小白从入门开始学习

前言: 小编在这里即将开启一个新系列学习课程,主要内容为基于Rhino的3D打印学习,具体包括Rhino中的Python使用,Grasshopper的功能,讲解视频会陆续更新在B站,希望大家多多支持! 关于相关学习、…

list(二)和_stack_queue

嗨喽大家好,时隔许久阿鑫又给大家带来了新的博客,list的模拟实现(二)以及_stack_queue,下面让我们开始今天的学习吧! list(二)和_stack_queue 1.list的构造函数 2.设计模式之适配器和迭代器 3.新容器de…

HTML静态网页成品作业(HTML+CSS)—— 保护环境环保介绍网页(1个页面)

🎉不定期分享源码,关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 🏷️本套采用HTMLCSS,未使用Javacsript代码,共有1个页面。 二、作品演示 三、代…

公检法部门保密网文件导出,这样做才是真正的安全又便捷

公检法是司法机关的核心组成,也是社会管理的重要组成,公检法部门的业务中涉及大量的居民数据、个人隐私、司法案件等信息,因此,数据的安全性至关重要。 根据我国法律要求,同时基于对数据的保护需要,我国的公…

Vue06-el与data的两种写法

一、el属性 用来指示vue编译器从什么地方开始解析 vue的语法,可以说是一个占位符。 1-1、写法一 1-2、写法二 当不使用el属性的时候: 两种写法都可以。 v.$mount(#root);写法的好处:比较灵活: 二、data的两种写法 2-1、对象式…

discuz点微同城源码34.7+全套插件+小程序前端

discuz点微同城源码34.7全套插件小程序前后端 模板挺好看的 带全套插件 自己耐心点配置一下插件 可以H5可以小程序

重磅就业报告前美股涨势消减,标普暂别纪录高位,英伟达盘中闪崩近6%,欧央行降息预期“退烧”,欧元跳涨

标普纳指创盘中历史新高后转跌,道指三连涨至近两周新高;芯片股指和台积电美股跌落纪录高位,英伟达三日收创历史新高后回落;游戏驿站盘中一度暴拉50%。中概股指回落,财报后蔚来收跌6.8%。欧央行会后,欧元盘中…

Dvws靶场

文章目录 一、XXE外部实体注入二、No-SQL注入三、Insecure Direct Object Reference四、Mass Assignment五、Information Disclosure六、Command Injection七、SQL注入 一、XXE外部实体注入 访问http://192.168.92.6/dvwsuserservice?wsdl,发现一个SOAP服务。在SO…