基于SA-BP模拟退火算法优化BP神经网络实现数据预测Python实现

news2024/12/25 0:26:10

       在数据分析和机器学习领域,时间序列预测和多输入单输出系统的预测是重要且复杂的问题。传统的BP(反向传播)神经网络虽然具有强大的非线性函数逼近能力,但在处理这些问题时容易陷入局部极小值、训练速度慢以及过拟合等问题。为了克服这些不足,我们引入了SA-BP(模拟退火算法优化BP神经网络)算法。本文将详细介绍SA-BP算法的原理、步骤,并结合Python程序进行实例分析。

一、SA-BP算法概述

      1.SA模拟退火算法

       SA模拟退火算法(Simulated Annealing, SA)是一种基于概率的全局优化算法,其灵感来源于物理学中固体物质的退火过程。模拟退火算法通过模拟固体在退火过程中粒子运动逐渐趋于有序并达到能量最低状态的机制,来解决数学优化问题。模拟退火算法的核心思想是在解空间中随机寻找目标函数的全局最优解。它从某一较高初温出发,伴随温度参数的不断下降,结合概率突跳特性在求解空间中随机寻找目标函数的全局最优解。即在局部最优解能概率性地跳出并最终趋于全局最优。

      2.BP神经网络(BP)

       BP神经网络是一种具有三层或三层以上的多层神经网络,包括输入层、隐含层和输出层。每一层都由若干个神经元组成,神经元之间通过加权和的方式传递信号,并经过激活函数进行非线性变换。BP神经网络的训练过程包括前向传播和反向传播两个阶段。在前向传播阶段,输入信号从输入层逐层传递到输出层;在反向传播阶段,根据输出误差调整各层之间的连接权重,使误差逐步减小。

      3.SA算法

       SA算法流程如下:

     (1)初始化:选择一个初始解作为当前解,同时设置一个初始温度和终止温度。温度表示搜索的“随机性”,在开始时较高,逐渐减小。

     (2)迭代搜索:在当前解的邻域中随机生成一个新解。计算新解与当前解的目标函数差(或成本差)ΔE。如果ΔE小于0,接受新解作为当前解。如果ΔE大于0,以一定概率接受新解。这个概率通常与温度参数和新旧解之间的成本差有关,遵循Metropolis准则。

     (3)温度更新:在每一次迭代后,根据一定的更新策略降低温度。常见的策略有线性降温和指数降温。

     (4)终止条件:当达到某个预定的迭代次数、温度降至某一阈值以下或在一定时间内未找到更好的解时,算法终止。

二、实验步骤

      SA-BP神经网络回归预测步骤:

      1.数据清洗:去除缺失值和异常值。

      2.特征选择:根据相关性分析选择对预测结果影响显著的特征。

      3.数据归一化:将特征值缩放到同一量纲,提高训练效率。

      4.确定BP神经网络结构:首先,根据问题的需求确定BP神经网络的输入层、隐藏层和输出层的节点数,以及隐藏层的层数。

      5.初始化BP神经网络参数:随机初始化BP神经网络的权重和偏置。这些参数将作为SA优化过程中的搜索变量。

      6.定义适应度函数:使用训练数据集训练BP神经网络,并计算网络输出与实际输出之间的误差(如均方误差MSE)作为适应度函数。适应度值越小,表示神经网络的预测性能越好。

      7.模拟退火优化:在一定的温度范围内,随机调整BP神经网络的权重和偏置。计算新解的误差,并根据接受准则判断是否接受新解。如果接受新解,则更新当前解,并降低温度;如果不接受,则保持当前解不变,继续下一轮迭代。

      8.迭代:重复执行适应度评估、分类和位置更新的过程,直到达到最大迭代次数或满足其他停止条件。

      9.输出最优BP神经网络:在SSA优化过程结束后,选择适应度值最小的麻雀(即最优的BP神经网络权重和偏置)作为最终的网络参数。

     10.测试与评估:使用测试数据集评估优化后的BP神经网络的预测性能,并与其他优化算法进行比较。

代码部分

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import random
import torch.optim as optim
import matplotlib
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False

# 导入数据
data = pd.read_csv('数据集.csv').values

# 划分训练集和测试集
np.random.seed(0)
temp = np.random.permutation(len(data))

P_train = data[temp[:80], :7]
T_train = data[temp[:80], 7]
P_test = data[temp[80:], :7]
T_test = data[temp[80:], 7]

# 数据归一化
scaler_input = MinMaxScaler(feature_range=(0, 1))
scaler_output = MinMaxScaler(feature_range=(0, 1))

p_train = scaler_input.fit_transform(P_train)
p_test = scaler_input.transform(P_test)

t_train = scaler_output.fit_transform(T_train.reshape(-1, 1)).ravel()
t_test = scaler_output.transform(T_test.reshape(-1, 1)).ravel()

# 转换为 PyTorch 张量
p_train = torch.tensor(p_train, dtype=torch.float32).to(device)
t_train = torch.tensor(t_train, dtype=torch.float32).view(-1, 1).to(device)
p_test = torch.tensor(p_test, dtype=torch.float32).to(device)
t_test = torch.tensor(t_test, dtype=torch.float32).view(-1, 1).to(device)

# 定义神经网络
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

model = NeuralNet(7, 10, 1).to(device)
criterion = nn.MSELoss()

四、实验与结果

      1.数据准备

       为了验证SA优化BP神经网络的有效性,本文采用如下数据集进行实验。下面所示本次采用的数据集(部分)。

      2.结果分析

       实验结果表明,采用基于SA优化BP神经网络的预测模型与传统BP神经网络模型进行对比分析。实验结果表明,SA-BP模型在准确率、鲁棒性和收敛速度方面均优于传统BP神经网络模型。

       (1) 适应度曲线图

       (2) 训练集预测值和真实值对比结果    

      (3) 测试集预测值和真实值对比结果   

      (4) 训练集线性回归图   

     (5) 测试集线性回归图    

     (6) 其他性能计算和新数据预测   

五、结论

       SA-BP算法通过结合模拟退火算法的全局搜索能力和BP神经网络的逼近能力,有效提高了BP神经网络的预测精度和鲁棒性。在实际应用中,SA-BP算法可以应用于时间序列预测、多输入单输出系统的预测等多个领域,为数据分析和机器学习提供强有力的支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2107845.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【读书笔记-《30天自制操作系统》-15】Day16

本篇内容继续多任务的讲解。上一篇中实现了两个任务之间的自动切换,但还不够通用,这里将其优化为多个任务之间的切换。接着引入了任务休眠的概念与休眠的程序实现。最后介绍了任务的优先级,一种用切换时间的长短来衡量,一种用Task…

【Qt】文件对话框QFileDialog

文件对话框QFileDialog ⽂件对话框⽤于应⽤程序中需要打开⼀个外部⽂件或需要将当前内容存储到指定的外部⽂件。 通过QFileDialog 可以选择一个文件,能够获取到这个文件的路径,打开文件/保存文件。 常⽤⽅法介绍: 1、打开⽂件(⼀…

【高中生讲机器学习】17. 讲人话的主成分分析,它来了!(上篇)

创建时间:2024-08-13 首发时间:2024-09-05 最后编辑时间:2024-09-05 作者:Geeker_LStar 你好呀~这里是 Geeker_LStar 的人工智能学习专栏,很高兴遇见你~ 我是 Geeker_LStar,一名准高一学生,热爱…

Redis 集群高可用详解及配置

关型数据库 关系型数据库: 是建立在关系模型基础上的数据库,其借助于集合代数等数学概念和方法来处理数据库中的数据 主流的 MySQL、Oracle、MS SQL Server 和 DB2 都属于这类传统数据库 关型数据库的优缺点 特点: 1、数据关系模型基于关系…

Redis使用——Redis的redis.conf配置注释详解(三)

Redis使用——Redis的redis.conf配置注释详解(三) 背景 日常我们开发时,我们会遇到各种各样的奇奇怪怪的问题(踩坑o(╯□╰)o),这个常见问题系列就是我日常遇到的一些问题的记录文章系列,这里整…

鸿蒙轻内核M核源码分析系列四 中断Hwi

往期知识点记录: 鸿蒙(HarmonyOS)应用层开发(北向)知识点汇总 持续更新中…… 在鸿蒙轻内核源码分析系列前几篇文章中,剖析了重要的数据结构。本文,我们讲述一下中断,会给读者介绍中…

Ubuntu固定USB串口名(包括1拖N的USB串口)

在运行Ubuntu系统的开发板上,如果使用可插拔的USB串口,有时候程序正在运行时,如果突然连接传感器的USB串口设备被插拔了一下,这时,会发现系统中的USB串口名发生了改变。例如,插拔之前是/dev/ttyUSB0,插拔之后变成了/dev/ttyUSB3。发生这种情况的时候,有时候会导致程序无…

Windows I/O系统

硬件存储体系 寄存器 处理器内部定义的存储体,它们除了存储功能,往往还兼有其他的能力,比如参与运算,地址解析,指示处理器的状态,等等。寄存器是由处理器内部专门的触发器电路实现的,处理器往…

jupyter里怎么设置代理下载模型

使用如下方式: %env http_proxyhttp://10.110.146.100:7890 %env https_proxyhttp://10.110.146.100:7890

【SLAM】GNSS的定义,信号原理以及RTK在多传感器融合中的使用方法

【SLAM】GNSS的定义,信号原理以及在多传感器融合中的使用方法 1. GNSS的定义2. GNSS信号原理3. RTK - Real Time Kinematic4。 如何使用RTK做融合和优化 1. GNSS的定义 GPS(Global Positioning System)和GNSS(Global Navigation …

Ubuntu22.04安装colmap

首先上这里查看自己电脑GPU的CMAKE_CUDA_ARCHITECTURES 终端输入以下内容安装预先的前置依赖 sudo apt-get install \git cmake ninja-build build-essential \libboost-program-options-dev libboost-filesystem-dev \libboost-graph-dev libboost-system-dev libboost-tes…

【操作系统存储篇】操作系统的设备管理

目录 一、广义的IO设备 分类 按使用特性分类 按信息交换的单位分类 按设备的共享属性分类 按传输速率分类 二、IO设备的缓冲区 三、SPOOLing技术 一、广义的IO设备 输入设备:对CPU而言,凡是对CPU进行数据输入的。 输出设备:对CPU而…

深度解析:基于离线开发的数据仓库转型落地案例

在当今这个数据驱动的时代,各行各业都正经历着前所未有的变革。伴随技术的飞速发展,数据仓库作为企业数据管理与分析的核心,如何更好地发挥作用,助力企业保持业务的敏捷性与成本效益,成为大家关心的焦点问题。本文将通…

vue使用html2Canvas导出图片 input文字向上偏移

vue使用html2Canvas导出图片 input文字向上偏移 图中 用的是element的输入框 行高 32px,经常测试 你使用原生的input 还是会出现偏移。 解决方法:修改css样式 1.怎么实现导出 网上随便找很多 2.在第一步 获取你要导出的元素id 克隆后 修改他的样式或者 你直接在你需…

web渗透:SSRF漏洞

SSRF漏洞的原理 SSRF(Server-Side Request Forgery,服务器端请求伪造)是一种安全漏洞,它允许攻击者构造请求,由服务端发起,从而访问服务端无法直接访问的内部或外部资源。这种漏洞通常发生在应用程序允许用…

v$session_longops监控 PDB clone 进度

How to Monitor PDB Clone / Move On Create Pluggable Database with COPY Clause Statement Execution (Doc ID 2866302.1)​编辑To Bottom In this Document Goal Solution References APPLIES TO: Oracle Database - Enterprise Edition - Version 19.14.1.0.0 and later…

leetcode:908. 最小差值 I(python3解法)

难度&#xff1a;简单 给你一个整数数组 nums&#xff0c;和一个整数 k 。 在一个操作中&#xff0c;您可以选择 0 < i < nums.length 的任何索引 i 。将 nums[i] 改为 nums[i] x &#xff0c;其中 x 是一个范围为 [-k, k] 的整数。对于每个索引 i &#xff0c;最多 只能…

【赛题已出】2024数学建模国赛A-E题已发布

2024年高教社杯全国大学生数学建模各题赛题已发布&#xff01; A题 B题 C题 D题 E题

Linux开源监控工具netdata

Netdata 是一个免费、开源、实时、专业的服务器监控工具&#xff0c;它以可视化的形式实时展现监控主机的性能变化&#xff0c;提供了一个交互式 Web 界面来查看您的服务器指标。它可以帮助我们了解监控主机的系统或应用程序中正在发生的事情以及刚刚发生的事情&#xff0c;并且…

macos系统内置php文件列表 系统自带php卸载方法

在macos系统中, 自带已经安装了php, 根据不同的macos版本php的版本号可能不同, 我们可以通过 which php 命令来查看mac自带的默认php安装路径, 不过注意这个只是php的执行文件路径. 系统自带php文件列表 一下就是macos默认安装的php文件列表. macos 10.15内置PHP文件列表配置…