山东大学软件学院ai导论实验之生成对抗网络

news2025/2/27 10:15:06

目录

实验目的

实验代码

实验内容

实验结果


实验目的

基于Pytorch搭建一个生成对抗网络,使用MNIST数据集。

实验代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os

# 设置环境变量
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# 创建保存生成图像的文件夹
output_path = r"xxxxxxxxxxxxxxxxxx"
os.makedirs(output_path, exist_ok=True)


# 生成器网络
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.network(z)
        return img.view(img.size(0), 1, 28, 28)


# 判别器网络
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.network(img.view(img.size(0), -1))


def generate_and_save_images(generator, test_input, epoch, img_path):
    with torch.no_grad():
        generated_images = generator(test_input).cpu().numpy()

    fig, axes = plt.subplots(4, 4, figsize=(4, 4))
    for i, ax in enumerate(axes.flat):
        # 将图像从形状 (1, 28, 28) 转换为 (28, 28),去除通道维度
        ax.imshow(np.squeeze(generated_images[i]), cmap='gray')
        ax.axis('off')

    img_filename = os.path.join(img_path, f"generated_epoch_{epoch}.png")
    plt.tight_layout()
    plt.savefig(img_filename)
    plt.close()


# 设置设备(使用GPU或CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
lr = 0.0001
batch_size = 128
latent_dim = 100
epochs = 2000

# 数据预处理和加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./MNIST_data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 测试数据:随机噪声作为输入
test_data = torch.randn(batch_size, latent_dim).to(device)

# 初始化生成器和判别器,并定义损失函数和优化器
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# 记录损失
D_losses = []
G_losses = []

# 训练过程
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(train_loader):
        real_imgs = imgs.to(device)
        batch_size = real_imgs.size(0)

        # 判别器训练
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)

        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # 计算损失
        real_loss = adversarial_loss(discriminator(real_imgs), real_labels)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # 生成器训练
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        # 记录损失
        D_losses.append(d_loss.item())
        G_losses.append(g_loss.item())

        # 打印每2000个步骤的迭代信息
        if (epoch * len(train_loader) + i) % 2000 == 0:
            print(f"Iter: {epoch * len(train_loader) + i}")
            print(f"D_loss: {d_loss.item():.4f}")
            print(f"G_loss: {g_loss.item():.4f}")
    # 每个epoch保存生成的图像
    generate_and_save_images(generator, test_data, epoch, output_path)

    # 保存生成器和判别器的模型
    torch.save(generator.state_dict(), "Generator_mnist.pth")
    torch.save(discriminator.state_dict(), "Discriminator_mnist.pth")

# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(D_losses, label='Discriminator Loss')
plt.plot(G_losses, label='Generator Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.savefig('loss_curve.png')  # 保存图像
plt.show()  # 显示图像

实验内容

1. 数据集加载

与前几次实验一样,本实验仍然使用MNIST数据集作为输入数据集通过torchvision库进行加载并标准化处理,使得图像像素值在[-1, 1]范围内,以适应生成对抗网络的训练要求。

2. 生成器与判别器网络

生成器:生成器网络的任务是生成伪造的图像,以欺骗判别器。输入是一个随机噪声向量(latent vector),输出是一个28x28像素的图像。生成器使用多个全连接层,每个层后面都跟着一个LeakyReLU激活函数,最终输出通过Tanh激活函数确保生成的图像像素值在[-1, 1]范围内。

判别器:判别器网络的任务是区分输入的图像是“真实的”还是“伪造的”。它将图像输入后,通过多个全连接层,最后输出一个介于0和1之间的值,表示图像的真实性。

3. 训练过程

判别器训练:判别器的目标是最大化其准确性,即正确分类真实和伪造的图像。在每次训练中,先计算真实图像的损失,然后计算生成图像的损失,最后将两个损失加权平均得到判别器的总损失。

生成器训练:生成器的目标是最小化判别器对其生成图像的判断错误率。即通过调整其权重,使得生成的图像越来越像真实图像,以此欺骗判别器。生成器的损失函数是判别器对生成图像的输出,标签为“真实”(即1)。

模型优化:使用Adam优化器分别优化生成器和判别器的参数。学习率为0.0001。

  1. 改变隐藏层数

生成器的结构由原来的4个隐藏层缩减为2个隐藏层:

5.生成图像并保存

在每个epoch结束时,使用生成器生成一些图像,并将图像保存为PNG格式文件。每个epoch的图像被保存到指定的文件夹中,以便可视化生成图像的变化。

6. 绘制损失曲线

训练过程中记录并绘制判别器和生成器的损失曲线,以便观察模型的训练进展。

实验结果

迭代得到的训练结果为:

改变隐藏层数得到的部分结果为:

刚开始生成的初始图像为:

运行一段时间后,得到的图像为:

可以明显的看到,随着迭代不断增加,数字越来越清晰,数字识别成功

损失曲线为:

初始:

慢慢的趋于平稳:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2306836.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

入门网络安全工程师要学习哪些内容【2025年寒假最新学习计划】

🤟 基于入门网络安全/黑客打造的:👉黑客&网络安全入门&进阶学习资源包 大家都知道网络安全行业很火,这个行业因为国家政策趋势正在大力发展,大有可为!但很多人对网络安全工程师还是不了解,不知道网…

【论文解读】《C-Pack: Packed Resources For General Chinese Embeddings》

论文链接:https://arxiv.org/pdf/2309.07597 本论文旨在构建一套通用中文文本嵌入的完整资源包——C-Pack,解决当前中文文本嵌入研究中数据、模型、训练策略与评测基准缺失的问题。论文主要贡献体现在以下几个方面: 大规模训练数据&#xf…

【操作系统、数学】什么是排队论?如何理解排队论?排队论有什么用处?Queueing Theory?什么是 Little’s Law?

排队论(Queueing Theory)是研究系统中排队现象的数学理论,旨在分析资源分配、服务效率及等待时间等问题。它广泛应用于计算机科学、通信网络、交通规划、工业工程等领域。 【下文会通过搜集的资料,从各方面了解排队论&#xff0c…

DeepSeek赋能大模型内容安全,网易易盾AIGC内容风控解决方案三大升级

在近两年由AI引发的生产力革命的背后,一场关乎数字世界秩序的攻防战正在上演:AI生成的深度伪造视频导致企业品牌声誉损失日均超千万,批量生成的侵权内容使版权纠纷量与日俱增,黑灰产利用AI技术持续发起欺诈攻击。 与此同时&#…

(0)阿里云大模型ACP-考试回忆

这两天通过了阿里云大模型ACP考试,由于之前在网上没有找到真题,导致第一次考试没有过,后面又重新学习了一遍文档才顺利通过考试,这两次考试内容感觉考试题目90%内容是覆盖的,后面准备分享一下每一章的考题,…

0.【深度学习YOLOV11项目实战-项目安装教程】(图文教程,超级详细)

目录 前言一、安装Pycharm(安装过Pycharm的跳过这一步)1.1 点击下述链接直接跳转到教程页面进行安装 二、安装Anaconda(安装过Anaconda的跳过这一步)2.1 点击下述链接直接跳转到教程页面进行安装 三、后续安装教程(有N…

Docker 部署 Jenkins持续集成(CI)工具

[TOC](Docker 部署 Jenkins持续集成(CI)工具) 前言 Jenkins 是一个流行的开源自动化工具,广泛应用于持续集成(CI)和持续交付(CD)的环境中。通过 Docker 部署 Jenkins,可以简化安装和配置过程,并…

布署elfk-准备工作

建议申请5台机器部署elfk: filebeat(每台app)--> logstash(2台keepalived)--> elasticsearch(3台)--> kibana(部署es上)采集输出 处理转发 分布式存储 展示 ELK中文社区: 搜索客,搜索人自己的社区 官方…

微软推出Office免费版,限制诸多,只能编辑不能保存到本地

易采游戏网2月25日独家消息:微软宣布推出一款免费的Office版本,允许用户进行基础文档编辑操作,但限制颇多,其中最引人关注的是用户无法将文件保存到本地。这一举措引发了广泛讨论,业界人士对其背后的商业策略和用户体验…

《ArkTS鸿蒙应用开发入门到实战》—新手小白学习鸿蒙的推荐工具书!

《ArkTS鸿蒙应用开发入门到实战》—新手小白学习鸿蒙的推荐工具书! 在科技日新月异的今天,鸿蒙操作系统(HarmonyOS)作为华为推出的全新操作系统,正迅速进入越来越多的智能设备,成为物联网和智能硬件领域的…

DeepSeek 提示词:高效的提示词设计

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编…

html中的css

css (cascading style sheets,串联样式表,也叫层叠样式表) css规范一般约定: 1.存放CSS样式文件的目录一般命名为style或css。 2.在项目初期,会把不同类别的样式放于不同的CSS文件,是为了CSS编…

JAVA面试常见题_基础部分_Dubbo面试题(上)

Dubbo 支持哪些协议,每种协议的应用场景,优缺点? • dubbo: 单一长连接和 NIO 异步通讯,适合大并发小数据量的服务调用,以及消费者远大于提供者。传输协议 TCP,异步,Hessian 序列化…

Binder通信协议

目录 一,整体架构 二,Binder通信协议 一,整体架构 二,Binder通信协议

解决应用程序 0xc00000142 错误:完整修复指南

💥 0xc00000142 错误出现的场景 你是不是遇到这样的情况: 🔹 点击某个软件,突然弹出“应用程序无法正确启动(0xc00000142)” ? 🔹 明明安装了所有必要组件,软件却始终打不开? &…

游戏引擎学习第125天

仓库:https://gitee.com/mrxiao_com/2d_game_3 回顾并为今天的内容做准备。 昨天,当我们离开时,工作队列已经完成了基本的功能。这个队列虽然简单,但它能够执行任务,并且我们已经为各种操作编写了测试。字符串也能够正常推送到队…

DeepSeek R1满血+火山引擎详细教程

DeepSeek R1满血火山引擎详细教程 一、安装Cherry Studio。 Cherry Studio AI 是一款强大的多模型 AI 助手,支持 iOS、macOS 和 Windows 平台。可以快速切换多个先进的 LLM 模型,提升工作学习效率。下载地址 https://cherry-ai.com/ 认准官网,无强制注册。 这…

前端依赖nrm镜像管理工具

npm 默认镜像 :https://registry.npmjs.org/ 1、安装 nrm npm install nrm --global2、查看镜像源列表 nrm ls3、测试当前环境下,哪个镜像源速度最快。 nrm test4、 切换镜像源 npm config get registry # 查看当前镜像源 nrm use taobao # 等价于 npm…

ES的简单讲解

功能 : 文档存储 与 文档搜索 特点:比如有一个文档名 “你好” 可以用‘你‘,好,你好都可以搜索到这个文档 ES核心概念 类似于数据库中表的概念,在表的概念下又对数据集合进行了细分 ​ ES_Client查询接口 cpr::R…

进程间通信(一)

1.进程间通信介绍 数组传输:一个进程需要将它的数据发送给另一个进程 资源共享:多个进程之间共享同样的资源 通知事件:一个进程需要向另一个或者一组进程发送信息,通知发送了某种事件(如进程终止时要通知父进程) 进程控制&…