基于神经网络的弹弹堂类游戏弹道快速预测

news2024/12/26 20:08:29

目录

一、 目的... 1

1.1 输入与输出.... 1

1.2 隐网络架构设计.... 1

1.3 激活函数与损失函数.... 1

二、 训练... 2

2.1 数据加载与预处理.... 2

2.2 训练过程.... 2

2.3 训练参数与设置.... 2

三、 测试与分析... 2

3.1     性能对比.... 2

3.2     训练过程差异.... 3

四、 训练过程中的损失变化... 3

五、 代码... 6


一、目的

在机器学习中,神经网络是解决回归和分类问题的强大工具。本文通过对比全连接神经网络(SimpleNN)在不同激活函数下的表现,探索不同激活函数对模型训练过程和最终性能的影响。本实验通过使用PyTorch框架,首先使用ReLU激活函数,之后将激活函数切换为tanh,分析这两种激活函数在回归问题中的差异。

1.1 输入与输出

本实验中的神经网络模型输入的是来自MATLAB文件(data.mat)的数据集,其中包括4个输入特征和1个输出标签。数据通过标准化处理后输入神经网络,网络模型通过学习特征和标签之间的关系来预测输出。最终网络输出为一个连续值,即回归问题中的预测值。

1.2 隐网络架构设计

SimpleNN模型:
本实验使用了一个简单的前馈神经网络模型,包含一个输入层、一个隐藏层和一个输出层。输入层的节点数与特征数量相同,输出层的节点数与标签数量相同。隐藏层的节点数设置为10。激活函数用于隐藏层的神经元,以增加模型的非线性表达能力。

在此实验中,我们首先使用了ReLU激活函数进行训练,然后将激活函数替换为tanh进行对比分析。

1.3 激活函数与损失函数

  1. 激活函数选择
  • ReLU(Rectified Linear Unit):
    是一种常用的激活函数,其输出为正输入或零。ReLU有助于缓解梯度消失问题,并加速神经网络的训练。

  • tanh(双曲正切函数):
    是一种平滑的非线性激活函数,其输出范围为-1到1。与ReLU相比,tanh的输出范围较小,并且存在梯度消失的风险,但它能够处理负值输入,适用于某些回归任务。

  1. 损失函数选择
    本实验使用均方误差(MSE)作为损失函数,用于回归任务中度量模型预测与真实输出之间的差异。

二、训练

2.1 数据加载与预处理

数据集来自MATLAB的.mat文件。输入特征(4个)和输出标签(1个)首先被提取,并通过MinMaxScaler进行归一化处理。数据集被随机分割为训练集和测试集,其中50个样本用于测试,剩余的用于训练。

2.2 训练过程

网络通过3000次迭代进行训练。在每一次迭代中,模型使用训练数据进行前向传播,计算预测结果与真实标签之间的损失。然后进行反向传播,更新网络的参数。训练的停止条件为损失低于设定阈值(1e-14)。

2.3 训练参数与设置

训练过程中使用的主要参数如下:

  • 学习率: 0.001
  • 训练轮次: 最大3000次,或提前停止
  • 损失函数: 均方误差(MSE)损失函数
  • 优化器: Adam

三、测试与分析

3.1 性能对比

  • 使用ReLU激活函数时:
    在训练过程中,模型的损失函数逐渐下降,表现出良好的学习效果。最终损失值趋近于0,表明网络能够较好地拟合训练数据。测试时,模型能够有效地预测测试集的数据,偏差较小。

  • 使用tanh激活函数时:
    与ReLU相比,使用tanh激活函数时,损失下降的速度较慢,且网络训练的初期出现较大的波动。这可能与tanh的输出范围(-1到1)有关,导致梯度消失问题,尤其是在多层网络中。

3.2 训练过程差异

  1. 收敛速度
  • ReLU: 在训练初期收敛较快,且表现出较好的梯度更新能力。在训练过程中,模型的准确性和损失函数下降速度较为平稳。

  • tanh: 收敛速度较慢,且在训练初期存在较大的梯度波动。由于其在负输入下的饱和特性,可能导致梯度更新较慢,尤其是在深层网络中。

  1. 偏差分析
  • 使用ReLU时: 偏差较小,模型预测与实际值之间的差异较少,说明模型具有较好的预测能力。

  • 使用tanh时: 偏差稍大,尤其是在某些测试样本上。虽然损失函数已经较低,但由于tanh的输出范围限制,模型在某些输入上可能无法达到完全准确的预测。


四、训练过程中的损失变化

图 1: ReLU训练损失曲线
图 2: ReLU测试数据集结果图
图 3: tanh训练损失曲线
图 4: tanh测试数据集结果图

1 relu训练损失曲线

2 relu测试数据集结果图

3 tanh训练损失曲线

4 tanh测试数据集结果图

  • 代码

import numpy as np

import torch

import torch.nn as nn

import torch.optim as optim

from sklearn.preprocessing import MinMaxScaler

import matplotlib.pyplot as plt

import scipy.io

# 加载 .mat 文件(替换为实际的路径)

data = scipy.io.loadmat('E:\\Learn Project\\matlab_pjt\\data.mat')

# 获取数据

data = data['data']

# 输入和输出数据

inputs = data[:, :4# 输入特征

outputs = data[:, 4:]  # 输出标签

# 随机分割数据为训练集和测试集

test_size = 50  # 测试集大小

indices = np.random.permutation(len(inputs))

train_indices = indices[test_size:]

test_indices = indices[:test_size]

input_train = inputs[train_indices]

output_train = outputs[train_indices]

input_test = inputs[test_indices]

output_test = outputs[test_indices]

# 数据归一化

scaler_input = MinMaxScaler()

scaler_output = MinMaxScaler()

input_train_scaled = scaler_input.fit_transform(input_train)

output_train_scaled = scaler_output.fit_transform(output_train)

input_test_scaled = scaler_input.transform(input_test)

output_test_scaled = scaler_output.transform(output_test)

# 转换为 PyTorch 张量

X_train_tensor = torch.tensor(input_train_scaled, dtype=torch.float32)

y_train_tensor = torch.tensor(output_train_scaled, dtype=torch.float32)

X_test_tensor = torch.tensor(input_test_scaled, dtype=torch.float32)

y_test_tensor = torch.tensor(output_test_scaled, dtype=torch.float32)

# 定义简单的神经网络

class SimpleNN(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):

        super(SimpleNN, self).__init__()

        self.layer1 = nn.Linear(input_size, hidden_size)

        self.layer2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):

        x = torch.relu(self.layer1(x))  # 激活函数 ReLU

        x = self.layer2(x)  # 输出层

        return x

# 网络参数

input_size = input_train.shape[1]

hidden_size = 10

output_size = output_train.shape[1]

# 创建模型

model = SimpleNN(input_size, hidden_size, output_size)

# 损失函数和优化器

criterion = nn.MSELoss()  # 均方误差损失

optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型

epochs = 3000

loss_history = []  # 保存损失变化

for epoch in range(epochs):

    optimizer.zero_grad()  # 清空梯度

    output = model(X_train_tensor)  # 前向传播

    loss = criterion(output, y_train_tensor)  # 计算损失

    loss.backward()  # 反向传播

    optimizer.step()  # 更新参数

    # 记录损失

    loss_history.append(loss.item())

    # 停止条件

    if loss.item() < 1e-14:

        print(f"训练提前停止,当前迭代:{epoch}")

        break

# 绘制训练损失图

plt.plot(loss_history)

plt.xlabel('Epoch')

plt.ylabel('Loss (MSE)')

plt.title('Training Loss History')

plt.show()

# 测试模型

with torch.no_grad():

    model.eval()  # 设置模型为评估模式

    y_test_pred_scaled = model(X_test_tensor)  # 预测

    y_test_pred = scaler_output.inverse_transform(y_test_pred_scaled.numpy())  # 反归一化

# 计算每个样本的偏差

deviation = np.sqrt(np.sum((output_test - y_test_pred) ** 2, axis=1))  # 欧几里得距离

# 绘制偏差图

plt.plot(deviation, marker='o', color='red')

plt.xlabel('Sample Index')

plt.ylabel('Deviation')

plt.title('Test Deviation')

plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2254524.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux入门攻坚——40、Linux集群系统入门-lvs(1)

Cluster&#xff0c;集群&#xff0c;为了解决某个特定问题将多台计算机组合起来形成的单个系统。 这个单个集群系统可以扩展&#xff0c;系统扩展的方式&#xff1a;scale up&#xff0c;向上扩展&#xff0c;更换更好的主机&#xff1b;scale out&#xff0c;向外扩展&…

威联通-001 手机相册备份

文章目录 前言1.Qfile Pro2.Qsync Pro总结 前言 威联通有两种数据备份手段&#xff1a;1.Qfile Pro和2.Qsync Pro&#xff0c;实践使用中存在一些区别&#xff0c;针对不同备份环境选择是不同。 1.Qfile Pro 用来备份制定目录内容的。 2.Qsync Pro 主要用来查看和操作文…

Docker单机网络:解锁本地开发环境的无限潜能

作者简介&#xff1a;我是团团儿&#xff0c;是一名专注于云计算领域的专业创作者&#xff0c;感谢大家的关注 座右铭&#xff1a; 云端筑梦&#xff0c;数据为翼&#xff0c;探索无限可能&#xff0c;引领云计算新纪元 个人主页&#xff1a;团儿.-CSDN博客 目录 前言&#…

【Linux操作系统】多线程控制(创建,等待,终止、分离)

目录 一、线程与轻量级进程的关系二、进程创建1.线程创建线程创建函数&#xff08;pthread&#xff09;查看和理解线程id主线程与其他线程之间的关系 三、线程等待&#xff08;回收&#xff09;四、线程退出线程退出情况线程退出方法 五、线程分离线程的优点线程的缺点 一、线程…

解决IDEA的easycode插件生成的mapper.xml文件字段之间逗号丢失

问题 easycode插件生成的mapper.xml文件字段之间逗号丢失&#xff0c;如图 解决办法 将easycode(在settings里面的othersettings)设置里面的Template的mapper.xml.vm和Global Config的mybatisSupport.vm的所有$velocityHasNext换成$foreach.hasNext Template的mapper.xml.vm(…

Android 实现中英文切换

在开发海外项目的时候&#xff0c;需要实现app内部的中英文切换功能&#xff0c;所有的英文都是内置的&#xff0c;整体思路为&#xff1a; 创建一个sp对象&#xff0c;存储当前系统的语言类型&#xff0c;然后在BaseActivity中对语言进行判断&#xff1b; //公共Activitypubl…

11月 | Apache DolphinScheduler月度进展总结

各位热爱 Apache DolphinScheduler 的小伙伴们&#xff0c;社区10月份月报更新啦&#xff01;这里将记录 DolphinScheduler 社区每月的重要更新&#xff0c;欢迎关注&#xff01; 月度Merge之星 感谢以下小伙伴11月份为 Apache DolphinScheduler 所做的精彩贡献&#xff08;排…

[软件开发幼稚指数评比]《软件方法》自测题解析010

第1章自测题 Part2 **9 [**单选题] 以下说法和其他三个最不类似的是: A)如果允许一次走两步&#xff0c;新手也能击败象棋大师 B)百米短跑比赛才10秒钟&#xff0c;不可能为每一秒做周密计划&#xff0c;凭感觉跑就是 C)即使是最好的足球队&#xff0c;也不能保证每…

【JavaWeb后端学习笔记】使用IDEA连接MySQL数据库

IDEA连接MySQL IDEA中集成了DataGrip&#xff0c;因此可以直接使用IDEA操作MySQL数据库。 1.创建一个新的空工程。点击右侧的数据库标志。 2.选择要连接的数据库。第一步&#xff1a;点击“”&#xff1b;第二步&#xff1a;点击 Data Source&#xff1b;第三步&#xff1a;选…

大模型分类2—按训练方式

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl根据训练方式,大模型可分为监督学习、无监督学习、自监督学习和强化学习大模型。 1. 监督学习大模型 1.1 定义与原理 监督学习大模型是一种机器学习范式,它依赖于标记数据集进行训练。这些数据…

鸿蒙特色实战2

服务卡片开发 创建服务卡片 创建一个新的工程后&#xff0c;可以通过如下方法进行创建服务卡片&#xff1a; 创建服务卡片包括如下两种方式&#xff1a; 选择模块&#xff08;如entry模块&#xff09;下的任意文件&#xff0c;单击菜单栏File > New > Service Widget创…

LCD1602液晶显示屏指令详解

文章目录 LCD1602液晶显示屏1.简介2. 液晶引脚说明3. 指令介绍3.1 清屏指令3.2 光标归位指令3.3 进入模式设置指令3.4 显示开关设置指令3.5 设定显示或光标移动方向指令3.6 功能设定指令3.7 设定CGRAM地址指令3.8 设定DDRAM地址指令3.9 读取忙或AC地址指令3.10 总图3.11 DDRAM …

Python毕业设计选题:基于大数据的旅游景区推荐系统_django

开发语言&#xff1a;Python框架&#xff1a;djangoPython版本&#xff1a;python3.7.7数据库&#xff1a;mysql 5.7数据库工具&#xff1a;Navicat11开发软件&#xff1a;PyCharm 系统展示 系统首页界面 用户注册界面 用户登录界面 景点信息界面 景点资讯界面 个人中心界面 …

引领素养教育行业,猿辅导素养课斩获“2024影响力教育品牌”奖项

近日&#xff0c;由教育界网、校长邦联合主办&#xff0c;鲸媒体、职教共创会协办的“第9届榜样教育年度盛典”评奖结果揭晓。据了解&#xff0c;此次评选共有近500家企业提交参评资料进行奖项角逐&#xff0c;历经教育界权威专家、资深教育从业者以及专业评审团队的多轮严格筛…

十七、监控与度量-Prometheus/Grafana/Actuator

文章目录 前言一、Spring Boot Actuator1. 简介2. 添加依赖2. 开启端点3. 暴露端点4. 总结 二、Prometheus1. 简介2. Prometheus客户端3. Prometheus服务端4. 总结 三、Grafana1. 简介2. Grafana安装3. Grafana配置 前言 系统监控‌ 在企业级的应用中&#xff0c;系统监控至关…

PHP语法学习(第六天)

&#x1f4a1;依照惯例&#xff0c;回顾一下昨天讲的内容 PHP语法学习(第五天)主要讲了PHP中的常量和运算符的运用。 &#x1f525; 想要学习更多PHP语法相关内容点击“PHP专栏” 今天给大家讲课的角色是&#x1f34d;菠萝吹雪&#xff0c;“我菠萝吹雪吹的不是雪&#xff0c;而…

关于遥感图像镶嵌后出现斑点情况的解决方案

把几张GF1的影像镶嵌在一起后&#xff0c;结果在Arcgis里出现了明显的斑点情况&#xff08;在ENVI里显示则不会出现&#xff09;&#xff0c;个人觉得可能是斑点噪声问题&#xff0c;遂用Arcgis的滤波工具进行滤波处理&#xff0c;但由于该工具本身没有直接设置对多波段处理方式…

【嵌套查询】.NET开源 ORM 框架 SqlSugar 系列

.NET开源 ORM 框架 SqlSugar 系列 【开篇】.NET开源 ORM 框架 SqlSugar 系列【入门必看】.NET开源 ORM 框架 SqlSugar 系列【实体配置】.NET开源 ORM 框架 SqlSugar 系列【Db First】.NET开源 ORM 框架 SqlSugar 系列【Code First】.NET开源 ORM 框架 SqlSugar 系列【数据事务…

单链表---合并两个链表

将两个升序链表合并为一个新的升序链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 struct ListNode {int val;struct ListNode* next; }; w 方法一---不使用哨兵位 我们创建一个新链表用于合并两个升序链表&#xff0c; 将两个链表中最小的结点依次尾插到…

vue聊天对话语音消息播放动态特效

vue2写法&#xff0c;vue3也能用&#xff0c;粘之即走&#xff1a; 示例&#xff1a; <template><div class"voice-hidden"><divclass"voice-play-chat":class"[className, { animate-stop: !isPlaying }]"><div class&q…