进行生成简单数字图片

news2024/11/25 19:38:55

1.之前只能做一些图像预测,我有个大胆的想法,如果神经网络正向就是预测图片的类别,如果我只有一个类别那就可以进行生成图片,专业术语叫做gan对抗网络
在这里插入图片描述
2.训练代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as dset
import matplotlib.pyplot as plt
import os

# 设置环境变量
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# 定义生成器模型
class Generator(nn.Module):
    def __init__(self, input_dim=100, output_dim=784):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, output_dim)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.tanh(self.fc4(x))
        return x

# 定义判别器模型
class Discriminator(nn.Module):
    def __init__(self, input_dim=784, output_dim=1):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, output_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.sigmoid(self.fc4(x))
        return x

# 加载 MNIST 手写数字图片数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataroot = "path_to_your_mnist_dataset"  # 替换为 MNIST 数据集的路径
dataset = dset.MNIST(root=dataroot, train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

# 创建生成器和判别器实例
input_dim = 100
output_dim = 784
generator = Generator(input_dim, output_dim)
discriminator = Discriminator(output_dim)

# 定义优化器和损失函数
lr = 0.0002
beta1 = 0.5
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
criterion = nn.BCELoss()

# 训练 GAN 模型
num_epochs = 50
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)
generator.to(device)
discriminator.to(device)
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        real_images, _ = data
        real_images = real_images.to(device)
        batch_size = real_images.size(0)  # 获取批次样本数量

        # 训练判别器
        optimizer_d.zero_grad()
        real_labels = torch.full((batch_size, 1), 1.0, device=device)
        fake_labels = torch.full((batch_size, 1), 0.0, device=device)
        noise = torch.randn(batch_size, input_dim, device=device)
        fake_images = generator(noise)
        real_outputs = discriminator(real_images.view(batch_size, -1))
        fake_outputs = discriminator(fake_images.detach())
        d_loss_real = criterion(real_outputs, real_labels)
        d_loss_fake = criterion(fake_outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        # 训练生成器
        optimizer_g.zero_grad()
        noise = torch.randn(batch_size, input_dim, device=device)
        fake_images = generator(noise)
        fake_outputs = discriminator(fake_images)
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        optimizer_g.step()

        # 输出训练信息
        if i % 100 == 0:
            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f]"
                  % (epoch, num_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))

    # 保存生成器的权重和图片示例
    if epoch % 10 == 0:
        with torch.no_grad():
            noise = torch.randn(64, input_dim, device=device)
            fake_images = generator(noise).view(64, 1, 28, 28).cpu().numpy()
            fig, axes = plt.subplots(nrows=8, ncols=8, figsize=(12, 12), sharex=True, sharey=True)
            for i, ax in enumerate(axes.flatten()):
                ax.imshow(fake_images[i][0], cmap='gray')
                ax.axis('off')
            plt.subplots_adjust(wspace=0.05, hspace=0.05)
            plt.savefig("epoch_%d.png" % epoch)
            plt.close()
        torch.save(generator.state_dict(), "generator_epoch_%d.pth" % epoch)

3.测试模型的代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image

# 定义生成器模型
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, output_dim)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = torch.tanh(self.fc4(x))
        return x

# 创建生成器模型
generator = Generator(input_dim=100, output_dim=784)

# 加载预训练权重
generator_weights = torch.load("generator_epoch_40.pth", map_location=torch.device('cpu'))

# 将权重加载到生成器模型
generator.load_state_dict(generator_weights)

# 生成随机噪声
noise = torch.randn(1, 100)

# 生成图像
fake_image = generator(noise).view(1, 1, 28, 28)

# 保存生成的图片
save_image(fake_image, "generated_image.png", normalize=False)

#测试结果,由于我的训练集是数字的,所以会生成各种各样的数字,下面明显的是1
在这里插入图片描述
#应该也是1
在这里插入图片描述

#再次运行,我也看不出来,不过只要我训练只有一个种类的问题就可以生成这个种类的图像
在这里插入图片描述
#搞定黑白图,那彩色图应该距离不远了,我需要改进的是把对抗网络的代码改为训练一个种类的图形,不过我感觉这种图形具有随机性,虽然通过训练我们得到了所有图像他们的规律,但是如果需要正常点的图片还是挺难的,就像是上面这张人都不一定知道他是什么东西(在没有颜色的情况下)总结就是精度不够,而且随机性太强了,现在普遍图片AI生成工具具有这个缺点(生成的物体可能会扭曲,挺阴间的),而且生成的图片速度慢,如果谁比较受益那一定是老黄(英伟达)哈哈哈
//比如下面这个图片生成视频的网站
https://app.runwayml.com/login

#每一帧看起来都没有问题,就是连起来变成视频不自然,如果有改进方法的话那可能需要引入重力/加速度/光处理 等等物理公式,来让图片更自然…
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1292240.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

排序分析(Ordination analysis)及R实现

在生态学、统计学和生物学等领域,排序分析是一种用于探索和展示数据结构的多元统计技术。这种分析方法通过将多维数据集中的样本或变量映射到低维空间,以便更容易理解和可视化数据之间的关系。排序分析常用于研究物种组成、生态系统结构等生态学和生物学…

java--枚举

1.枚举 枚举是一种特殊类 2.枚举类的格式 注意: ①枚举类中的第一行,只能写一些合法的标识符(名称),多个名称用逗号隔开。 ②这些名称,本质是常量,每个常量都会记住枚举类的一个对象。 3.枚举类的特点 ①枚举类的…

Java 第21章 网络通信

网络程序设计基础 网络程序设计编写的是与其他计算机进行通信的程序。Java 已经将网络程序所需要的元素封装成不同的类,用户只要创建这些类的对象,使用相应的方法,即使不具备有关的网络支持,也可以编写出高质量的网络通信程序。 …

JPA(Java Persistence API)是什么

JPA的官网地址:https://jcp.org/en/jsr/detail?id338 当前最新的版本是2.2版本:https://jcp.org/aboutJava/communityprocess/mrel/jsr338/index.html JPA是一个Java持久化的API,通过这个API,在Java EE和Java SE 环境中管理持…

鸿蒙开发ServiceAbility基本概念

时间过长,开发者必须在Service里创建新的线程来处理(详见线程间通信),防止造成主线程阻塞,应用程序无响应。 创建Service 介绍如何创建一个Service 创建Service的代码示例如下:查看获取鸿蒙开发 (qq.com)…

HTTP 缓存机制

一、强制缓存 只要浏览器判断缓存没有过期,则直接使用浏览器的本地缓存而无需再请求服务器。 强制缓存是利用下面这两个 HTTP 响应头部(Response Header)字段实现的,它们都用来表示资源在客户端缓存的有效期: Cache…

ChatGPT学习笔记

1 ChatGPT架构图 (ChatGPT_Diagram.svg来自于【OpenA | Introducing ChatGPT】) 2 模型训练 ChatGPT在训练时使用了PPO方法;

pandas空格及网页空格符NBSP替换处理

df3[动作一课程内容]df3[动作一课程内容].str.replace( ,) df3[动作一课程内容]df3[动作一课程内容].str.replace( ,) 截图中代码为python展示代码,由于网页空格符和常规空格符看起来大致相同,但却不能用常规空格替换解决

基于个微机器人的开发

简要描述: 下载消息中的动图 请求URL: http://域名/getMsgEmoji 请求方式: POST 请求头Headers: Content-Type:application/jsonAuthorization:login接口返回 参数: 参数名必选类型说明…

树莓派4B iio子系统 mpu6050

编写基于iio的mpu6050 遇到的问题,在读取数据时,读出来的数据不能直接拼接成int类型 需要先将其转换成short int,再转换成int 效果如图所示 注:驱动是使用的modprobe加载的 画的思维导图(部分,上传的…

预付费远传水表和传统水表有什么不同?

随着科技的发展,预付费远传水表作为一种新型智能水表,与传统水表相比有着许多不同之处。那么,预付费远传水表和传统水表究竟有什么不同呢? 首先,预付费远传水表具备智能化功能。与传统水表只能记录用水总量不同&#x…

代码随想录算法训练营第五十七天【动态规划part17】 | 647. 回文子串、516.最长回文子序列

647. 回文子串 题目链接 力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 求解思路 动规五部曲 1.确定dp数组及其下标含义 布尔类型的dp[i][j]:表示区间范围[i,j] (注意是左闭右闭)的子串是否是回文子串&#…

springboot 整合 Spring Security 中篇(RBAC权限控制)

1.先了解RBAC 是什么 RBAC(Role-Based Access control) ,也就是基于角色的权限分配解决方案 2.数据库读取用户信息和授权信息 1.上篇用户名好授权等信息都是从内存读取实际情况都是从数据库获取; 主要设计两个类 UserDetails和UserDetailsService 看下…

linux高级篇基础理论七(Tomcat)

♥️作者:小刘在C站 ♥️个人主页: 小刘主页 ♥️不能因为人生的道路坎坷,就使自己的身躯变得弯曲;不能因为生活的历程漫长,就使求索的 脚步迟缓。 ♥️学习两年总结出的运维经验,以及思科模拟器全套网络实验教程。专栏:云计算技…

主食罐头哪个牌子好?猫主食罐头品牌盘点

养猫的这几年德罐也买了不少了,很早以前德罐给我的感觉就是,物美价廉,而且质量保障也不错,很美丽。但最近的德罐恕在下高攀不起了。 猫罐头侠登场!养猫这么久了我就把我吃的不错的猫罐头分享一下!别纠结了…

Dockerfile 指令的最佳实践

这些建议旨在帮助您创建一个高效且可维护的Dockerfile。 一、FROM 尽可能使用当前的官方镜像作为镜像的基础。Docker推荐Alpine镜像,因为它受到严格控制,体积小(目前不到6 MB),同时仍然是一个完整的Linux发行版。 FR…

【FPGA】Quartus18.1打包封装网表文件(.qxp)详细教程

当我们在做项目的过程中,编写的底层Verilog代码不想交给甲方时怎么办呢?此时可以将源代码打包封装成网表文件(.qxp)进行加密,并且在工程中进行调用。 Quartus II的.qxp文件为QuartusII Exported Partition,…

马来西亚虾皮选品工具:如何优化您的电商业务

随着电子商务的快速发展,越来越多的商家开始将目光投向在线市场。在马来西亚,虾皮(Shopee)平台成为了一个备受瞩目的电商平台,吸引了大量的商家和消费者。然而,要在这个竞争激烈的市场中脱颖而出并取得成功…

基于JavaSE+JDBC使用控制台操作的简易购物系统【源码+数据库】

1、项目简介 本项目是一套基于JavaSEJDBC使用控制台操作的简易购物系统,主要针对计算机相关专业的正在做bishe的学生和需要项目实战练习的Java学习者。 包含:项目源码、数据库脚本等,该项目可以直接作为bishe使用。 项目都经过严格调试&…

计算n的阶乘-递归与迭代之间的相爱相杀

n的阶乘是指从1连乘到n的结果,通常用符号n!表示。例如,3的阶乘表示为3!,计算过程如下: 3! 3 2 1 6 一般地,n的阶乘可以用递归或迭代的方式计算,公式为: n! n (n-1) (n-2) ... 2 1 …