人工智能(pytorch)搭建模型23-pytorch搭建生成对抗网络(GAN):手写数字生成的项目应用

news2024/11/15 17:02:07

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型23-pytorch搭建生成对抗网络(GAN):手写数字生成的项目应用。生成对抗网络(GAN)是一种强大的生成模型,在手写数字生成方面具有广泛的应用前景。通过生成逼真的手写数字图像,GAN可以用于数据增强、图像修复、风格迁移等任务,提高模型的性能和泛化能力。生成对抗网络在手写数字生成领域具有广泛的应用前景。主要应用场景包括数据增强、图像修复、风格迁移和跨领域生成。数据增强可以通过生成逼真的手写数字图像,为训练数据集提供更多的样本,提高模型的泛化能力。

一、项目背景

随着深度学习技术的不断发展,生成模型在计算机视觉、自然语言处理等领域取得了显著的成果。生成对抗网络(GAN)作为一种新兴的生成模型,近年来备受关注。在手写数字生成方面,GAN可以生成逼真的手写数字图像,为数据增强、图像修复等任务提供有力支持。

二、生成对抗网络原理

生成对抗网络(GAN)由Goodfellow等人于2014年提出,它由两个神经网络——生成器(Generator)和判别器(Discriminator)——组成。生成器的目标是生成逼真的假样本,而判别器的目标是区分真实样本和生成器生成的假样本。在训练过程中,生成器和判别器相互竞争,不断调整参数,以达到纳什均衡。
GAN的目标是最小化以下价值函数:
min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
其中, G G G表示生成器, D D D表示判别器, x x x表示真实样本, z z z表示生成器的输入噪声, p data p_{\text{data}} pdata表示真实数据分布, p z p_z pz表示噪声分布。
在这里插入图片描述

三、生成对抗网络应用场景

生成对抗网络(GAN)在手写数字生成领域的应用具有广泛的前景。以下是几个主要的应用场景:
1.数据增强:通过生成逼真的手写数字图像,GAN可以为训练数据集提供更多的样本,提高模型的泛化能力。
2. 图像修复:GAN可以用于修复损坏或缺失的手写数字图像,提高图像的质量和可读性。
3. 风格迁移:GAN可以将一种手写风格转换为另一种风格,为个性化手写数字生成提供可能。
4. 跨领域生成:GAN可以实现不同手写数字数据集之间的转换,为多任务学习提供支持。

四、生成对抗网络实现手写数字生成

下面我将利用pytorch深度学习框架构建生成对抗网络的生成器模型Generator、判别器模型Discriminator。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

# 超参数设置
batch_size = 128
learning_rate = 0.0002
num_epochs = 80

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 下载并加载训练数据
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

# 定义生成器模型
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x).view(x.size(0), 1, 28, 28)

# 定义判别器模型
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.model(x)

# 初始化模型
generator = Generator()
discriminator = Discriminator()

# 损失函数和优化器
criterion = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr=learning_rate)
optimizerD = optim.Adam(discriminator.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        # 确保标签的大小与当前批次的数据大小一致
        real_labels = torch.ones(images.size(0), 1)
        fake_labels = torch.zeros(images.size(0), 1)

        # 训练判别器
        optimizerD.zero_grad()
        real_outputs = discriminator(images)
        d_loss_real = criterion(real_outputs, real_labels)
        z = torch.randn(images.size(0), 100)
        fake_images = generator(z)
        fake_outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(fake_outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizerD.step()

        # 训练生成器
        optimizerG.zero_grad()
        fake_images = generator(z)
        fake_outputs = discriminator(fake_images)
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        optimizerG.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')

    # 保存生成器生成的图片
    save_image(fake_images.data[:25], './fake_images/fake_images-{}.png'.format(epoch+1), nrow=5, normalize=True)

# 保存模型
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

最后我们打开fake_images/文件夹,可以看到生成手写图片的过程:
在这里插入图片描述

五、总结

本项目利用生成对抗网络(GAN)实现了手写数字的生成。通过训练生成器和判别器,我们成功生成了逼真的手写数字图像。这些生成的图像可以应用于数据增强、图像修复、风格迁移等领域,为手写数字识别等相关任务提供有力支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1421579.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Adobe Camera Raw forMac/win:掌控原始之美的秘密武器

Adobe Camera Raw,这款由Adobe开发的插件,已经成为摄影师和设计师们的必备工具。对于那些追求完美、渴望探索更多创意可能性的专业人士来说,它不仅仅是一个插件,更是一个能够释放无尽创造力的平台。 在数字摄影时代,R…

数据结构——栈和队列(C语言)

栈种常见的数据结构,它用来解决一些数据类型的问题,那么好,我来带着大家来学习一下栈 文章目录 栈对栈的认识栈的模拟实现栈的练习方法一方法二 栈 对栈的认识 栈(stack)是限定只能在表的一端进行插入删除操作的线性…

苍穹外卖项目可以写的简历和如何优化简历

文章目录 重点写中规写添加自己个性的项目面试会问道的问题 我是一名双非大二计算机本科生,希望我的分享对你有帮助,点赞关注不迷路。 简历编写一直是很多人求职人的心病,我自己上学期有一门课程是去校内企业面试,当时我就感受出…

性能脚本设计

性能脚本设计 目标 - 性能脚本设计技巧 1. 为什么要设计性能脚本? 1.1 需求 100虚拟用户对(查询学院-所有)接口测试,以每秒启动10个用户,统计服务器平均响应时间和错误率1.2 问题 100虚拟用户请求服务器的时候,如何统计服务器响应时间和…

Unity_Timeline使用说明

Unity_Timeline使用说明 首先要找到工具吧?Unity2023.1.19f1c1打开如下: (团结引擎没找见哪儿打开,可能是引擎问题吧?有知道的同学可以告诉我在哪儿打开) Timelime使用流程: 打开之后会提示您…

18.通过telepresence调试部署在Kubernetes上的微服务

Telepresence简介 在微服务架构中,本地开发和调试往往是一项具有挑战性的任务。Telepresence 是一种强大的工具,使得开发者本地机器上开发微服务时能够与运行在 Kubernetes 集群中的其他服务无缝交互。本文将深入探讨 Telepresence 的架构、运行原理,并通过实际的案例演示其…

Python XPath解析html出现⋆解决方法 html出现#123;解决方法

前言 爬网页又遇到一个坑,老是出现乱码,查看html出现的是&#数字;这样的。 网上相关的“Python字符中出现&#的解决办法”又没有很好的解决,自己继续冲浪,费了一番功夫解决了。 这算是又加深了一下我对这些iso、Unicode编…

HarmonyOS使用Web组件加载页面

1、加载网络页面 在Web组件创建时,指定默认加载的网络页面 。在默认页面加载完成后,如果开发者需要变更此Web组件显示的网络页面,可以通过调用loadUrl()接口加载指定的网页。 默认在Web组件加载完“www.baidu.com”页面后,点击按…

flask_django基于python的城市轨道交通公交线路查询系统vue

同时,随着信息社会的快速发展,城市轨道交通线路查询系统面临着越来越多的信息,因此很难获得他们对高效信息的需求,如何使用方便快捷的方式使查询者在广阔的海洋信息中查询,存储,管理和共享信息方面有效&…

python fastapi swagger 连接超时

问题描述 运行python项目时,访问fastapi swagger出现连接超时。 https://cdn.jsdelivr.net/npm/swagger-ui-dist4/swagger-ui.css https://cdn.jsdelivr.net/npm/swagger-ui-dist4/swagger-ui-bundle.js 解决方案 第一步 下载文件 https://pan.baidu.com/s/1Ef…

微信小程序如何实现点击上传图片功能

如下所示,实际需求中常常存在需要点击上传图片的功能,上传前显示边框表面图片显示大小,上传后将图形缩放到边框大小。 实现如下: .wxml <view class="{{img_src==?blank-area:}}" style="width:100%;height:40%;display:flex;align-items: center;jus…

K8S之Pod的介绍和使用

Pod的理论和实操 pod理论说明Pod介绍Pod运行与管理Pod管理多个容器Pod网络Pod存储 Pod工作方式自主式Pod控制器管理的Pod&#xff08;常用&#xff09; 创建pod的流程 pod实操通过资源清单文件创建自主式pod通过kubectl run创建Pod&#xff08;不常用&#xff09; pod理论说明 …

springboot整合mqtt实现消息订阅和推送

前言 mica-mqtt-client-spring-boot-starter是一个基于Spring Boot的MQTT客户端启动器&#xff0c;它集成了mica-mqtt客户端&#xff0c;提供了在Spring Boot应用程序中使用MQTT协议进行消息通信的能力。以下是关于mica-mqtt-client-spring-boot-starter的简介&#xff1a; 特…

Spring-boot项目+Rancher6.3部署+Nacos配置中心+Rureka注册中心+Harbor镜像仓库+NFS存储

目录 一、项目概述二、环境三、部署流程3.1 Harbor部署3.1.1 docker安装3.1.2 docker-compose安装3.1.3 安装证书3.1.4 Harbor下载配置安装 3.2 NFS存储搭建3.3 Rancher平台配置3.3.1 NFS存储相关配置3.3.2 Harbor相关配置3.3.3 Nacos部署及相关配置3.3.4 工作负载deployment配…

在 Android 中使用 C/C++:初学者综合指南

在 Android 中使用 C/C&#xff1a;初学者综合指南 一、为什么有人在他们的 Android 项目中需要 C/C 支持&#xff1f;二、了解 C 如何集成到 Android 应用程序中三、C和Java程序的编译3.1 Java3.2 Android ART 和 DEX 字节码 四、使用 JNI 包装 C 源代码五、CMake和Android ND…

326. Power of Three(3 的幂)

题目描述 给定一个整数&#xff0c;写一个函数来判断它是否是 3 的幂次方。如果是&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false。 整数 n 是 3 的幂次方需满足&#xff1a;存在整数 x 使得 n 3 x n 3^x n3x 问题分析 要证明一个整数是三的幂次方&#…

【计算机网络】【练习题及解答】【新加坡南洋理工大学】【Computer Control Network】

说明&#xff1a; 仅供学习使用。 一、题目描述 题目共4问&#xff0c;描述网络通信中的 帧传输时延&#xff08;Frame Delay&#xff09;、传播时延&#xff08;Propagation Delay&#xff09;&#xff0c;以及 链接利用率&#xff08;Link Utilization&#xff09; 的相关…

《游戏-03_2D-开发》

基于《游戏-02_2D-开发》&#xff0c; 继续制作游戏&#xff1a; 首先要做的时切割人物Idle空闲状态下的动画&#xff0c; 在切割之前我们需要创建一个文件夹&#xff0c;用来存放动画控制器AnimatorContoller&#xff0c; 再创建一个人物控制器文件夹用来存放人物控制器&…

当前的人工智能忽略了人类最具有灵性的心理部分

在人工智能的发展中&#xff0c;目前人工智能的侧重点主要是在物理机理与数理符号计算方面。 物理机理是指人工智能系统对现实世界的感知和交互能力。例如&#xff0c;通过传感器和摄像头等设备获取环境信息&#xff0c;以及利用机器学习和深度学习等技术进行数据分析和模式识别…

pve宿主机更改网络导致没网,pve更改ip

一、问题描述 快过年了&#xff0c;我把那台一直在用的小型服务器&#xff0c;带回去了&#xff0c;导致网络发生了变更&#xff0c;需要对网络进行调整&#xff0c;否则连不上网&#xff0c;我这里改的是宿主机&#xff0c;不是pve虚拟机中的系统。 二、解决方法 pve用的是…