深度学习技巧应用22-构建万能数据生成类的技巧,适用于CNN,RNN,GNN模型的调试与训练贯通

news2024/11/15 14:04:24

大家好,我是微学AI,今天给大家介绍一下深度学习技巧应用22-构建万能数据生成类的技巧,适用于CNN,RNN,GNN模型的调试与训练贯通。本文将实现了一个万能数据生成类的编写,并使用PyTorch框架训练CNN、RNN和GNN模型。

目录:
1.背景介绍
2.依赖库介绍
3.万能的数据生成器介绍
4.CNN,RNN,GNN模型搭建
5.数据生成与模型训练
6.训练结果与总结
在这里插入图片描述

1.背景介绍

在人工智能模型训练过程中,我们需要进行一些实验、测试或调试,我们可能需要一个具有特定形状和数量的数据集来验证我们的算法或模型。通过构建一个万能的数据生成器,我们可以灵活地生成各种形状和大小的数据集,无需手动制作和准备真实数据集。

其次,数据生成器还可以用于探索数据集的性质和特征。通过生成具有特定分布、特征或规律的数据集,我们可以更深入地了解数据之间的关系、特征之间的相互影响以及数据的结构等。这对于数据预处理、特征工程和模型选择都非常有帮助。

此外,数据生成器还可以用于实现数据增强技术。数据增强是指通过对原始数据进行一系列变换或扰动来生成新的训练样本,以增加训练数据的多样性和泛化能力。通过构建一个万能的数据生成器,我们可以定义各种数据增强方法,并在训练过程中动态地生成增强后的样本,从而提高模型的稳健性和可靠性。

2.依赖库介绍

首先,我们需要引入以下依赖库:

  • torch:PyTorch框架
  • torch.optim:torch.optim是PyTorch框架中的一个模块,用于优化模型的参数。它提供了各种优化算法,如随机梯度下降(SGD)、Adam、Adagrad等。通过选择适当的优化算法和调整参数,可以使模型在训练过程中更好地收敛并获得更好的性能。
  • torch.utils.data:torch.utils.data是PyTorch框架中的一个模块,用于处理数据集的工具类。它提供了一些常用的数据处理操作,如数据加载、批量处理、数据迭代和数据转换等。通过使用torch.utils.data,可以方便地将数据集加载到模型中进行训练,并且能够灵活地处理不同格式的数据。
  • numpy:numpy是一个Python库,主要用于进行数值计算和科学计算。它提供了多维数组对象(ndarray)和一系列用于操作数组的函数。numpy可以高效地进行数值运算,并且支持广播(broadcasting)和向量化操作,因此在科学计算、数据分析和机器学习等领域都得到广泛应用。在PyTorch中,numpy可以与torch.Tensor进行无缝的转换,方便进行数据的处理和转换。

3.万能的数据生成器介绍

首先我们需要定义了一个名为UniversalDataset的数据集类,用于生成具有特定形状和数量的数据和标签。

在类的初始化方法__init__中,我们传入了三个参数:data_shape表示数据的形状(一个元组),target_shape表示标签的形状(一个元组),num_samples表示数据集中样本的数量。通过这三个参数生成数据。

接着,我们实现了__len__方法,该方法返回数据集中样本的数量,即num_samples。

再定义__getitem__方法,该方法根据索引idx返回数据集中索引对应的数据和标签。在这个方法中,我们首先创建了一个与data_shape相同形状的全零张量data,以及一个与target_shape相同形状的全零张量target。

然后,我们分别计算了数据和标签的维度,即data_dims和target_dims。

本文使用torch.linspace函数在0和1之间生成长度为data_dim_size的等间隔数据范围data_range,并通过reshape方法将其重新塑形为data_shape_expanded形状的张量。然后,我们将这个塑形后的数据范围加到数据张量data上。

我们对标签也进行了类似的操作,生成了一个有规律的标签张量target。

最后,我们返回了数据张量data和标签张量target作为这个索引对应的样本。

通过这个类,我们可以根据需要生成具有指定形状和数量的数据集,并且数据和标签都是有规律的,方便进行后续的训练和评估。

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

class UniversalDataset(Dataset):
    def __init__(self, data_shape, target_shape, num_samples):
        self.data_shape = data_shape
        self.target_shape = target_shape
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 生成数据和标签
        data = torch.zeros(self.data_shape)
        target = torch.zeros(self.target_shape)

        # 计算数据和标签的维度
        data_dims = len(self.data_shape)
        target_dims = len(self.target_shape)

        # 生成有规律的数据和标签
        for dim in range(data_dims):
            data_dim_size = self.data_shape[dim]
            data_range = torch.linspace(0, 1, data_dim_size)
            data_shape_expanded = [1] * data_dims
            data_shape_expanded[dim] = data_dim_size
            data += data_range.reshape(data_shape_expanded)

        for dim in range(target_dims):
            target_dim_size = self.target_shape[dim]
            target_range = torch.linspace(0, 1, target_dim_size)
            target_shape_expanded = [1] * target_dims
            target_shape_expanded[dim] = target_dim_size
            target += target_range.reshape(target_shape_expanded)

        return data, target

4.CNN,RNN,GNN模型搭建

class CNNModel(nn.Module):
    def __init__(self, input_shape):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(input_shape[0], 16, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(16 * (input_shape[1] // 2) * (input_shape[2] // 2), 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class RNNModel(nn.Module):
    def __init__(self, input_shape):
        super(RNNModel, self).__init__()
        self.rnn = nn.RNN(input_shape[1], 64, batch_first=True)
        self.fc = nn.Linear(64, 10)

    def forward(self, x):

        _, h_n = self.rnn(x)
        x = self.fc(h_n.squeeze(0))
        return x

class GNNModel(nn.Module):
    def __init__(self, input_shape):
        super(GNNModel, self).__init__()
        self.fc1 = nn.Linear(input_shape[1], 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        x = torch.mean(x, dim=1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

5.数据生成与模型训练

# 定义训练函数
def train(model, dataloader, criterion, optimizer):
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in dataloader:
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        _, predicted = torch.max(outputs.data, dim=1)
        total += labels.size(0)

        correct += (predicted == labels.argmax(dim=1)).sum().item()

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct / total

    return epoch_loss, epoch_acc

# 设置参数
data_shape_cnn = (3, 32, 32)  # (channels, height, width)
target_shape = (10,)
num_samples = 1000
batch_size = 32
learning_rate = 0.001
num_epochs = 10

# 创建数据集和数据加载器
dataset = UniversalDataset(data_shape_cnn, target_shape, num_samples)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 创建CNN模型、优化器和损失函数
cnn_model = CNNModel(data_shape_cnn)
cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=learning_rate)
cnn_criterion = nn.CrossEntropyLoss()

print('CNN模型训练:')
# 训练CNN模型
for epoch in range(num_epochs):
    cnn_loss, cnn_acc = train(cnn_model, dataloader, cnn_criterion, cnn_optimizer)
    print(f'CNN - Epoch {epoch+1}/{num_epochs}, Loss: {cnn_loss:.4f}, Accuracy: {cnn_acc:.4f}')


# 重新创建数据集和数据加载器
data_shape_rnn = (20, 32)  # (sequence_length, input_size, hidden_size)
dataset = UniversalDataset(data_shape_rnn, target_shape, num_samples)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 创建RNN模型、优化器和损失函数
rnn_model = RNNModel(data_shape_rnn)
rnn_optimizer = torch.optim.Adam(rnn_model.parameters(), lr=learning_rate)
rnn_criterion = nn.CrossEntropyLoss()

print('RNN模型训练:')
# 训练RNN模型
for epoch in range(num_epochs):
    rnn_loss, rnn_acc = train(rnn_model, dataloader, rnn_criterion, rnn_optimizer)
    print(f'RNN - Epoch {epoch+1}/{num_epochs}, Loss: {rnn_loss:.4f}, Accuracy: {rnn_acc:.4f}')

# 重新创建数据集和数据加载器
data_shape_gnn = (10, 100)  # (num_nodes, node_features)
dataset = UniversalDataset(data_shape_gnn, target_shape, num_samples)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 创建GNN模型、优化器和损失函数
gnn_model = GNNModel(data_shape_gnn)
gnn_optimizer = torch.optim.Adam(gnn_model.parameters(), lr=learning_rate)
gnn_criterion = nn.CrossEntropyLoss()

print('GNN模型训练:')
# 训练GNN模型
for epoch in range(num_epochs):
    gnn_loss, gnn_acc = train(gnn_model, dataloader, gnn_criterion, gnn_optimizer)
    print(f'GNN - Epoch {epoch+1}/{num_epochs}, Loss: {gnn_loss:.4f}, Accuracy: {gnn_acc:.4f}')

6.训练结果与总结

运行结果:

CNN模型训练:
CNN - Epoch 1/10, Loss: 10.4031, Accuracy: 0.5840
CNN - Epoch 2/10, Loss: 10.2561, Accuracy: 1.0000
CNN - Epoch 3/10, Loss: 10.2503, Accuracy: 1.0000
CNN - Epoch 4/10, Loss: 10.2496, Accuracy: 1.0000
CNN - Epoch 5/10, Loss: 10.2495, Accuracy: 1.0000
CNN - Epoch 6/10, Loss: 10.2494, Accuracy: 1.0000
CNN - Epoch 7/10, Loss: 10.2493, Accuracy: 1.0000
CNN - Epoch 8/10, Loss: 10.2493, Accuracy: 1.0000
CNN - Epoch 9/10, Loss: 10.2493, Accuracy: 1.0000
CNN - Epoch 10/10, Loss: 10.2493, Accuracy: 1.0000
RNN模型训练:
RNN - Epoch 1/10, Loss: 10.3851, Accuracy: 0.9680
RNN - Epoch 2/10, Loss: 10.2606, Accuracy: 1.0000
RNN - Epoch 3/10, Loss: 10.2551, Accuracy: 1.0000
RNN - Epoch 4/10, Loss: 10.2531, Accuracy: 1.0000
RNN - Epoch 5/10, Loss: 10.2520, Accuracy: 1.0000
RNN - Epoch 6/10, Loss: 10.2513, Accuracy: 1.0000
RNN - Epoch 7/10, Loss: 10.2509, Accuracy: 1.0000
RNN - Epoch 8/10, Loss: 10.2506, Accuracy: 1.0000
RNN - Epoch 9/10, Loss: 10.2504, Accuracy: 1.0000
RNN - Epoch 10/10, Loss: 10.2502, Accuracy: 1.0000
GNN模型训练:
GNN - Epoch 1/10, Loss: 10.9591, Accuracy: 0.0400
GNN - Epoch 2/10, Loss: 10.3914, Accuracy: 1.0000
GNN - Epoch 3/10, Loss: 10.2818, Accuracy: 1.0000
GNN - Epoch 4/10, Loss: 10.2635, Accuracy: 1.0000
GNN - Epoch 5/10, Loss: 10.2569, Accuracy: 1.0000
GNN - Epoch 6/10, Loss: 10.2539, Accuracy: 1.0000
GNN - Epoch 7/10, Loss: 10.2524, Accuracy: 1.0000
GNN - Epoch 8/10, Loss: 10.2515, Accuracy: 1.0000
GNN - Epoch 9/10, Loss: 10.2509, Accuracy: 1.0000
GNN - Epoch 10/10, Loss: 10.2505, Accuracy: 1.0000

本文主要介绍了如何创建一个万能的数据生成类,可以根据输入的形状参数生成不同形状的数据。然后,将生成的数据和标签输入到CNN、RNN和GNN模型中进行训练,并打印出损失值和准确率。后续我们可以根据实际应用中可能需要根据具体任务做更多修改和扩展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/712359.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

摄像机控制——旁轴摇移

通常摄像机进行摇移控制的时候,都是以摄像机正前方中心位置作为注视点进行环绕控制的,如果在注释点位置有物体,那么感受上是围绕着该物体进行观察。 但是最近公司的策划要求摇移时候的围绕点是鼠标点击的位置,而不是摄像机的正中心…

零基础网络安全学习路线,真的很全,建议收藏!!!

很多小伙伴在网上搜索网络安全时,会出来网络安全工程师这样一个职位,它的范围很广,只要是与网络安全挂钩的技术人员都算网络安全工程师,一些小伙伴就有疑问了,网络安全现在真的很火吗? 那么今天博主就带大…

从0实现基于Linux socket聊天室-多线程服务器模型(一)

前言Socket在实际系统程序开发当中,应用非常广泛,也非常重要。实际应用中服务器经常需要支持多个客户端连接,实现高并发服务器模型显得尤为重要。高并发服务器从简单的循环服务器模型处理少量网络并发请求,演进到解决C10K&#xf…

AntDB数据库将携创新性解决方案亮相2023可信数据库发展大会

由中国通信标准化协会指导,中国通信标准化协会大数据技术标准推进委员会(CCSA TC601)主办的“2023可信数据库发展大会”将于2023年7月4日——5日在北京国际会议中心召开。作为深耕通信行业15年的国产数据库产品,AntDB受邀参会&…

记录一下kibana启动链接报错问题(kibana server is not ready yet)

记录一下kibana启动链接报错问题(kibana server is not ready yet) 今天启动kibana出现该问题 先去看了看是否是elasticsearch连接出错 启动了容器 docker start elasticsearch docker start kibana进入了kibana容器 docker exec -it kibana bash进行了下面的操作&#xf…

No suitable driver found for

在学习Mbatis时候遇到的奇怪的问题,报错提示如图所示,提示找不到数据库驱动 检查db.properties文件,一开始认为没问题 drivercom.mysql.jdbc.Driver urljdbc:mysql://localhost:3306/mybatis?useSSLfalse&useUnicodetrue&characterEncodingUTF…

华为OD机试真题2023Q1 100分 + 2023 B卷(JavaPythonJavaScript)

目录 2023 5月 B卷 “新加题”(100分值)2023Q1 100分下面分享一道“2023Q1 200分 机器人活动区域”的解题思路一、题目描述二、输入描述三、输出描述四、解题思路五、Python算法源码六、效果展示1、输入2、输出 大家好,我是哪吒。 五月份之前…

react基础-生命周期render props模式高阶组件原理揭秘

组件生命周期(★★★) 目标 说出组件生命周期对应的钩子函数钩子函数调用的时机 概述 意义:组件的生命周期有助于理解组件的运行方式,完成更复杂的组件功能、分析组件错误原因等 组件的生命周期: 组件从被创建到挂…

通用机器人里程碑?谷歌展示全球首个多任务AI智能体

目录 两大硬核科技支撑通用机器人研发(1)自生成训练数据(2)基于多模态模型 科技巨头同台比拼 中国产业链凸显性价比优势发展初期硬件先行 运动模块价值量最高 已学会套圈、搭积木、抓水果…… 人工智能和机器人,总是不…

MES生产管理系统与ERP系统的集成以及优势

导言: 在当今数字化转型的浪潮中,企业越来越意识到整合各个部门的数据和流程的重要性。MES生产管理系统和ERP系统是两个关键的管理工具,它们在企业中发挥着不可或缺的作用。本文将探讨企业MES管理系统与ERP系统进行集成,以及这种…

它如何做到让我们持久且不感疲劳

写在前面 随着科技的进步和数字化生活的兴起,人们长时间使用显示器的需求增加,越来越多的人戴眼镜并且面临眼睛问题。显示器屏幕灯在当今社会也逐渐扮演着不可或缺的角色。 首先,显示器屏幕灯能够提供必要的亮度,确保我们在各种…

pyhton-docx表格合并单元格

合并单元格需要指定两个单元格, from docx_utils import set_table_singleBoard from docx import Documentdocument Document() table document.add_table(rows3, cols3) # 创建一个包含 3 行 3 列的表格 table.cell(0, 0).merge(table.cell(0, 1)) # 合并第一…

用正则表达式进行input框的限制输入

vue项目可以用input事件输入 1.限制input输入框只能输入大小写字母、数字、下划线的正则表达式&#xff1a; 用户名< input type"text" placeholder"只包含数字字母下划线" onkeyup"this.valuethis.value.replace(/[^\w_]/g,);"> 2.限…

linux如何修改sudoers文件,将非root用户加入到 sudoers 文件中

需求 由于在非 root 用户下执行 sudo 命令会报错 cc 不在 sudoers 文件中。此事将被报告。所以需要将 cc 这个用户加入到 sudoers 文件中进行授权 解决 要修改 sudoers 文件&#xff0c;您需要以 root 用户身份进行操作。以下是一种常见的方法&#xff1a; 1、使用 root 用…

Linux文件管理(创建 删除 复制 剪切 打包 压缩 解压缩)全总结

目录 一、Linux下文件命名规则 1、可以使用哪些字符&#xff1f; 2、文件名的长度 3、文件名的大小写 4、Linux文件扩展名 二、Linux下的文件管理 1、文件夹创建 ① mkdir创建文件夹 ② mkdir -p递归创建文件夹&#xff08;目录&#xff09; ③ 使用mkdir同时创建多个…

nacos批量信息获取-GitNacosConfig

声明&#xff1a;文中涉及到的技术和工具&#xff0c;仅供学习使用&#xff0c;禁止从事任何非法活动&#xff0c;如因此造成的直接或间接损失&#xff0c;均由使用者自行承担责任。 点点关注不迷路&#xff0c;每周不定时持续分享各种干货。 原文链接&#xff1a;众亦信安&a…

Leetcode-每日一题【234.回文链表】

题目 给你一个单链表的头节点 head &#xff0c;请你判断该链表是否为回文链表。如果是&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,2,1]输出&#xff1a;true 示例 2&#xff1a; 输入&#xff1a;head…

【Golang | runtime】runtime.Caller和runtime.Callers的使用和区别

环境&#xff1a; go version go1.18.2 1、runtime.Caller 函数func runtime.Caller(skip int) (pc uintptr, file string, line int, ok bool) 作用 获取函数Caller的调用信息 参数 skip: 0时&#xff0c;返回调用Caller的函数A的pc(program counter)、所在文件名以及Cal…

SI12T触摸按键芯片兼容TMS12资料

Si12T 是一款具有自动灵敏度校准功能的 12 通道电容传感器&#xff0c;其工作电压 范围为 1.8 ~ 5.0 V 。 Si12T 设置 IDLE 模式来节省功耗&#xff0c;此时&#xff0c;功耗电流为 3.5 A3.3V 。 Si12T 有三种特殊功能&#xff1a;一种是通道 1 上的嵌入式电源键…

体验Kubernetes(k8s),使用minikube搭建单机k8s

文章目录 一、windows&#xff1a;使用Minikube搭建单节点K8s1、安装VirtualBox2、安装kubectl3、安装minikube4、使用minikube搭建k8s 二、centos&#xff1a;使用minikube搭建单节点k8s1、安装docker2、下载kubectl&minikube与安装3、搭建单机k8s4、体验pod 一、windows&…