神经网络拟合离散标签值

news2024/9/20 12:56:42

神经网络拟合离散标签值

  • 1. 数据预处理
    • 1.1 添加参数解析
    • 1.2 数据预处理逻辑
    • 1.3 标签处理逻辑
    • 1.4 构建特征和标签
    • 1.5 数据归一化、转torch
    • 1.6 实现Dataset类
  • 2. 定义model
  • 3. 定义train脚本
    • 3.1 loss和optimizer
    • 3.2 train
    • 3.3 predict

1. 数据预处理

1.1 添加参数解析

为了方便管理模型和训练等参数,统一用参数解析。

def parse_args():
    """
    解析命令行参数并返回参数对象。

    返回:
    - args (argparse.Namespace): 解析后的参数对象
    """
    parser = argparse.ArgumentParser(description='命令行参数解析示例')

    # 添加参数
    #代表输入文件的名称:g0-25%.xlsx、l0.4-2.6.xlsx、s0.4-1.6.xlsx、y0.8-4.0.xlsx
    parser.add_argument('--input', type=str, default='g0-25%.xlsx', help='数据文件路径')
    #代表预测excel中哪一列的参数值,比如要预测g0-25%.xlsx中g取0的这一列
    parser.add_argument('--predict_para', type=float, default=12.5, help='代表预测excel中哪一列的参数值')
    parser.add_argument('--batch_size', type=int, default=4, help='每批次的样本数量,默认值是 4')
    parser.add_argument('--epochs', type=int, default=100, help='训练的轮数,默认值是100')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='学习率,默认值是 0.001')

    # 解析参数
    args = parser.parse_args()

    return args

1.2 数据预处理逻辑

最近一个小项目需要搭建神经网络根据输入数据的特征去拟合对应的标签值
输入是excel的.xlsx文件,一共是1400x12的12列数据,第一列是时间,第2-第12列都是对应的数据,第一行是对应的标签值;使用pandas库读取。
在这里插入图片描述

args = parse_args()

file_path = args.input  # 请替换为实际的数据文件路径

# 读取 Excel 文件
df = pd.read_excel(file_path)
data = df.to_numpy()[:, 1:]

# 标签就是第一行对应的离散标签值
labels = np.array(df.columns)[1:].astype(np.float32)
labels = np.round(labels, 2)

1.3 标签处理逻辑

从第一行的标签值来看,每个标签真实值都是离散值,我们可以将其映射为0、1、2…,把回归问题当成分类问题来,更有利于模型的拟合。

scaler = StandardScaler()
train_features = train_df.iloc[:, :-output_df.shape[1]]
train_labels = train_df.iloc[:, -output_df.shape[1]:]

test_features = test_df.iloc[:, :-output_df.shape[1]]
test_labels = test_df.iloc[:, -output_df.shape[1]:]

X_train = scaler.fit_transform(train_features)
y_train = np.array(train_labels)
X_val = scaler.transform(test_features)
y_val = np.array(test_labels)

1.4 构建特征和标签

这里,data是1400x11维度的数据,将其转置``,则表示成11组(1400, 1)的数据,刚好对应11个离散标签值。标签y则是(11, 1),对应11个映射好的标签值

# 构建特征和标签
X: object = data.T  # [样本数, 特征数]
y = label_indices  # 标签

1.5 数据归一化、转torch

# 数据预处理:标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 创建 Dataset 对象
dataset = CustomDataset(X_scaled, label_indices)

# 创建 DataLoader 对象
batch_size = args.batch_size
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
#实例化网络模型,output对应预测的标签值的数量
model = CNNClassifier(fc1_channels=44544, output=11)

1.6 实现Dataset类

这里,对data添加维度,为(11, 1, 1400),labels 转int型,最后都转tensor
经过上述的转置后,这里的__len__,就是return的11,正好是11组数据,__getitem__正好根据索引idx返回对应的每一组数据。

# 自定义 Dataset 类
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32).unsqueeze(1)  # Add channel dimension
        self.labels = torch.tensor(labels, dtype=torch.int64)  # Long for classification

    def __len__(self):
        return self.data.size(0)  # Number of samples

    def __getitem__(self, idx):
        # return self.data[:, idx, :], self.labels[idx]
        return self.data[idx], self.labels[idx]

2. 定义model

这里我定义了1个CNN结合FC的网络结构。

# 定义 CNN 模型
class CNNClassifier(nn.Module):
    def __init__(self, fc1_channels, output):
        super(CNNClassifier, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=5)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=5)
        self.fc1 = nn.Linear(fc1_channels, 128)  # Flattened feature size is 32 * 1390
        self.fc2 = nn.Linear(128, output)  #  classes

    def forward(self, x):
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)
        return x

3. 定义train脚本

3.1 loss和optimizer

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

3.2 train

# 训练模型
num_epochs = args.epochs
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_data, batch_labels in dataloader:
        outputs = model(batch_data)
        loss = criterion(outputs, batch_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * batch_data.size(0)
    epoch_loss = running_loss / len(dataset)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}')

3.3 predict

最后在预测的时候别忘了根据上述映射关系,将预测的标签值映射回去

# 进行预测
model.eval()
with torch.no_grad():
    # 选择一个样本进行预测
    # sample_idx = 10
    sample_labels_list = [round(label, 2) for label in labels.tolist()]
    sample_idx = sample_labels_list.index(args.predict_para)
    eval_data = torch.tensor(X_scaled[sample_idx], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    predictions = model(eval_data)
    _, predicted_label_idx = torch.max(predictions, 1)
    predicted_label = labels[predicted_label_idx.item()]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2148950.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第二证券:金价涨了!创一历史之最!

当地时间周四,金融商场进一步消化美联储大幅降息50个基点的利率抉择,认为这是为了完结美国经济“软着陆”的一次防备式降息,而非紧急应对阑珊风险的降息,加之当天公布的上星期初度申请赋闲救助人数低于预期,投资者对美…

B站前端错误监控实践

前言 从23年开始,我们团队开始前端错误监控方向的开发。经历了一些列的迭代和发展,从监控SDK、上报、数据治理、看板集成、APM自研可视化初步完成了一条完整且适合B站前端监控。 截止目前(2024.08.01),前端监控在B站85%以上的业务线&#xf…

Linux运维培训班靠谱吗?如何判断一个培训班的教学质量?

当我们下定决心打算系统培训Linux运维时,哪家机构比较好成为了我们面临的最大难题。之前分享过很多培训机构的个人感受,但授人以鱼不如授人以渔,说到底那些都是我个人的看法,因此今天打算帮助大家学会如何判断一个培训班的好坏。 …

使用 SSCB 保护现代高压直流系统的优势

在各种应用中,系统效率和功率密度不断提高,这导致了更高的直流系统电压。然而,传统的电路保护解决方案不足以在保持高可靠性和安全性的同时有效保护这些高压配电系统。 固态断路器 (SSCB) 和电熔断器具有众多优点&…

GitLab 迁移并推送代码仓库

迁移并推送代码仓库到 GitLab 可以有多种方法,以下是一些常见的步骤: 一、创建空仓库 在 Gitlab 上创建一个空仓库 方式一:点击左上角“+”号,选择新建项目/仓库 方式二:进入“项目”界面,点击右上角“新建项目”按钮 选择“创建空白项目” 填写项目信息并点击“新…

从黎巴嫩电子通信设备爆炸看如何防范网络电子袭击

引言: 在当今数字化时代,电子通信设备已成为我们日常生活中不可或缺的一部分。然而,近期黎巴嫩发生的电子设备爆炸事件提醒我们,这些设备也可能成为危险的武器。本文将深入探讨电子袭击的原理、防范措施,以及网络智能…

SpinalHDL之结构(四)

本文作为SpinalHDL学习笔记第六十四篇,介绍SpinalHDL的时钟域(Clock domains)。 目录: 1.简介(Introduction) 2.例化(Instantiation) ⼀、简介(Introduction) 在SpinalHDL中, 时钟和复位信号能结合起来构成时钟域(clock domain)。时钟域可以应⽤于设计的某些区域中, 例化在…

在线安全干货|如何更改IP地址?

更改IP地址是一个常见的需求,无论是为了保护个人隐私、绕过地理限制还是进行商业数据分析。不同的IP更改方法适用于不同的需求和环境。但请注意,更改IP地址应在合法场景下进行,无论使用什么方法,都需要在符合当地网络安全法律法规…

开源链动 2+1 模式 S2B2C 商城小程序中的产品为王理念

摘要:本文深入探讨了在社交电商领域中,开源链动 21 模式 S2B2C 商城小程序如何践行“产品为王”的理念。分析了在社交电商野蛮生长时期好产品的稀缺性以及选对产品的重要性,同时阐述了因产品问题导致的不良后果,并强调了在该小程序…

Spring Boot 整合 MyBatis 的详细步骤(两种方式)

1. Spring Boot 配置 MyBatis 的详细步骤 1、首先,我们创建相关测试的数据库,数据表。如下: CREATE DATABASE springboot_mybatis USE springboot_mybatisCREATE TABLE monster ( id int not null auto_increment, age int not null, birthda…

一篇文章读懂什么事 LLM 训练:从预训练到微调【大模型应用入门系列】

自然语言处理(NLP)是人工智能领域中一项重要的研究方向,涉及机器对人类语言进行理解和生成。然而,语言的复杂性和多样性使得处理自然语言任务成为一项极具挑战性的任务。在这个领域中,LLM Training 扮演着至关重要的角…

Visual Studio配置opencv环境

(1)打开属性页面(鼠标放在解决方案上,点击右键会有一个属性选项弹出) (2)配置opencv的include和opencv2路径,具体路径和版本根据自己电脑配置 (3)配置opencv…

2017年国赛高教杯数学建模A题CT系统参数标定及成像解题全过程文档及程序

2017年国赛高教杯数学建模 A题 CT系统参数标定及成像 CT(Computed Tomography)可以在不破坏样品的情况下,利用样品对射线能量的吸收特性对生物组织和工程材料的样品进行断层成像,由此获取样品内部的结构信息。一种典型的二维CT系统如图1所示&#xff0c…

品牌网站建设如何做

品牌网站建设是一项复杂而关键的任务,它直接影响着企业在线形象和市场竞争力。一个成功的品牌网站不仅仅是一个展示产品或服务的平台,更是一个能够吸引、保留用户并传递品牌价值的载体。下面是一些关键步骤,以及在品牌网站建设中需要考虑的一…

12 - TCPServer实验

在上一章节中,我们学习了TCPClient通信测试的相关知识。接下来,本章节将以此为基础,构建一个基础性的TCPServer连接机制,该机制将利用之前所建立的WIFI网络连接。为方便演示,我们将借助网络调试助手工具进行数据的发送…

金砖软件测试赛项之Jmeter如何录制脚本!

一、简介 Apache JMeter 是一款开源的性能测试工具,用于测试各种服务的负载能力,包括Web应用、数据库、FTP服务器等。它可以模拟多种用户行为,生成负载以评估系统的性能和稳定性。 JMeter 的主要特点: 图形用户界面:…

基于CNN的10种物体识别项目

一:数据导入和处理 1.导入相关包: import numpy as np import pandas as pd import matplotlib.pyplot as plt import tensorflow as tf2.下载数据 (x_train_all, y_train_all), (x_test, y_test) tf.keras.datasets.cifar10.load_data()# x_valid:测…

使用Rust直接编译单个的Solidity合约

这里写自定义目录标题 使用Rust直接编译单个的Solidity合约前言预备知识准备工作示例 使用Rust直接编译单个的Solidity合约 前言 我们知道,我们平常开发Solidity智能合约时一般使用Hardhat框架,但是如果你是一个Rustacean (这是由 “Rust” 和 “crust…

Cloudera安装不再复杂:基础环境设置详解

Cloudera Manager是CDH市场领先的管理平台。它以其强大的数据管理和分析能力,帮助企业能够轻松驾驭海量数据,实现数据的实时分析与洞察。 作为业界第一的端到端 Apache Hadoop 的管理应用,Cloudera Manager对CDH的每个部件都提供了细粒度的可…

windows10 ipv4设置(多个)网段同时连接

注意另一个网段的测试设备必须插在你现在用的电脑上 如果没用那就换几个网口试试,换几个转接器试试,理论是可以的,如果不行那就是硬件坏了 二、如果还不行那就这样 注意:pcie是网线接在主机上,usb是转接器的网络 把你…