pytorch生成对抗网络

news2025/2/3 16:23:10

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

生成对抗网络(GAN,Generative Adversarial Network)是一种深度学习模型,由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。这两个网络通过对抗过程共同训练,从而使生成器能够生成越来越真实的假数据。

GAN的基本工作原理:

  1. 生成器(G):它的任务是生成与真实数据相似的假数据。生成器通常从一个随机噪声(例如,均匀分布或高斯分布的噪声)开始,经过多层神经网络的处理,输出伪造的数据样本。

  2. 判别器(D):它的任务是区分输入数据是来自真实数据分布,还是生成器伪造的假数据。判别器通常是一个二分类器,其输出是一个表示“真实”或“假”的概率值。

训练过程:

  • 对抗过程:生成器和判别器相互博弈。生成器希望生成尽可能像真的数据,以骗过判别器;而判别器希望准确区分真假数据。最终,生成器会通过优化损失函数,使得生成的数据与真实数据尽可能相似,判别器的性能则被提升到一个极限,使得它不能再轻易地区分真假数据。
  • 数学公式:

  • 判别器的目标是最大化其输出的正确分类概率,即区分真假数据。
  • 生成器的目标是最小化其输出的“假数据”被判定为假的概率。

常见的GAN变种:

  1. DCGAN(Deep Convolutional GAN):使用卷积神经网络(CNN)来增强生成器和判别器的表现。
  2. WGAN(Wasserstein GAN):引入了Wasserstein距离,改进了训练稳定性。
  3. CycleGAN:能够在没有成对样本的情况下进行图像到图像的转换,例如将马变成斑马。

以下是一个简化的PyTorch GAN实现的框架,生成一个语音的梅尔频谱(假设已经处理了音频并提取了梅尔频谱特征)

import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import matplotlib.pyplot as plt


# 生成器(Generator)
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 80),  # 80表示梅尔频谱的时间步(例如:80个梅尔频率)
            nn.Tanh()  # 生成梅尔频谱,范围在[-1, 1]之间
        )

    def forward(self, z):
        return self.fc(z)


# 判别器(Discriminator)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(80, 512),  # 输入为梅尔频谱的时间步
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出判定是“真”还是“假”
        )

    def forward(self, x):
        return self.fc(x)


# 初始化生成器和判别器
z_dim = 100
generator = Generator(z_dim)
discriminator = Discriminator()

# 优化器
lr = 0.0002
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# 损失函数
criterion = nn.BCELoss()


# 加载数据(假设已经提取了梅尔频谱特征,取一个示例)
def load_example_mel_spectrogram():
    # 假设这是一个真实梅尔频谱的示例,实际数据应从音频文件中提取
    mel = torch.rand((80))  # 生成一个假的梅尔频谱数据
    return mel.unsqueeze(0)  # 扩展维度以适应网络


# 训练GAN
num_epochs = 1000
for epoch in range(num_epochs):
    # 真实数据
    real_data = load_example_mel_spectrogram()
    real_labels = torch.ones(real_data.size(0), 1)  # 标签为1表示真实数据

    # 假数据
    z = torch.randn(real_data.size(0), z_dim)  # 随机噪声
    fake_data = generator(z)
    fake_labels = torch.zeros(real_data.size(0), 1)  # 标签为0表示假数据

    # 训练判别器
    discriminator.zero_grad()
    real_loss = criterion(discriminator(real_data), real_labels)
    fake_loss = criterion(discriminator(fake_data.detach()), fake_labels)
    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    d_optimizer.step()

    # 训练生成器
    generator.zero_grad()
    g_loss = criterion(discriminator(fake_data), real_labels)  # 生成器希望判别器判定为真实
    g_loss.backward()
    g_optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch [{epoch}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

    # 可视化生成的梅尔频谱(只显示最后一次生成的结果)
    if epoch == num_epochs - 1:
        plt.figure(figsize=(10, 4))
        plt.imshow(fake_data.detach().numpy(), aspect='auto', origin='lower')
        plt.title(f"Generated Mel Spectrogram - Epoch {epoch}")
        plt.colorbar()
        plt.show()

# 测试阶段:使用训练好的生成器进行语音生成
z_test = torch.randn(1, z_dim)  # 创建一个新的随机噪声向量
generated_mel_spectrogram = generator(z_test)

# 可视化生成的梅尔频谱
plt.figure(figsize=(10, 4))
plt.imshow(generated_mel_spectrogram.detach().numpy(), aspect='auto', origin='lower')
plt.title("Generated Mel Spectrogram from Test Data")
plt.colorbar()
plt.show()

解释:

  1. 测试阶段

    • 在训练完成后,我们使用一个新的随机噪声向量z_test来生成一个新的梅尔频谱。
    • generated_mel_spectrogram = generator(z_test)是生成梅尔频谱的过程。
  2. 可视化

    • 使用plt.imshow()来可视化生成的梅尔频谱图,origin='lower'是确保频谱图正确显示。
    • plt.colorbar()添加颜色条,以便更清晰地理解梅尔频谱的数值范围。

结果:

  • 在训练过程中,你会看到每个epoch的损失值,并在最后一次epoch时显示生成的梅尔频谱。
  • 在测试阶段,生成器会基于随机噪声生成一个新的梅尔频谱并进行可视化,帮助你观察最终模型生成的语音特征。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2291349.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Baklib在企业知识管理领域的领先地位与三款竞品的深度剖析

内容概要 在现代企业中,知识管理已成为提高工作效率和推动创新的重要手段。Baklib作为一款领先的知识中台,以其集成化和智能化的特性,帮助企业在这一领域取得了显著成就。该平台具备强大的知识收集、整理、存储和共享功能,通过构…

2 MapReduce

2 MapReduce 1. MapReduce 介绍1.1 MapReduce 设计构思 2. MapReduce 编程规范3. Mapper以及Reducer抽象类介绍1.Mapper抽象类的基本介绍2.Reducer抽象类基本介绍 4. WordCount示例编写5. MapReduce程序运行模式6. MapReduce的运行机制详解6.1 MapTask 工作机制6.2 ReduceTask …

测压表压力表计量表针头针尾检测数据集VOC+YOLO格式4862张4类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):4862 标注数量(xml文件个数):4862 标注数量(txt文件个数):4862 …

吴恩达深度学习——优化神经网络

本文来自https://www.bilibili.com/video/BV1FT4y1E74V,仅为本人学习所用。 文章目录 优化样本大小mini-batch 优化梯度下降法动量梯度下降法指数加权平均概念偏差纠正 动量梯度下降法 RMSpropAdam优化算法 优化学习率局部最优问题(了解) 优…

揭秘算法 课程导读

目录 一、老师介绍 二、课程目标 三、课程安排 一、老师介绍 学问小小谢 我是一个热爱分享知识的人,我深信知识的力量能够启迪思考,丰富生活。 欢迎每一位对知识有渴望的朋友,如果你对我的创作感兴趣,或者我们有着共同的兴趣点&…

17.[前端开发]Day17-形变-动画-vertical-align

1 transform CSS属性 - transform transform的用法 表示一个或者多个 不用记住全部的函数&#xff0c;只用掌握这四个常用的函数即可 位移 - translate <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta ht…

Python的那些事第五篇:数据结构的艺术与应用

新月人物传记&#xff1a;人物传记之新月篇-CSDN博客 目录 一、列表&#xff08;List&#xff09;&#xff1a;动态的容器 二、元组&#xff08;Tuple&#xff09;&#xff1a;不可变的序列 三、字典&#xff08;Dict&#xff09;&#xff1a;键值对的集合 四、集合&#xf…

Linux:线程池和单例模式

一、普通线程池 1.1 线程池概念 线程池&#xff1a;一种线程使用模式。线程过多会带来调度开销&#xff0c;进而影响缓存局部性和整体性能。而线程池维护着多个线程&#xff0c;等待着监督管理者分配可并发执行的任务。这避免了在处理短时间任务时创建与销毁线程的代价&…

【算法-位运算】位运算遍历 LogTick 算法

文章目录 1. 引入2. LogTick 优化遍历过程3. 题目3.1 LeetCode3097 或值至少为 K 的最短子数组 II3.2 LeetCode2411 按位或最大的最小子数组长度3.3 LeetCode3209 子数组按位与值为 K 的数目3.4 LeetCode3171 找到按位或最接近 K 的子数组3.5 LeetCode1521 找到最接近目标值的函…

【memgpt】letta 课程4:基于latta框架构建MemGpt代理并与之交互

Lab 3: Building Agents with memory 基于latta框架构建MemGpt代理并与之交互理解代理状态,例如作为系统提示符、工具和agent的内存查看和编辑代理存档内存MemGPT 代理是有状态的 agents的设计思路 每个步骤都要定义代理行为 Letta agents persist information over time and…

Python的那些事第九篇:从单继承到多继承的奇妙之旅

Python 继承&#xff1a;从单继承到多继承的奇妙之旅 目录 Python 继承&#xff1a;从单继承到多继承的奇妙之旅 一、引言 二、继承的概念与语法 三、单继承 四、多继承 五、综合代码示例 六、总结 一、引言 在编程的世界里&#xff0c;继承就像是一场神奇的魔法&#…

pandas(三)Series使用

一、Series基础使用 import pandasd {x:100,y:200,z:300} s1 pandas.Series(d) #将dict转化为Series print(s1)print("") l1 [1, 2, 3] l2 [a, b, c] s2 pandas.Series(l1, indexl2) #list转为Series print(s2)print("") s3 pandas.Series([11…

Windows电脑本地部署运行DeepSeek R1大模型(基于Ollama和Chatbox)

文章目录 一、环境准备二、安装Ollama2.1 访问Ollama官方网站2.2 下载适用于Windows的安装包2.3 安装Ollama安装包2.4 指定Ollama安装目录2.5 指定Ollama的大模型的存储目录 三、选择DeepSeek R1模型四、下载并运行DeepSeek R1模型五、使用Chatbox进行交互5.1 下载Chatbox安装包…

如何用微信小程序写春联

​ 生活没有模板,只需心灯一盏。 如果笑能让你释然,那就开怀一笑;如果哭能让你减压,那就让泪水流下来。如果沉默是金,那就不用解释;如果放下能更好地前行,就别再扛着。 一、引入 Vant UI 1、通过 npm 安装 npm i @vant/weapp -S --production​​ 2、修改 app.json …

2025最新在线模型转换工具onnx转换ncnn,mnn,tengine等

文章目录 引言最新网址地点一、模型转换1. 框架转换全景图2. 安全的模型转换3. 网站全景图 二、转换说明三、模型转换流程图四、感谢 引言 在yolov5&#xff0c;yolov8&#xff0c;yolov11等等模型转换的领域中&#xff0c;时间成本常常是开发者头疼的问题。最近发现一个超棒的…

算法每日双题精讲 —— 前缀和(【模板】一维前缀和,【模板】二维前缀和)

在算法竞赛与日常编程中&#xff0c;前缀和是一种极为实用的预处理技巧&#xff0c;能显著提升处理区间和问题的效率。今天&#xff0c;我们就来深入剖析一维前缀和与二维前缀和这两个经典模板。 一、【模板】一维前缀和 题目描述 给定一个长度为 n n n 的整数数组 a a a&…

记8(高级API实现手写数字识别

目录 1、Keras&#xff1a;2、Sequential模型&#xff1a;2.1、建立Sequential模型&#xff1a;modeltf.keras.Sequential()2.2、添加层&#xff1a;model.add(tf.keras.layers.层)2.3、查看摘要&#xff1a;model.summary()2.4、配置训练方法&#xff1a;model.compile(loss,o…

88.[4]攻防世界 web php_rce

之前做过&#xff0c;回顾&#xff08;看了眼之前的wp,跟没做过一样&#xff09; 属于远程命令执行漏洞 在 PHP 里&#xff0c;system()、exec()、shell_exec()、反引号&#xff08;&#xff09;等都可用于执行系统命令。 直接访问index.php没效果 index.php?sindex/think\a…

23.Word:小王-制作公司战略规划文档❗【5】

目录 NO1.2.3.4 NO5.6​ NO7.8.9​ NO10.11​ NO12​ NO13.14 NO1.2.3.4 布局→页面设置对话框→纸张&#xff1a;纸张大小&#xff1a;宽度/高度→页边距&#xff1a;上下左右→版式&#xff1a;页眉页脚→文档网格&#xff1a;勾选只指定行网格✔→ 每页&#xff1a;…

数据结构 树1

目录 前言 一&#xff0c;树的引论 二&#xff0c;二叉树 三&#xff0c;二叉树的详细理解 四&#xff0c;二叉搜索树 五&#xff0c;二分法与二叉搜索树的效率 六&#xff0c;二叉搜索树的实现 七&#xff0c;查找最大值和最小值 指针传递 vs 传引用 为什么指针按值传递不会修…