神经网络解决回归问题(更新ing)

news2024/12/26 0:19:17

神经网络应用于回归问题

神经网络是处理回归问题的强大工具,它们能够学习输入数据和输出之间的复杂关系。

神经网络提供了一种灵活且强大的框架,用于建模和预测回归问题。通过 适当的 网络结构训练策略正则化技术,可以有效地从数据中学习并做出准确的预测。

在实际应用中,选择合适的网络架构参数对于构建一个高效的回归模型至关重要

所以说,虽然神经网络是处理回归问题的强大工具,但是也存在很多问题,需要我们掌握很多方法技巧才能建立一个高效准确的回归模型:

  • 正则化(Regularization): 为了防止过拟合,可以在损失函数中添加正则化项,如L1或L2正则化。
  • Dropout: 这是一种技术,可以在训练过程中随机地丢弃一些神经元的激活,以减少模型对特定神经元的依赖。
  • 批量归一化(Batch Normalization): 通过对每一层的输入进行归一化处理,可以加速训练过程并提高模型的稳定性。
  • 早停(Early Stopping): 当验证集上的性能不再提升时,停止训练以避免过拟合。
  • 超参数调整(Hyperparameter Tuning): 通过调整网络结构(如层数每层的神经元数量)和学习率等超参数,可以优化模型的性能。

生成数据集:

输入数据:
X 1 = 100 × N ( 1 , 1 ) X_{1} = 100 \times \mathcal{N}(1, 1) X1=100×N(1,1)
X 2 = N ( 1 , 1 ) 10 X_{2} = \frac{\mathcal{N}(1, 1) }{10} X2=10N(1,1)
X 3 = 10000 × N ( 1 , 1 ) X_{3} = 10000 \times \mathcal{N}(1, 1) X3=10000×N(1,1)
输出数据 Y Y Y Y 1 Y_1 Y1:
Y = 6 X 1 − 3 X 2 + X 3 2 + ϵ Y = 6X_{1} - 3X_2 + X_3^2 + \epsilon Y=6X13X2+X32+ϵ

Y 1 = X 1 ⋅ X 2 − X 1 X 3 + X 3 X 2 + ϵ 1 Y_1 = X_1 \cdot X_2 - \frac{X_1}{X_3} + \frac{X_3}{X_2} + \epsilon_1 Y1=X1X2X3X1+X2X3+ϵ1
其中, ϵ 1 \epsilon_1 ϵ1 是均值为0,方差为0.1的正态分布噪声。

请注意,这里的 N ( μ , σ 2 ) {N}(\mu, \sigma^2) N(μ,σ2) 表示均值为 μ \mu μ ,方差为 σ 2 \sigma^2 σ2的正态分布。

下面是生成数据集的代码:

# 生成测试数据
import numpy as np
import pandas as pd
# 训练集和验证集样本总个数
sample = 2000
train_data_path = 'train.csv'
validate_data_path = 'validate.csv'
predict_data_path = 'test.csv'

# 构造生成数据的模型
X1 = np.zeros((sample, 1))
X1[:, 0] = np.random.normal(1, 1, sample) * 100
X2 = np.zeros((sample, 1))
X2[:, 0] = np.random.normal(2, 1, sample) / 10
X3 = np.zeros((sample, 1))
X3[:, 0] = np.random.normal(3, 1, sample) * 10000

# 模型
Y = 6 * X1 - 3 * X2 + X3 * X3 + np.random.normal(0, 0.1, [sample, 1])
Y1 = X1 * X2 - X1 / X3 + X3 / X2 + np.random.normal(0, 0.1, [sample, 1])

# 将所有生成的数据放到data里面
data = np.zeros((sample, 5))
data[:, 0] = X1[:, 0]
data[:, 1] = X2[:, 0]
data[:, 2] = X3[:, 0]
data[:, 3] = Y[:, 0]
data[:, 4] = Y1[:, 0]

# 将data分成测试集和训练集
num_traindata = int(0.8*sample)

# 将训练数据保存
traindata = pd.DataFrame(data[0:num_traindata, :], columns=['x1', 'x2', 'x3', 'y', 'y1'])
traindata.to_csv(train_data_path, index=False)
print('训练数据保存在: ', train_data_path)

# 将验证数据保存
validate_data = pd.DataFrame(data[num_traindata:, :], columns=['x1', 'x2', 'x3', 'y', 'y1'])
validate_data.to_csv(validate_data_path, index=False)
print('验证数据保存在: ', validate_data_path)

# 将预测数据保存
predict_data = pd.DataFrame(data[num_traindata:, 0:-2], columns=['x1', 'x2', 'x3'])
predict_data.to_csv(predict_data_path, index=False)
print('预测数据保存在: ', predict_data_path)

通用神经网络拟合函数

要根据生成的数据集建立回归模型应该如何实现呢?对于这样包含非线性的方程,直接应用通用的神经网络模型可能效果并不好,就像这样:

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd

class FNN(nn.Module):
    def __init__(self,Arc,func,device):
        super(FNN, self).__init__()  # 调用父类的构造函数
        self.func = func # 定义激活函数
        self.Arc = Arc # 定义网络架构
        self.device = device
        self.model = self.create_model().to(self.device)
        # print(self.model)

    def create_model(self):
        layers = []
        for ii in range(len(self.Arc) - 2):  # 遍历除最后一层外的所有层
            layers.append(nn.Linear(self.Arc[ii], self.Arc[ii + 1], bias=True))
            layers.append(self.func)  # 添加激活函数
            if ii < len(self.Arc) - 3:  # 如果不是倒数第二层,添加 Dropout 层
                layers.append(nn.Dropout(p=0.1))
        layers.append(nn.Linear(self.Arc[-2], self.Arc[-1], bias=True))  # 添加最后一层
        return nn.Sequential(*layers)

    def forward(self,x):
        out = self.model(x)
        return out

if __name__ == "__main__":
    # 定义网络架构和激活函数
    Arc = [3, 10, 20, 20, 20, 10, 2]
    func = nn.ReLU()  # 选择ReLU激活函数
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 根据是否有GPU来选择设备

    # 创建FNN模型实例
    model = FNN(Arc, func, device)
    # 定义损失函数和优化器
    criterion = nn.MSELoss()  # 均方误差损失函数
    optimizer = optim.Adam(model.parameters(), lr=0.001)  # 使用Adam优化器
    # 训练数据
    train_data_path = 'train.csv'
    train_data = pd.read_csv(train_data_path)
    features = np.array(train_data.iloc[:, :-2])

    labels = np.array(train_data.iloc[:, -2:])
    #转换成张量
    inputs_tensor = torch.from_numpy(features).float().to(device)  # 转换为浮点张量
    labels_tensor = torch.from_numpy(labels).float().to(device)  # 如果标签是数值型数

    loss_history = []
    # 训练模型
    for epoch in range(20000):
        optimizer.zero_grad()  # 清空之前的梯度
        outputs = model(inputs_tensor)  # 前向传播
        loss = criterion(outputs, labels_tensor)  # 计算损失
        loss_history.append(loss.item())  # 将损失值保存在列表中
        loss.backward()  # 反向传播
        optimizer.step()  # 更新权重
        if epoch % 1000 == 0:
            print('epoch is', epoch, 'loss is', loss.item(), )

    import matplotlib.pyplot as plt
    loss_history = np.array(loss_history)
    plt.plot(loss_history)
    plt.xlabel = ('epoch')
    plt.ylabel = ('loss')
    plt.show()

    torch.save(model, 'model\entire_model.pth')

应用这个代码得到的损失随迭代次数变化曲线如图:
在这里插入图片描述
这损失值也太大了!!!
那么应该如何修改神经网络模型使其损失函数降低呢?

————————————————

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1581408.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【CPA考试】2024注册会计师报名照片尺寸要求解读及手机拍照方法

随着2024年注册会计师考试的临近&#xff0c;众多会计专业人士和学生都开始准备报名参加这一行业的重要考试&#xff0c;报名时间为4月8日至4月30日。报名过程中&#xff0c;一张符合要求的证件照是必不可少的。本文将为您详细解读2024年注册会计师考试报名照片的尺寸要求&…

Pytorch导出FP16 ONNX模型

一般Pytorch导出ONNX时默认都是用的FP32&#xff0c;但有时需要导出FP16的ONNX模型&#xff0c;这样在部署时能够方便的将计算以及IO改成FP16&#xff0c;并且ONNX文件体积也会更小。想导出FP16的ONNX模型也比较简单&#xff0c;一般情况下只需要在导出FP32 ONNX的基础上调用下…

LINUX系统触摸工业显示器芯片应用方案--Model4(简称M4芯片)

背景介绍&#xff1a; 触摸工业显示器传统的还是以WINDOWS为主&#xff0c;但近年来&#xff0c;安卓紧随其后&#xff0c;但一直市场应用情况不够理想&#xff0c;反而是LINUX系统的触摸工业显示器大受追捧呢&#xff1f; 触摸工业显示器传统是以Windows系统为主&#xff0c…

微信小程序用户登录授权指定(旧版本)

配置旧版本基础库2.12.3 实现效果 点击登录按钮即可直接登录&#xff0c;获取用户昵称和头像 点击获取头像昵称按钮则需要授权&#xff0c;才能成功登录 代码实现 my.xml <!-- 登录页面,调试基础库为2.20.2库 --> <view class"mylogin"><block w…

权威报道 | 百分点科技:《突发事件应急预案管理办法》解读

近日&#xff0c;百分点科技CTO刘译璟作为唯一企业界代表&#xff0c;接受应急领域权威期刊——《中国应急管理》杂志邀请&#xff0c;与中国安全生产科学研究院、中央党校、中国政法大学等单位的专家一起&#xff0c;就《突发事件应急预案管理办法》&#xff08;以下简称《办法…

三支冲突分析介绍

Pawlak最早通过观察一组智能体对一组问题的意见&#xff0c;提出了冲突分析模型。U表示对象集&#xff0c;V表示属性集&#xff0c;R表示对象集和属性集之间的二元关系&#xff0c;这样一个刻画冲突分析的信息系统通过三元组&#xff08;U&#xff0c;V&#xff0c;R&#xff0…

hive管理之ctl方式

hive管理之ctl方式 hivehive --service clictl命令行的命令 #清屏 Ctrl L #或者 &#xff01; clear #查看数据仓库中的表 show tabls; #查看数据仓库中的内置函数 show functions;#查看表的结构 desc表名 #查看hdfs上的文件 dfs -ls 目录 #执行操作系统的命令 &#xff01;命令…

如何用西门子PLC手工做一个电闸监控控制系统

//S-H4CK13Maptnh// 项目地址:https://github.com/MartinxMax/S7-200_Power_monitoring 欢迎三连支持…电闸监控与控制 注意 注意安全,220V!!! 应用场景 一些工业自动化企业中需要对电闸的监控与控制。 原理图 组装 上面一行检测跳闸情况 下面一行控制当前一路电源 控制…

小剧场短剧剧集收费短剧小程序APP

1. 内容展现 付费、免费、任务解锁&#xff1a;用户可以通过付费直接观看短剧&#xff0c;也可以通过完成平台任务&#xff08;如签到、分享等&#xff09;获得免费观看的机会。这种灵活的解锁方式既满足了用户的多种需求&#xff0c;也促进了平台的活跃度。主流展现形式&…

InsectMamba:基于状态空间模型的害虫分类

InsectMamba&#xff1a;基于状态空间模型的害虫分类 摘要IntroductionRelated WorkImage ClassificationInsect Pest Classification PreliminariesInsectMambaOverall Architecture InsectMamba: Insect Pest Classification with State Space Model 摘要 害虫分类是农业技术…

【黑马头条】-day07APP端文章搜索-ES-mongoDB

文章目录 今日内容1 搭建es环境1.1 拉取es镜像1.2 创建容器1.3 配置中文分词器ik1.4 测试 2 app文章搜索2.1 需求说明2.2 思路分析2.3 创建索引和映射2.3.1 PUT请求添加映射2.3.2 其他操作 2.4 初始化索引库数据2.4.1 导入es-init2.4.2 es-init配置2.4.3 导入数据2.4.4 查询已导…

Docker容器嵌入式开发:Docker Ubuntu18.04配置mysql数据库

在 Ubuntu 18.04 操作系统中安装 MySQL 数据库的过程。下面是安装过程的详细描述: 首先,使用以下命令安装 MySQL 服务器: sudo apt install mysql-server系统会提示是否继续安装,按下 Y 键确认。 安装过程中,系统会下载并安装 MySQL 相关的软件包,包括 libaio1、mysql…

ChromeDriver / Selenium-server

一、简介 ChromeDriver 是一个 WebDriver 的实现&#xff0c;专门用于自动化控制 Google Chrome 浏览器。以下是关于 ChromeDriver 的详细说明&#xff1a; 定义与作用&#xff1a; ChromeDriver 是一个独立的服务器程序&#xff0c;作为客户端库与 Google Chrome 浏览…

STM32H7通用定时器计数功能的使用

目录 概述 1 STM32定时器介绍 1.1 认识通用定时器 1.2 通用定时器的特征 1.3 递增计数模式 1.4 时钟选择 2 STM32Cube配置定时器时钟 2.1 配置定时器参数 2.2 配置定时器时钟 3 STM32H7定时器使用 3.1 认识定时器的数据结构 3.2 计数功能实现 4 测试案例 4.1 代码…

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单视频处理实战案例 之七 简单指定视频某片段快放效果

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单视频处理实战案例 之七 简单指定视频某片段快放效果 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单视频处理实战案例 之七 简单指定视频某片段快放效果 一、简单介绍 二、简单指定视频某片段快放效果实现原理…

vue3使用jsQR解析二维码

1.了解jsQR jsQR是一个纯javascript脚本实现的二维码识别库&#xff0c;不仅可以在浏览器端使用&#xff0c;而且支持后端node.js环境。jsQR使用较为简单&#xff0c;有着不错的识别率。 2.效果图 3.二维码 4.下载jsqr包 npm i -d jsqr5.代码 <script setup> import …

STM32F407+FreeRTOS+LWIP UDP组播

开发环境介绍&#xff1a; MCU&#xff1a;STM32F407ZET6 网卡&#xff1a;LAN8720A LWIP版本&#xff1a;V1.1.0 FreeRTOS 版本&#xff1a;V10.2.1 LAN8720A硬件原理图&#xff1a; 硬件连接说明&#xff1a; MII_RX_CLK/RMII_REF_CLK ------>PA1 …

[lesson15]类与封装的概念

类与封装的概念 类的封装 类通常分为以下两个部分 类的实现细节类的使用方式 当使用类时&#xff0c;不需要关心其实现细节 当创建类时&#xff0c;才需要考虑其内部实现细节 封装的基本概念 根据经验&#xff1a;并不是类的每个属性都是对外公开的 如&#xff1a;女孩子不…

hive-3.1.2分布式搭建与hive的三种交互方式

hive-3.1.2分布式搭建&#xff1a; 一、上传解压配置环境变量 在官网或者镜像站下载驱动包 华为云镜像站地址&#xff1a; hive&#xff1a;Index of apache-local/hive/hive-3.1.2 mysql驱动包&#xff1a;Index of mysql-local/Downloads/Connector-J # 1、解压 tar -zx…

gpt科普1 GPT与搜索引擎的对比

GPT&#xff08;Generative Pre-trained Transformer&#xff09;是一种基于Transformer架构的自然语言处理模型。它通过大规模的无监督学习来预训练模型&#xff0c;在完成这个阶段后&#xff0c;可以用于各种NLP任务&#xff0c;如文本生成、机器翻译、文本分类等。 以下是关…