图像生成—使用GANs给出代码示例

news2024/11/16 3:22:27

文章目录

  • 图像生成简单介绍—使用GANs给出代码示例
    • 1. 什么是生成对抗网络(GANs)
    • 2. 准备数据集
    • 3. 构建生成器和判别器
    • 4. 训练GAN模型
    • 5. 生成新图像
    • 6. 总结

图像生成简单介绍—使用GANs给出代码示例

图像生成是指使用计算机算法生成图像的过程。这些图像可以是真实的照片、绘画、3D渲染或者是完全想象的图像。图像生成技术涵盖了一系列算法,包括基于规则的方法、基于统计学的方法、深度学习等。

基于规则的方法通常是通过手动设计规则来生成图像。例如,计算机图形学中的几何建模就是一种基于规则的方法,通过定义几何形状、光照、材质等参数来生成图像。

基于统计学的方法则是通过对大量图像数据进行分析,学习数据中的规律,然后使用这些规律来生成新的图像。这些方法包括基于纹理的方法、基于样式的方法等。

深度学习方法则是最近几年兴起的一种生成图像的方法,它利用神经网络模型进行训练,以学习输入图像和输出图像之间的映射关系。这些模型包括生成对抗网络(GAN)、变分自编码器(VAE)等,能够生成高质量、逼真的图像。

图像生成技术在许多领域都有应用,例如计算机游戏、电影制作、虚拟现实、视觉特效等。同时,它也在艺术创作、产品设计、医学图像处理等领域得到广泛应用。

图像生成是一种涉及生成新图像样本的技术,通常基于深度学习模型。在这份教程中,我们将介绍如何使用生成对抗网络(GANs)生成图像。

1. 什么是生成对抗网络(GANs)

生成对抗网络(GANs)是一种深度学习技术,由两个独立的神经网络组成:生成器(Generator)和判别器(Discriminator)。生成器的任务是生成与真实图像类似的图像,而判别器的任务是区分生成的图像是否为真实图像。这两个网络相互竞争,生成器试图生成越来越真实的图像,而判别器试图越来越准确地识别生成的图像。

2. 准备数据集

首先,我们需要一个用于训练的图像数据集。这里,我们以CIFAR-10数据集为例进行说明。CIFAR-10包含10个类别的60000张32x32彩色图像。我们将使用PyTorch框架,首先需要安装并导入相应的库:

!pip install torch torchvision

import torch
import torchvision
import torchvision.transforms as transforms

接下来,加载和预处理数据:

transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

3. 构建生成器和判别器

接下来,我们需要构建生成器和判别器网络。这里,我们使用卷积层和反卷积层构建网络。生成器的输入是随机噪声,输出是生成的图像;判别器的输入是图像,输出是它判断图像是否为真实图像的概率。

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.main(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.main(x)
        return x

generator = Generator()
discriminator = Discriminator()

4. 训练GAN模型

为了训练模型,我们需要定义损失函数和优化器。这里我们使用二元交叉熵损失(Binary Cross Entropy Loss)和Adam优化器。

criterion = nn.BCELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

现在我们可以开始训练模型。在每个训练循环中,我们首先训练判别器,然后训练生成器。

save_every = 10  # 保存模型的频率,每训练十次保存一次
start_time = time.time()  # 获取开始时间

log_file = open("training_log.txt", "a")

num_epochs = 100
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(trainloader):
        # Train discriminator
        discriminator.zero_grad()
        real_images = images.to(device)
        batch_size = real_images.size(0)
        label = torch.full((batch_size,), 1, device=device, dtype=torch.float)
        output = discriminator(real_images).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()

        noise = torch.randn(batch_size, 100, 1, 1, device=device)
        fake_images = generator(noise).to(device)
        label.fill_(0.0)
        output = discriminator(fake_images.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        errD = errD_real + errD_fake
        optimizer_d.step()

        # Train generator
        generator.zero_grad()
        label.fill_(1)
        output = discriminator(fake_images).view(-1)
        errG = criterion(output, label)
        errG.backward()
        optimizer_g.step()

    # 保存模型
    if (epoch+1) % save_every == 0:
        generator_name = f'generator_{epoch+1}.pth'
        discriminator_name = f'discriminator_{epoch+1}.pth'
        torch.save(generator, generator_name)
        torch.save(discriminator, discriminator_name)


    # 将输出打印到终端并保存到log文件中
    log_str = f'Epoch [{epoch+1}/{num_epochs}] Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} Time: {time.time()-start_time:.2f}s\n'
    print(log_str)
    log_file.write(log_str)
    log_file.flush()  # 刷新缓冲区

log_file.close()

代码开始运行

在这里插入图片描述

5. 生成新图像

训练完成后,我们可以使用生成器来生成新的图像。

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.main(x)
        return x

# generator = Generator()
# 加载保存的模型文件,'model/generator_82.pth'填写自己的生成模型文件路径
generator = torch.load('model/generator_82.pth', map_location=device)
generator.to(device)

with torch.no_grad():
   noise = torch.randn(1, 100, 1, 1, device=device)
   generated_image = generator(noise)


def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    np_img = img.cpu().numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

imshow(torchvision.utils.make_grid(generated_image.cpu()))

通过上述代码通过加载生成器模型,可以生成图片,这个模型训练的次数一般越多越好,我训练的82次,也就图一乐,虽说图片啥也不是,但是勉强有图片轮廓。
在这里插入图片描述

如果不调用生成模型,将上述代码修改为如下

generator = Generator()

# 加载保存的模型文件,'model/generator_82.pth'填写自己的生成模型文件路径
# generator = torch.load('model/generator_82.pth', map_location=device)
generator.to(device)

重新运行代码,可得到下面随机噪声的图片,说明我们生成器模型是有点作用的,刚才的图片并不是随机噪声,随机噪声是下面这种图片。

在这里插入图片描述

6. 总结

在本教程中,我们介绍了如何使用生成对抗网络(GANs)生成图像。我们以 CIFAR-10 数据集为例,构建了生成器和判别器网络,并进行了训练。最后,我们使用训练好的生成器生成了新的图像。

GANs 是一种非常强大的图像生成技术,但训练过程可能具有挑战性。为了获得高质量的生成图像,可能需要调整网络结构、损失函数和训练参数。此外,还有许多 GANs 的变体可供尝试,如 Deep Convolutional GANs(DCGANs)、Wasserstein GANs(WGANs)等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/671216.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++ | 多线程使用vector

多线程使用vector 文章目录 多线程使用vector场景描述原因分析解决代码测试不扩容和提前扩容 size 与 capacity 变化欢迎关注公众号【三戒纪元】 场景描述 最近在看代码优化&#xff0c;看到有这样的代码&#xff1a; std::vector<int> valid_indices;void SimbaSegmen…

APP测试面试题快问快答(五)

21. App自动化你用的什么工具&#xff1f; 框架&#xff1a;Appium 编译环境和工具&#xff1a;python3.7和PyCharm 环境&#xff1a;Android sdk 第三方模拟器&#xff1a;夜神、蓝叠等模拟器 定位工具&#xff1a;uiautomatorviewer 实时日志查看&#xff1a;ddms 22.…

Tapdata 重磅更新已就绪!全托管云服务上线,应用场景再扩展

继 5 月举办的 「连接 1 次孤岛&#xff0c;服务 N 个场景」主题产品发布会后&#xff0c;Tapdata Live Data Platform 现已实现功能特性的全面升级&#xff0c;并基于自身产品能力积极探索在应用场景层面的落地实践及无限可能。 在去年 6 月的 Tapdata 2.0 发布会上&#xff…

Threejs实现数字人3D粽子

个人主页&#xff1a; 左本Web3D&#xff0c;更多案例预览请点击》 在线案例 个人简介&#xff1a;专注Web3D使用ThreeJS实现3D效果技巧和学习案例 &#x1f495; &#x1f495;积跬步以至千里&#xff0c;致敬每个爱学习的你。喜欢的话请三连&#xff0c;有问题请私信或者加微…

美国访问学者的父母如何申请探亲签证?

对于美国访问学者的父母来说&#xff0c;申请探亲签证是能够让他们在美国与子女团聚的重要途径。下面是知识人网小编整理的一些关于如何申请探亲签证的基本步骤和要点&#xff0c;希望对您有所帮助。 第一步&#xff1a;了解签证类型 在开始申请探亲签证之前&#xff0c;父母需…

【增值税发票识别 OCR】如何实现自动化发票管理

导言 在现代商业环境中&#xff0c;管理和处理大量的增值税发票数据是一项繁琐而重要的任务。传统的手动处理方法既费时又容易出错&#xff0c;而使用增值税发票识别OCR API可以实现自动化的发票管理&#xff0c;大大减少人工处理的工作量。本文将介绍如何利用增值税发票识别O…

Hadoop --- HDFS介绍

HDFS 全称是Hadoop Distributed File System hadoop分布式&#xff08;cluser&#xff09;文件存储系统。适合一次写入&#xff0c;多次读出的场景。 HDFS不需要单独安装&#xff0c;安装Hadoop的时候带了HDFS系统。 Hadoop安装可以参考&#xff1a; 有基础的&#xff0c;已…

轻松了解OPC:实时数据通信领域的必备神器!

OPC简介 OPC&#xff08;OLE for Process Control&#xff0c;进程控制对象连接&#xff09;是一种在工业自动化领域中被广泛使用的技术&#xff0c;它允许不同厂商的自动化设备之间进行通信和数据交换。 OPC技术最早是由美国的软件公司OPC Foundation推出的&#xff0c;它通…

【jsDelivr】jsDelivr - 一个免费、快速、可靠的为JS和开源项目服务的CDN

文章目录 jsDelivr 简介jsDelivr 工作原理 jsDelivr加速域名如下cdn.jsdelivr.net 2023/06/21 域名解析结果fastly.jsdelivr.net 2023/06/21 域名解析结果gcore.jsdelivr.net 2023/06/21 域名解析结果test1.jsdelivr.net 2023/06/21 域名解析结果 仓库拓展 - 其他CDNnpmESMGitH…

nginx nginx-module-vts 监控模块

nginx nginx-module-vts 监控模块 大纲 nginx-module-vts 安装nginx-module-vts 配置监控字段总结配置参数总结vhost_traffic_status_filter_by_host 使用vhost_traffic_status_filter_by_set_key 使用 nginx-module-vts 安装 nginx-module-vts 可以实现对nginx 各个虚拟主…

接口测试是什么?如何测试?

扫盲内容&#xff1a; 1.什么是接口&#xff1f; 2.接口都有哪些类型&#xff1f; 3.接口的本质是什么&#xff1f; 4.什么是接口测试&#xff1f; 5.问什么要做接口测试&#xff1f; 6.怎样做接口测试&#xff1f; 7.接口测测试点是什么&#xff1f; 8.接口测试都要掌…

在Linux系统实现服务器端和客户端的多线程并发通信

先导知识&#xff1a; 在Linux系统实现服务器端和客户端的套接字通信_小梁今天敲代码了吗的博客-CSDN博客 线程同步&#xff08;一&#xff09;_小梁今天敲代码了吗的博客-CSDN博客 线程同步&#xff08;二&#xff09;_小梁今天敲代码了吗的博客-CSDN博客 线程同步&#x…

通付盾入围《2023年度中国数字安全能力图谱(行业版)》

近日&#xff0c;数世咨询发布《2023年度中国数字安全能力图谱&#xff08;行业版&#xff09;》。通付盾作为以分布式数字身份和大数据决策智能技术为核心的数字化高端软件与服务提供商&#xff0c;凭借在数字安全领域的实力和影响力&#xff0c;入选政府、互联网两大行业细分…

实战react+ts+antd遇见的问题之自定义树形结构

目录 自定义编辑树搜索树形结构搜索算法原理 实时更改数据界面不随之发生变化 自定义编辑树 需求要求在每个节点的后面加上新增&#xff0c;编辑&#xff0c;删除按钮&#xff0c;并且能够点击编辑title的显示变成input输入框&#xff0c;antd的案例中没有这种情况&#xff0c…

逍遥自在学C语言 | 指针函数与函数指针

前言 在C语言中&#xff0c;指针函数和函数指针是强大且常用的工具。它们允许我们以更灵活的方式处理函数和数据&#xff0c;进而扩展程序的功能。 本文将介绍指针函数和函数指针的概念&#xff0c;并讲解一些常见的应用示例。 一、人物简介 第一位闪亮登场&#xff0c;有请…

金士顿U盘无法识别的修复软件,方便好用

一、PD V1.16 先打开“PDx16.exe”这个软件&#xff0c;插入U盘。就会在“DEVICE 1”那里检测到U盘&#xff08;如果没有&#xff0c;就用另外的软件&#xff09;。然后按“全部开始”。当完成好&#xff0c;再重新插入U盘。 二、2090&2090E_V1.6.9_普通版070628 1、插入…

工业机器人运动学与Matlab正逆解算法学习笔记(用心总结一文全会)(二)

文章目录 机器人逆运动学※ 代数解、几何解&#xff0c;解析解&#xff08;封闭解&#xff09;、数值解的含义与联系○ 代数解求 θ 1 \theta_1 θ1​、 θ 2 \theta_2 θ2​、 θ 3 \theta_3 θ3​※参考资料 求解 θ 1 \theta_1 θ1​ 求解 θ 3 \theta_3 θ3​ 求解 θ 2 \t…

JUC高级-0620

8. CAS 原子类&#xff1a;Atomic没有CAS之前&#xff1a;多线程环境不使用原子类保证线程安全i&#xff08;基本数据类型&#xff09;&#xff0c;可以使用synchronized&#xff0c;但是很重有CAS之后&#xff1a; 使用AtomicInteger.getAndIncrement这样的API&#xff0c;保…

ARM的半主机模式(Semihosting)

本文介绍ARM的半主机模式&#xff0c;并介绍在MCU进行调试时其他的调试方法和手段。 1.ARM半主机模式(Semihosting) ARM Semihosting是ARM平台的一个独特功能&#xff0c;它允许使用主机上的输入和输出函数&#xff0c;通过硬件调试器转发到微控制器&#xff0c;通过挂接到I/…

网络解析----faster rcnn

Faster R-CNN&#xff08;Region-based Convolutional Neural Network&#xff09;是一种基于区域的卷积神经网络用于目标检测任务的模型。它是一种两阶段的目标检测方法&#xff0c;主要包含以下几个步骤&#xff1a; Region Proposal Network&#xff08;RPN&#xff09;: F…