生成式 AI:使用 Pytorch 通过 GAN 生成合成数据

news2024/10/6 8:29:29

导 读

生成对抗网络(GAN)因其生成图像的能力而变得非常受欢迎,而语言模型(例如 ChatGPT)在各个领域的使用也越来越多。这些 GAN 模型可以说是人工智能/机器学习目前主流的原因;

因为它向每个人(尤其是该领域之外的人)展示了机器学习所具有的巨大潜力。网上已经有很多关于 GAN 模型的资源,但其中大多数都集中在图像生成上。这些图像生成和语言模型需要复杂的空间或时间复杂性,这增加了额外的复杂性,使读者更难理解 GAN 的真正本质。

为了解决这个问题并使 GAN 更容易被更广泛的受众所接受,在本文的 GAN 模型示例中,我们将采取一种不同的、更实用的方法,重点关注生成数学函数的合成数据。

除了出于学习目的的简化之外,合成数据生成本身也变得越来越重要。数据不仅在业务决策中发挥着核心作用,而且数据驱动方法的用途也越来越多,比第一原理模型更受欢迎。

比如天气预报,第一个原理模型包括通过数值求解的纳维-斯托克斯方程的简化版本。然而,深度学习研究中进行天气预报的尝试在捕捉天气模式方面非常成功,并且一旦经过训练,运行起来会更容易、更快。

有需要的朋友关注公众号【小Z的科研日常】,获取更多内容

01、生成模型与判别模型

在机器学习中,理解判别模型和生成模型之间的区别非常重要,因为它们是 GAN 的关键组成部分:

判别模型:

判别模型侧重于将数据分类为预定义的类别,例如将狗和猫的图像分类为各自的类别。这些模型不是捕获整个分布,而是辨别不同类别的边界。它们输出 P(y|x)(类别概率,给定输入数据的 y,x),即它们回答给定数据点属于哪个类别的问题。

生成模型:

生成模型旨在理解数据的底层结构。与区分类别的判别模型不同,生成模型学习数据的整个分布。这些模型输出 p(x|y),即它们回答了给定指定类生成该特定数据点的可能性有多大的问题。

这两个模型之间的相互作用构成了 GAN 的基础。

02、GAN—结构和组件

GAN 的关键组件包括噪声向量、生成器和鉴别器。

生成器:生成真实数据

为了生成合成数据,生成器使用随机噪声向量作为输入。为了欺骗鉴别器,生成器的目的是学习真实数据的分布并生成无法与真实数据区分开的合成数据。这里的一个问题是,对于相同的输入,它总是会产生相同的输出(想象一个图像生成器产生真实的图像,但总是相同的图像,这不是很有用)。随机噪声向量将随机性注入到过程中,从而提供生成的输出的多样性。

鉴别器:辨别真假

鉴别器就像一位受过训练来区分真实数据和虚假数据的艺术评论家。它的作用是仔细检查收到的数据并为工作真实性分配概率分数。如果合成数据看起来与真实数据相似,则鉴别器分配高概率,否则分配低概率分数。

对抗性训练:动态决斗

生成器努力学习生成鉴别器无法与真实数据区分开的合成数据。同时,鉴别器还学习并提高区分真实与合成的能力。这种动态的训练过程促使两个模型提高技能。这两个模型总是相互竞争(因此被称为对抗性),并且通过这种竞争,两个模型都在各自的角色中变得非常出色。

03、Pytorch实现GAN

在此示例中,我们在 pytorch 中实现了一个可以生成合成数据的模型。对于训练,我们有一个具有以下形状的 6 参数数据集(所有参数都绘制为参数 1 的函数)。每个参数都经过精心选择,具有显着不同的分布和形状,以增加数据集的复杂性并模仿真实世界的数据。

定义 GAN 模型组件(生成器和判别器)

import torch
from torch import nn
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn.init as init
import pandas as pd
import numpy as np
from torch.utils.data import Dataset


# 定义单块功能
def FC_Layer_blockGen(input_dim, output_dim):
    single_block = nn.Sequential(
        nn.Linear(input_dim, output_dim),

        nn.ReLU()
    )
    return single_block
    
# 定义 GENERATOR
class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
            nn.Tanh()  
        )

    def forward(self, x):
        return self.model(x)
        
#定义单个判别块
def FC_Layer_BlockDisc(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.ReLU(),
        nn.Dropout(0.4)
    )
    
# 定义判别器

class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)
        
        
#定义训练参数
batch_size = 128
num_epochs = 500
lr = 0.0002
num_features = 6
latent_dim = 20

# 模型初始化
generator = Generator(noise_dim, num_features)
discriminator = Discriminator(num_features)

# 损失函数和优化器
criterion = nn.BCELoss()
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

模型初始化和数据处理

file_path = 'SamplingData7.xlsx'
data = pd.read_excel(file_path)
X = data.values
X_normalized = torch.FloatTensor((X - X.min(axis=0)) / (X.max(axis=0) - X.min(axis=0)) * 2 - 1)
real_data = X_normalized


class MyDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe.values.astype(float)
        self.labels = dataframe.values.astype(float)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = {
            'input': torch.tensor(self.data[idx]),
            'label': torch.tensor(self.labels[idx])
        }
        return sample

# 创建数据集实例
dataset = MyDataset(data)

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

def weights_init(m):
    if isinstance(m, nn.Linear):
        init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.constant_(m.bias, 0)

pretrained = False
if pretrained:
    pre_dict = torch.load('pretrained_model.pth')
    generator.load_state_dict(pre_dict['generator'])
    discriminator.load_state_dict(pre_dict['discriminator'])
else:
    # 应用权重初始化
    generator = generator.apply(weights_init)
    discriminator = discriminator.apply(weights_init)

模型训练

model_save_freq = 100

latent_dim =20
for epoch in range(num_epochs):
    for batch in dataloader:
        real_data_batch = batch['input']
        real_labels = torch.FloatTensor(np.random.uniform(0.9, 1.0, (batch_size, 1)))
        disc_optimizer.zero_grad()
        output_real = discriminator(real_data_batch)
        loss_real = criterion(output_real, real_labels)
        loss_real.backward()

        fake_labels = torch.FloatTensor(np.random.uniform(0, 0.1, (batch_size, 1)))
        noise = torch.FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim)))
        generated_data = generator(noise)
        output_fake = discriminator(generated_data.detach())
        loss_fake = criterion(output_fake, fake_labels)
        loss_fake.backward()

        disc_optimizer.step()
 
        valid_labels = torch.FloatTensor(np.random.uniform(0.9, 1.0, (batch_size, 1)))
        gen_optimizer.zero_grad()
        output_g = discriminator(generated_data)
        loss_g = criterion(output_g, valid_labels)
        loss_g.backward()
        gen_optimizer.step()

    print(f"Epoch {epoch}, D Loss Real: {loss_real.item()}, D Loss Fake: {loss_fake.item()}, G Loss: {loss_g.item()}")

模型评估和可视化结果

import seaborn as sns

synthetic_data = generator(torch.FloatTensor(np.random.normal(0, 1, (real_data.shape[0], noise_dim))))

# 绘制结果
fig, axs = plt.subplots(2, 3, figsize=(12, 8))
fig.suptitle('Real and Synthetic Data Distributions', fontsize=16)

for i in range(2):
    for j in range(3):
        sns.histplot(synthetic_data[:, i * 3 + j].detach().numpy(), bins=50, alpha=0.5, label='Synthetic Data', ax=axs[i, j], color='blue')
        sns.histplot(real_data[:, i * 3 + j].numpy(), bins=50, alpha=0.5, label='Real Data', ax=axs[i, j], color='orange')
        axs[i, j].set_title(f'Parameter {i * 3 + j + 1}', fontsize=12)
        axs[i, j].set_xlabel('Value')
        axs[i, j].set_ylabel('Frequency')
        axs[i, j].legend()

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()


#创建 2x3 网格的子绘图
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Comparison of Real and Synthetic Data', fontsize=16)

# Define parameter names
param_names = ['Parameter 1', 'Parameter 2', 'Parameter 3', 'Parameter 4', 'Parameter 5', 'Parameter 6']

# 各参数的散点图
for i in range(2):
    for j in range(3):
        param_index = i * 3 + j
        sns.scatterplot(real_data[:, 0].numpy(), real_data[:, param_index].numpy(), label='Real Data', alpha=0.5, ax=axs[i, j])
        sns.scatterplot(synthetic_data[:, 0].detach().numpy(), synthetic_data[:, param_index].detach().numpy(), label='Generated Data', alpha=0.5, ax=axs[i, j])
        axs[i, j].set_title(param_names[param_index], fontsize=12)
        axs[i, j].set_xlabel(f'Real Data - {param_names[param_index]}')
        axs[i, j].set_ylabel(f'Real Data - {param_names[param_index]}')
        axs[i, j].legend()

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1509763.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue+elementUI用户修改密码的前端验证

用户登录后修改密码,密码需要一定的验证规则。旧密码后端验证是否正确;前端验证新密码的规范性,新密码规范为:6-16位,至少含数字/字母/特殊字符中的两种;确认密码只需要验证与新密码是否一致; 弹…

数据结构——二叉树的遍历【前序、中序、后序】

💞💞 前言 hello hello~ ,这里是大耳朵土土垚~💖💖 ,欢迎大家点赞🥳🥳关注💥💥收藏🌹🌹🌹 💥个人主页&#x…

java-教师管理系统全部资料-164-(代码+说明)

转载地址: http://www.3q2008.com/soft/search.asp?keyword教师管理系统全部资料 第一章 综 述 2 1.1 背景说明 2 1.2 设计目的 2 1.3 系统目标 3 1.4 设计指导思想 3 1.5 开发技术概论 3 1.5.1 JSP技术 3 1.5.2 JAVA 4 1.5.3 JavaBeans 5 1.5.4 Servlet 5 1.5.5 Tomcat应用服…

纯前端Web网页内嵌AutoCAD,支持在线编辑DWG、dxf等文档。

随着企业信息化的发展,越来越多的企业有网页在线浏览和编辑DWG文档(AutoCad生成的文档)的需求,但是新版浏览器纷纷取消了对NPAPI插件的支持,导致之前一些可以在线在线浏览和编辑DWG文档纷纷失效,今天推荐一…

王道机试C++第 5 章 数据结构二:队列queue和21年蓝桥杯省赛选择题Day32

目录 5.2 队列 1.STL-queue 课上演示: 基本代码展示: 2. 队列的应用 例:约瑟夫问题 No. 2 题目描述: 思路提示: 代码展示: 例:猫狗收容所 题目描述: 代码表示&#xff1…

【深度学习】线性回归

Linear Regression 一个例子线性回归机器学习中的表达评价函数好坏的度量:损失(Loss)损失函数(Loss function)哪个数据集的均方误差 (MSE) 高 如何找出最优b和w?寻找最优b和w如何降低损失 (Reducing Loss)梯度下降法梯…

【毕设级项目】基于AI技术的多功能消防机器人(完整工程资料源码)

基于AI技术的多功能消防机器人演示效果 竞赛-基于AI技术的多功能消防机器人视频演示 前言 随着“自动化、智能化”成为数字时代发展的关键词,机器人逐步成为社会经济发展的重要主体之一,“机器换人”成为发展的全新趋势和时代潮流。在可预见的将来&#…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《计及台区资源聚合功率的中低压配电系统低碳优化调度方法》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

软件无线电系列——软件无线电的发展历程及体系框架

本节目录 一、软件无线电的起始 二、软件无线电SDR论坛 三、SPEAKeasy计划 四、JTRS与SCA 五、软件无线电体系框架本节内容 一、软件无线电的起始 1992年5月,美国电信会议上,Joseph Mitola III博士提出来软件无线电(Software Radio,SR)的概念。理想化的…

实现支持多选的QComboBox

Qt提供的QComboBox只支持下拉列表内容的单选,但通过QComboBox提供的setModel、setView、setLineEdit三个方法,可以对QComboBox进行改造,使其实现下拉列表选项的多选。 QComboBox可以看作两个组件的组合:一个QLineEdit和一个QList…

OpenCV开发笔记(七十七):相机标定(二):通过棋盘标定计算相机内参矩阵矫正畸变摄像头图像

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/136616551 各位读者,知识无穷而人力有穷,要么改需求,要么找专业人士,要么自己研究 红胖子(红模仿)的博…

Orange3数据预处理(预处理器组件)

1.组件介绍 Orange3 提供了一系列的数据预处理工具,这些工具可以帮助用户在数据分析之前准备好数据。以下是您请求的预处理组件的详细解释: Discretize Continuous Variables(离散化连续变量): 这个组件将连续变量转…

利用Nginx正向代理实现局域网电脑访问外网

引言 在网络环境中,有时候我们需要让局域网内的电脑访问外网,但是由于网络策略或其他原因,直接访问外网是不可行的。这时候,可以借助 Nginx 来搭建一个正向代理服务器,实现局域网内电脑通过 Nginx 转发访问外网的需求…

算法(结合算法图解)

算法简介简单查找二分查找法 选择排序内存的工作原理数组和链表数组选择排序小结 递归小梗 要想学会递归,首先要学会递归。 递归的基线条件和递归条件递归和栈小结 快速排序分而治之快速排序合并排序时间复杂度的平均情况和最糟情况小结 算法简介 算法是一组完成任…

Python3虚拟环境之virtualenv

virtualenv 在开发Python应用程序的时候,系统安装的Python3只有一个版本:3.7。所有第三方的包都会被pip安装到Python3的site-packages目录下。 如果要同时开发多个应用程序,这些应用程序都会共用一个Python,就是安装在系统的Pyt…

【算法】一维前缀和以及二维前缀和

目录 一维前缀和适用场景示例 二维前缀和适用场景一种情况另一种情况示例 一维前缀和 适用场景 求一段区间的和。 比如有一个数列 : 如果我们要求 [l,r]即某个区间内的数组和的时候,思路就是每遍历一个元素就进行求和,记录下加到al时的和…

HYBBS 表白墙网站PHP程序源码,支持封装成APP

PHP表白墙网站源码,适用于校园内或校区间使用,同时支持封装成APP。告别使用QQ空间的表白墙。 简单安装,只需PHP版本5.6以上即可。 通过上传程序进行安装,并设置账号密码,登录后台后切换模板,适配手机和PC…

Java双非大二找实习记录

先说结论:2.22→3.6线上线下面了七家,最后oc两家小公司,接了其中一个。 本人bg: 真名不经传双非一本,无绩点无竞赛无奖项无实习,23年12月开始学java。若非要说一点相关的经历,就是有java基础&…

python-0003-pycharm开发虚拟环境中的项目

前言 在虚拟环境中创建好了python项目,使用pycharm进行开发 打开项目 使用pycharm打开项目 设置虚拟环境的解释器 File–>Settings–>Project(项目名)–>Python Interpreter–>添加解释器–>添加已经存在的解释器–>选择虚拟环境的解释器 …

程序人生——Java开发中通用的方法和准则,Java进阶知识汇总

目录 引出Java开发中通用的方法和准则建议1:不要在常量和变量中出现易混淆的字母建议2:莫让常量蜕变成变量建议3:三元操作符的类型务必一致建议4:避免带有变长参数的方法重载建议5:别让null值和空值威胁到变长方法建议6:覆写变长方法也循规蹈矩建议7:警惕自增的陷阱建议…