使用神经网络拟合6项参数

news2024/9/20 9:27:04

使用神经网络拟合6项参数

  • 1. 数据预处理
    • 1.1 添加参数解析
    • 1.2 数据预处理逻辑
    • 1.3 数据归一化及划分
    • 1.4 数据标签处理逻辑
    • 1.5 数据转torch
  • 2. 定义model
    • 2.1 CNN_LSTM
    • 2.2 Transformer
  • 3. 定义train脚本
    • 3.1 loss和optimizer
    • 3.2 train
    • 3.3 predict

1. 数据预处理

1.1 添加参数解析

为了方便管理模型和训练等参数,统一用参数解析。

def parse_args():
    """
    解析命令行参数并返回参数对象。

    返回:
    - args (argparse.Namespace): 解析后的参数对象
    """
    parser = argparse.ArgumentParser(description='命令行参数解析示例')

    # 添加参数
    parser.add_argument('--input', type=str, default='input.csv', help='输入文件的路径')
    parser.add_argument('--output', type=str, default='output.csv', help='输出文件的路径')
    parser.add_argument('--data_group', type=int, default=401, help='数据隔多少行划分一组')
    parser.add_argument('--test_size', type=float, default=0.1, help='测试集划分比例')
    # parser.add_argument('--batch_size', type=int, default=32, help='每批次的样本数量,默认值是 32')
    parser.add_argument('--batch_size', type=int, default=64, help='每批次的样本数量,默认值是 32')
    parser.add_argument('--epochs', type=int, default=50, help='训练的轮数,默认值是 10')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='学习率,默认值是 0.001')
    parser.add_argument('--model_name', type=str, default='CNN_LSTM', help='选择模型,eg:LSTM、CNN_LSTM、Transformer')
    parser.add_argument('--predict_para', type=str, default='AAdl1', help='m1, AAdl1, AAdl2, PPa0, nn, ccd1, ccd2')

    # 解析参数
    args = parser.parse_args()

    return args

1.2 数据预处理逻辑

最近一个小项目需要搭建神经网络根据输入数据的特征去拟合对应的6项参数
输入是input.csv,一共是876987x2的两列数据,分别代表特征1:Var1,特征2:Var2;使用pandas库读取,并提取出实部(如果存在的话),然后将它转换为浮点数。
在这里插入图片描述
6项参数对应的真实值如下,维度为:2187x7,每一列分别对应一项参数的真实标签值。
在这里插入图片描述

def clean_complex_number(val):
    val = str(val)
    if 'i' in val:
        val = val.split('+')[0] if '+' in val else val.split('-')[0]
    return float(val)

input_df = pd.read_csv(args.input)
output_df = pd.read_csv(args.output)

#提取出实部(如果存在的话),然后将它转换为浮点数
input_df = input_df.applymap(clean_complex_number)
output_df = output_df.applymap(clean_complex_number)

接着,输入特征是876987x2,也就是876987行, 每401行分组,并将每组展平成一维数组,一共876987÷401=2187组,最后的输入特征处理维度为:2187x802(401x2)。full_df就是拼接上标签,就是维度:2187x809

#876987 x 2
num_features = input_df.shape[1]  # 特征数 2
#一共876987行, 每401行分组,并将每组展平成一维数组
#每组就是1x401x2 = 1x802
#一共876987÷401=2187组
grouped = input_df.groupby(input_df.index // args.data_group).apply(lambda x: x.values.ravel())
#2187x802
new_input_df = pd.DataFrame(grouped.tolist(), index=grouped.index)
new_input_df.columns = [f'feature_{i}' for i in range(args.data_group * num_features)]
#将 new_input_df 和 output_df 合并成一个完整的数据框 full_df。
#2187x809
full_df = pd.concat([new_input_df, output_df.reset_index(drop=True)], axis=1)

train_df, test_df = train_test_split(full_df, test_size=args.test_size, random_state=2024)

1.3 数据归一化及划分

接着,进行归一化处理:

scaler = StandardScaler()
train_features = train_df.iloc[:, :-output_df.shape[1]]
train_labels = train_df.iloc[:, -output_df.shape[1]:]

test_features = test_df.iloc[:, :-output_df.shape[1]]
test_labels = test_df.iloc[:, -output_df.shape[1]:]

X_train = scaler.fit_transform(train_features)
y_train = np.array(train_labels)
X_val = scaler.transform(test_features)
y_val = np.array(test_labels)

1.4 数据标签处理逻辑

由于第一参数的真实值只存在于400、420、430的这三种值,所以我们将其映射到0、1、2,这样更有利于模型的拟合。

def map_unique_values(data):
    """
    将数据中的唯一值映射到从0开始的整数。

    参数:
    data (numpy.ndarray): 输入的一维数据数组

    返回:
    numpy.ndarray: 映射后的数据数组
    """
    # 获取唯一值及其数量
    unique_values = np.unique(data)

    # 创建映射字典
    mapping = {val: idx for idx, val in enumerate(unique_values)}

    # 将数据映射到新的值
    mapped_data = np.array([mapping[val] for val in data])

    return mapped_data.reshape(-1, 1)
# 只对第一列数据进行映射
y_train[:, 0:1] = map_unique_values(y_train[:, 0])
y_val[:, 0:1] = map_unique_values(y_val[:, 0])

1.5 数据转torch

# 转换为 PyTorch 张量
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32)
# 创建数据加载器
if args.model_name != 'LSTM':
    X_train_tensor = X_train_tensor.unsqueeze(1)
    X_val_tensor = X_val_tensor.unsqueeze(1)
if args.model_name == 'Transformer':
    X_train_tensor = X_train_tensor.permute(0, 2, 1)  # 重新排列维度以适应 Transformer 输入
    X_val_tensor = X_val_tensor.permute(0, 2, 1)  # 重新排列维度以适应 Transformer 输入
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
input_size = X_train.shape[1]  # 802
# 选择推理设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if args.model_name == 'Transformer':
    model = TransformerModel().to(device)
elif args.model_name == 'CNN_LSTM':
    hidden_size = 100
    output_size = 7
    model = CNN_LSTM_Model(input_size, hidden_size, output_size).to(device)

2. 定义model

2.1 CNN_LSTM

这里我定义了1个CNN和LSTM的hybird结构。

class CNN_LSTM_Model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CNN_LSTM_Model, self).__init__()

        # CNN 部分
        self.conv1 = nn.Conv1d(1, 16, kernel_size=5, stride=1, padding=2)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool1d(kernel_size=2)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=5, stride=1, padding=2)

        # 计算 CNN 输出的特征维度
        self.cnn_output_size = 32 * (input_size // 4)  # 输入大小 / 2^2

        # LSTM 部分
        self.lstm = nn.LSTM(self.cnn_output_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # CNN 部分
        x = self.relu(self.conv1(x))
        x = self.maxpool(x)
        x = self.relu(self.conv2(x))
        x = self.maxpool(x)

        # 展平 CNN 输出
        x = x.permute(0, 2, 1)  # (batch_size, seq_len, feature_size) 这里的 seq_len 是时间步长

        # LSTM 部分
        x = x.flatten(1)
        h_lstm, _ = self.lstm(x)
        out = self.fc(h_lstm)  # 只取最后一个时间步的输出

        return out

2.2 Transformer

class TransformerModel(nn.Module):
    def __init__(self):
        super(TransformerModel, self).__init__()
        self.input_embed = nn.Linear(1, 32)  # 嵌入到更高维度
        self.pos_encoder = nn.Parameter(torch.randn(802, 32))  # 位置编码
        encoder_layers = nn.TransformerEncoderLayer(d_model=32, nhead=4)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=1)
        self.fc = nn.Linear(32, 7)  # 输出层

    def forward(self, x):
        x = self.input_embed(x) + self.pos_encoder[:x.size(1), :]  # 加入位置编码
        x = x.permute(1, 0, 2)  # 调整维度以匹配 PyTorch Transformer 输入需求
        x = self.transformer_encoder(x)
        x = x.mean(dim=0)  # 池化层
        x = self.fc(x)
        return x

3. 定义train脚本

3.1 loss和optimizer

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

3.2 train

for epoch in range(args.epochs):
    model.train()
    for batch_X, batch_y in train_loader:
        outputs = model(batch_X.to(device))
        loss = criterion(outputs, batch_y.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{args.epochs}], Loss: {loss.item():.4f}')


model.eval()
with torch.no_grad():
    y_pred_tensor = model(X_val_tensor.to(device))
    y_pred = y_pred_tensor.cpu().numpy()
    for i in range(y_pred.shape[1]):
        y_pred[:, i:i+1] = map_to_discrete_values(y_pred[:, i], np.unique(y_val[:, i]))

3.3 predict

x_axis = range(len(y_pred))
figure_list = ['m1', 'AAdl1', 'AAdl2', 'PPa0', 'nn', 'ccd1', 'ccd2']
for i, param in enumerate(figure_list):
    plt.figure(figsize=(10, 6))
    plt.scatter(x_axis, y_pred[:, i], color='blue', alpha=0.5, label='Predicted')
    plt.scatter(x_axis, y_val[:, i], color='red', alpha=0.5, label='True Label')
    plt.xlabel('item')
    plt.ylabel(figure_list[i])
    plt.title(f'{param} - True vs Predicted')

    # 计算均方误差(MSE)
    mse = np.mean((y_val[:, i] - y_pred[:, i]) ** 2)
    # 打印 MSE 值
    print(f"{args.model_name}--MSE of {figure_list[i]}: {mse:.4f}")
    plt.legend()
    plt.show()

其中,可视化的第一个参数拟合结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2144920.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue+nodejs+express汽车配件商城销售管理系统 i9cgz

目录 技术栈具体实现截图系统设计思路技术可行性nodejs类核心代码部分展示可行性论证研究方法解决的思路Express框架介绍源码获取/联系我 技术栈 该系统将采用B/S结构模式,开发软件有很多种可以用,本次开发用到的软件是vscode,用到的数据库是…

动态分析基础

实验一 Lab03-01.exe文件中发现的恶意代码 问题: 1.找出这个恶意代码的导入函数与字符串列表? 2.这个恶意代码在主机上的感染迹象特征是什么? 3.这个恶意代码是否存在一些有用的网络特征码?如果存在,它们是什么? 解答: 1.找出这个恶意代…

上调铁矿石产量预期后,淡水河谷股价能否重振?

猛兽财经的核心观点: (1)尽管市场面临挑战,但淡水河谷(VALE)还是上调了2024年的铁矿石产量预期。 (2)第二季度业绩喜忧参半;收入减少,但铁矿石出货量却很强劲。 (3)投资者…

【渗透测试】-vulnhub源码框架漏洞-Os-hackNos-1

vulnhub源码框架漏洞中的CVE-2018-7600-Drupal 7.57 文章目录  前言 1.靶场搭建: 2.信息搜集: 主机探测: 端口扫描: 目录扫描: 3.分析: 4.步骤: 1.下载CVE-2018-7600的exp 2.执行exp: 3.写入木…

QCustomPlot笔记(一)

文章目录 简介将帮助文档添加到Qt Creator中编译共享库cmake工程编译提示ui_mainwindow.h找不到qcustomplot.h文件 环境:windowsQt Creator 10.0.1cmake 简介 QT中用于绘制曲线的第三方工具 下载地址:https://www.qcustomplot.com/index.php/download 第一个压缩…

心觉:不能成事的根本原因

很多人一直都很努力,每天都很忙 每天都学习很多东西,学习各种道,各种方法论 但是许多年过去了依然一事无成 自己的目标没有达成,梦想没有实现 为什么呢 关键是没有开悟 那么什么是开悟呢 现在很多人都在讲开悟 貌似开悟很…

Docker Registry API best practice 【Docker Registry API 最佳实践】

文章目录 1. 安装 docker2. 配置 docker4. 配置域名解析5. 部署 registry6. Registry API 管理7. 批量清理镜像8. 其他 👋 这篇文章内容:实现shell 脚本批量清理docker registry的镜像。 🔔:你可以在这里阅读:https:/…

《深度学习》—— PyTorch的神经网络模块中常用的损失函数

文章目录 前言一、回归模型中常用的损失函数1、平均绝对误差损失(L1Loss)2、均方误差损失(MSELoss也称L2Loss)3、SmoothL1Loss 二、分类模型中常用的损失函数1、负对数似然损失(NLLLoss)2、二元交叉熵损失&…

XML映射器-动态sql

01-动态sql 1.实现动态条件SQL 第一种方法在sql语句中加入where 11其他条件都加and就行,这样就可以根据if条件来判断要传递的参数可以有几个 第二种方法用where标签给if语句包起来 where标签的作用如下图 第三种方法用trim标签解释如下图 用choose也可以实现条件查询如下图,…

【数据结构与算法 | 灵神题单 | 自底向上DFS篇】力扣508, 1026, 951

1. 力扣508:出现次数最多的子树元素和 1.1 题目: 给你一个二叉树的根结点 root ,请返回出现次数最多的子树元素和。如果有多个元素出现的次数相同,返回所有出现次数最多的子树元素和(不限顺序)。 一个结…

在Ubuntu中编译含有JSON的文件出现报错

在ubuntu中进行JSON相关学习的时候,我发现了一些小问题,决定与大家进行分享,减少踩坑时候出现不必要的时间耗费 截取部分含有JSON部分的代码进行展示 char *str "{ \"title\":\"JSON Example\", \"author\&…

Web植物管理系统-下位机部分

本节主要展示上位机部分,采用BSP编程,不附带BSP中各个头文件的说明,仅仅是对main逻辑进行解释 main.c 上下位机通信 通过串口通信,有两位数据验证头(verify数组中保存对应的数据头 0xAA55) 通信格式 上位发送11字节…

保护您的企业免受网络犯罪分子侵害的四个技巧

在这个日益数字化的时代,小型企业越来越容易受到网络犯罪的威胁。网络犯罪分子不断调整策略,并使用人工智能来推动攻击。随着技术的进步,您的敏感数据面临的风险也在增加。 风险的不断增大意味着,做好基本工作比以往任何时候都更…

Java--stream流、方法引用

Stream流 - Stream流的好处 - 直接阅读代码的字面意思即可完美展示无关逻辑方式的语义 - Stream流把真正的函数式编程风格引入到Java中 - 代码简洁 - Stream流的三类方法 - 获取Stream流 - 创建一条流水线,并把数据放到流水线上准备进行操作 - 中间方法 - 流水线上的操作 - 一次…

【代码随想录训练营第42期 Day60打卡 - 图论Part10 - Bellman_ford算法系列运用

目录 一、Bellman_ford算法的应用 二、题目与题解 题目一:卡码网 94. 城市间货物运输 I 题目链接 题解:队列优化Bellman-Ford算法(SPFA) 题目二:卡码网 95. 城市间货物运输 II 题目链接 题解: 队列优…

MySQL高阶1783-大满贯数量

题目 找出每一个球员赢得大满贯比赛的次数。结果不包含没有赢得比赛的球员的ID 。 结果集 无顺序要求 。 准备数据 Create table If Not Exists Players (player_id int, player_name varchar(20)); Create table If Not Exists Championships (year int, Wimbledon int, F…

Unity 高亮插件HighlightPlus介绍

仅对官方文档进行了翻译 注意:官方文档本身就落后实际,但对入门仍很有帮助,核心并没有较大改变,有的功能有差异,以实际为准.(目前我已校正了大部分差异,后续我会继续维护该文档) 为什么为该插件做翻译?功能强大,使用简单,且还在维护. 基于此版本的内置渲染管线文档 快速开始…

C语言之预处理详解(完结撒花)

目录 前言 一、预定义符号 二、#define 定义常量 三、#define定义宏 四、宏与函数的对比 五、#和## 运算符 六、命名约定 七、#undef 八、条件编译 九、头文件的包含 总结 前言 本文为我的C语言系列的最后一篇文章,主要讲述了#define定义和宏、#和##运算符、各种条件…

9.18作业

提示并输入一个字符串&#xff0c;统计该字符串中字母、数字、空格、其他字符的个数并输出 代码展示 #include <iostream>using namespace std;int main() {string str;int countc 0; // 字母计数int countn 0; // 数字计数int count 0; // 空格计数int counto 0;…

IEEE-754 32位十六进制数 转换为十进制浮点数

要将 IEEE-754 32位十六进制数 转换为 十进制浮点数&#xff0c;可以使用LabVIEW中的 Type Cast 函数。以下是一些具体步骤&#xff0c;以及相关实例的整理&#xff1a; 实现步骤&#xff1a; 输入十六进制数&#xff1a;在LabVIEW中&#xff0c;首先需要创建一个输入控制器&am…