【GAN 图像生成】

news2024/9/30 17:06:18

理论知识学习:

PART 1:

生成对抗网络GAN 深度学习模型,用于生成数据

对抗式训练,生成器v判别器

DCGAN>WGAN>StyleGAN技术不断进化

GAN在艺术创作。数据增强领域应用越来越广泛

应用:

GAN在图像合成,数据增强,虚拟现实等领域有着广泛的应用。

StyleGAN2能够生成逼真的人脸图像,推动了计算机视觉和图形学的发展。

GAN也被用于生成式医学图像,帮助医生进行更准确的诊断。

PART2:

生成器Generator:生成数据

判别器Discriminator:负责区分数据和生成的数据

两者在训练中互相竞争,生成器努力生成愈来愈真实的数据,判别器不断提高其分辨能力。

损失函数&训练过程:

GAN:训练过程涉及到最小化的一个特定的损失函数,生成器和判别器的组合。

生成器的损失函数:生成的数据被判别器错误分类的概率

判别器的损失函数:正确分类真实和生成数据的概率

网络架构优化:

GAN的网络架构非常复杂,包括卷积神经网络,循环神经网络

网络优化:

训练GAN要仔细选择优化算法和学习率,以避免模式崩溃等问题。

PART3:GAN的高级概念

cGAN;

允许生成过程加入变量条件,是的生成的数据具有特定的属性。

可以生成特定风格的图像或者具有特定特征的人脸。

CycleGAN:循环对抗网络

CycleGAN能够在没有成对训练数据的情况下,实现不同域之间的图像转换。

通过循环一致性来保持转换过程中的原始结构信息。

变分自编码器VAE与GAN

VAE是一种生成模型,它通过编码器和解码器生成数据

GAN与VAE在生成质量和多样性上有所不同,两者可以互相补充。

PART4:

GAN在训练过程中容易出现不稳定,导致生成器和判别器之间的不平衡。

通过改进的优化算法和正则化技术,可以提高训练的稳定性。

问题:

模式崩溃,生成器开始生成非常相似或者重复的数据。

解决方案:

通过引入多样化和正则化和改进的网络架构来解决这样一问题。

PART5实操:MindSpore实现GAN图像生成

操作步骤:

实操

代码:

%%capture captured_output
# 实验环境已经预装了mindspore==2.3.0,如需更换mindspore版本,可更改下面 MINDSPORE_VERSION 变量
!pip uninstall mindspore -y
%env MINDSPORE_VERSION=2.3.0
!pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/${MINDSPORE_VERSION}/MindSpore/unified/aarch64/mindspore-${MINDSPORE_VERSION}-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.mirrors.ustc.edu.cn/simple


# 查看当前 mindspore 版本
!pip show mindspore



# 数据下载
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download(url, ".", kind="zip", replace=True)


import numpy as np
import mindspore.dataset as ds

batch_size = 64
latent_size = 100  # 隐码的长度

train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')

def data_load(dataset):
    dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False, num_samples=10000)
    # 数据增强
    mnist_ds = dataset1.map(
        operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),
        output_columns=["image", "latent_code"])
    mnist_ds = mnist_ds.project(["image", "latent_code"])

    # 批量操作
    mnist_ds = mnist_ds.batch(batch_size, True)

    return mnist_ds

mnist_ds = data_load(train_dataset)

iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)



import matplotlib.pyplot as plt

data_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):
    image = data_iter['image'][idx]
    figure.add_subplot(rows, cols, idx)
    plt.axis("off")
    plt.imshow(image.squeeze(), cmap="gray")
plt.show()



import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype

# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)


from mindspore import nn
import mindspore.ops as ops

img_size = 28  # 训练图像长(宽)

class Generator(nn.Cell):
    def __init__(self, latent_size, auto_prefix=True):
        super(Generator, self).__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 100] -> [N, 128]
        # 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维
        self.model.append(nn.Dense(latent_size, 128))
        self.model.append(nn.ReLU())
        # [N, 128] -> [N, 256]
        self.model.append(nn.Dense(128, 256))
        self.model.append(nn.BatchNorm1d(256))
        self.model.append(nn.ReLU())
        # [N, 256] -> [N, 512]
        self.model.append(nn.Dense(256, 512))
        self.model.append(nn.BatchNorm1d(512))
        self.model.append(nn.ReLU())
        # [N, 512] -> [N, 1024]
        self.model.append(nn.Dense(512, 1024))
        self.model.append(nn.BatchNorm1d(1024))
        self.model.append(nn.ReLU())
        # [N, 1024] -> [N, 784]
        # 经过线性变换将其变成784维
        self.model.append(nn.Dense(1024, img_size * img_size))
        # 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间
        self.model.append(nn.Tanh())

    def construct(self, x):
        img = self.model(x)
        return ops.reshape(img, (-1, 1, 28, 28))

net_g = Generator(latent_size)
net_g.update_parameters_name('generator')



 # 判别器
class Discriminator(nn.Cell):
    def __init__(self, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 784] -> [N, 512]
        self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512
        self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数
        # [N, 512] -> [N, 256]
        self.model.append(nn.Dense(512, 256))  # 进行一个线性映射
        self.model.append(nn.LeakyReLU())
        # [N, 256] -> [N, 1]
        self.model.append(nn.Dense(256, 1))
        self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]

    def construct(self, x):
        x_flat = ops.reshape(x, (-1, img_size * img_size))
        return self.model(x_flat)

net_d = Discriminator()
net_d.update_parameters_name('discriminator')


lr = 0.0002  # 学习率

# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')




import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpoint

total_epoch = 12  # 训练周期数
batch_size = 64  # 用于训练的训练集批量大小

# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'

checkpoints_path = "./result/checkpoints"  # 结果保存路径
image_path = "./result/images"  # 测试结果保存路径



%%time
# 生成器计算损失过程
def generator_forward(test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))
    return loss_g

# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    real_out = net_d(real_data)
    real_loss = adversarial_loss(real_out, ops.ones_like(real_out))
    fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))
    loss_d = real_loss + fake_loss
    return loss_d

# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())

def train_step(real_data, latent_code):
    # 计算判别器损失和梯度
    loss_d, grads_d = grad_d(real_data, latent_code)
    optimizer_d(grads_d)
    loss_g, grads_g = grad_g(latent_code)
    optimizer_g(grads_g)

    return loss_d, loss_g

# 保存生成的test图像
def save_imgs(gen_imgs1, idx):
    for i3 in range(gen_imgs1.shape[0]):
        plt.subplot(5, 5, i3 + 1)
        plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")
        plt.axis("off")
    plt.savefig(image_path + "/test_{}.png".format(idx))

# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)

net_g.set_train()
net_d.set_train()

# 储存生成器和判别器loss
losses_g, losses_d = [], []

for epoch in range(total_epoch):
    start = time.time()
    for (iter, data) in enumerate(mnist_ds):
        start1 = time.time()
        image, latent_code = data
        image = (image - 127.5) / 127.5  # [0, 255] -> [-1, 1]
        image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])
        d_loss, g_loss = train_step(image, latent_code)
        end1 = time.time()
        if iter % 10 == 10:
            print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "
                  f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "
                  f"loss_d:{d_loss.asnumpy():>4f} , "
                  f"loss_g:{g_loss.asnumpy():>4f} , "
                  f"time:{(end1 - start1):>3f}s, "
                  f"lr:{lr:>6f}")

    end = time.time()
    print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))

    losses_d.append(d_loss.asnumpy())
    losses_g.append(g_loss.asnumpy())

    # 每个epoch结束后,使用生成器生成一组图片
    gen_imgs = net_g(test_noise)
    save_imgs(gen_imgs.asnumpy(), epoch)

    # 根据epoch保存模型权重文件
    if epoch % 1 == 0:
        save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))
        save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))



plt.figure(figsize=(6, 4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(losses_g, label="G", color='blue')
plt.plot(losses_d, label="D", color='orange')
plt.xlim(-5,15)
plt.ylim(0, 3.5)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()


import cv2
import matplotlib.animation as animation

# 将训练过程中生成的测试图转为动态图
image_list = []
for i in range(total_epoch):
    image_list.append(cv2.imread(image_path + "/test_{}.png".format(i), cv2.IMREAD_GRAYSCALE))
show_list = []
fig = plt.figure(dpi=70)
for epoch in range(0, len(image_list), 5):
    plt.axis("off")
    show_list.append([plt.imshow(image_list[epoch], cmap='gray')])

ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
ani.save('train_test.gif', writer='pillow', fps=1)



import mindspore as ms

test_ckpt = './result/checkpoints/Generator11.ckpt'

parameter = ms.load_checkpoint(test_ckpt)
ms.load_param_into_net(net_g, parameter)
# 模型生成结果
test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32))
images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy()
# 结果展示
fig = plt.figure(figsize=(3, 3), dpi=120)
for i in range(25):
    fig.add_subplot(5, 5, i + 1)
    plt.axis("off")
    plt.imshow(images[i].squeeze(), cmap="gray")
plt.show()


from datetime import datetime
import pytz
beijing_tz=pytz.timezone('Asia/Shanghai')
current_beijing_time=datetime.now(beijing_tz)
formatted_time=current_beijing_time.strftime('%Y-%m-%d %H:%M:%S')
print("当前北京时间:",formatted_time,'name')





 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2180187.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

生信初学者教程(十二):数据汇总

文章目录 介绍加载R包导入数据汇总表格输出结果总结介绍 在本教程中,汇总了三个肝细胞癌(HCC)的转录组数据集,分别是LIRI-JP,LIHC-US/TCGA-LIHC和GSE14520,以及一个HCC的单细胞数据集GSE149614的临床表型信息。这些数据集为科研人员提供了丰富的基因表达数据和相关的临床…

SqlAlchemy使用教程(七) 异步访问数据库

SqlAlchemy使用教程(一) 原理与环境搭建SqlAlchemy使用教程(二) 入门示例及编程步骤SqlAlchemy使用教程(三) CoreAPI访问与操作数据库详解SqlAlchemy使用教程(四) MetaData 与 SQL Express Language 的使用SqlAlchemy使用教程(五) ORM API 编程入门SqlAlchemy使用教程(六) – O…

Find My储物盒|苹果Find My技术与储物盒结合,智能防丢,全球定位

储物盒是用来存储,收藏东西的器具。储物盒可以帮助用户合理利用有限的空间,通过分类归置物品,避免浪费和混乱。储物盒能够有效地保护存放的物品,防止它们受到灰尘、污渍、损坏和潮湿的影响。储物盒还可以增加空间利用率、方便搬家…

Windows环境下使用Docker配置MySQL数据库

用Docker配置数据库,无论是做开发,还是做生产部署,都非常的方便 它不需要单独安装数据库,也不用担心出现各种环境的配置问题。 本文将分享用Docker配置数据库的步骤,这里用MySQL举例。 其他的数据库如MSSQL&#xf…

全球IP归属地查询-IP地址查询-IP城市查询-IP地址归属地-IP地址解析-IP位置查询-IP地址查询API接口

IP地址城市版查询接口 API是指能够根据IP地址查询其所在城市等地理位置信息的API接口。这类接口在网络安全、数据分析、广告投放等多个领域有广泛应用。以下是一些可用的IP地址城市版查询接口API及其简要介绍 1. 快证 IP归属地查询API 特点:支持IPv4 提供高精版、…

Scalefit:有效避免工作场所运动损伤的解决方案

在当今快节奏的工作环境中,运动损伤已成为一个不容忽视的问题。长时间的久坐、重复性动作以及缺乏适当的运动,都可能导致肌肉骨骼损伤、关节疼痛等问题。作为一款专注于运动健康管理的平台,Scalefit Industrial Athlete通过科学的方法和个性化…

天坑!Spark+Hive+Paimon+Dolphinscheduler

背景: 数据中台项目使用Spark+Hive+Paimon做湖仓底层,调度任务使用的是基于Dolphinscheduler进行二开。在做离线脚本任务开发时,在Paimon库下执行非查询类SQL报错。 INSERT报错 DELETE报错 现状: 原始逻辑为数据中台中选择的Paimon数据源,实际上在Dolphinscheduler中是…

生成靶标图像代码——C语言代码实现

1. 生成左右相机拍摄的3个彩色靶标的图像 两个相机在x轴方向上平移 // 生成左右相机拍摄3个靶标时的图像 生成彩色靶标 #include <stdio.h> #include <stdlib.h> #include <math.h>// 图像尺寸 #define WIDTH 1920 #define HEIGHT 1080// BMP头信息 #pra…

掌握自动化测试必要的几种技能?

1.自动化测试员技能——编程语言 当我开始担任手动测试人员时&#xff0c;我不喜欢编码。但是&#xff0c;当我逐渐进入自动化领域时&#xff0c;对我来说很清楚&#xff0c;如果没有对编程语言的一些基本了解&#xff0c;就无法编写逻辑自动化测试脚本。 对编程有一点了解&a…

短视频矩阵源码oem/矩阵系统搭建/源码开发注意事项知识分享

短视频矩阵系统的源码框架主要涵盖Spring、Struts与Hibernate三种。Spring是一款全栈式Java应用开发框架&#xff0c;集成了IOC容器、AOP以及事务管理等关键功能。Struts则基于MVC架构设计&#xff0c;用于Web应用程序的开发&#xff0c;有效分离数据模型、用户界面及控制器逻辑…

全面指南:探索并实施解决Windows系统中“mfc140u.dll丢失”的解决方法

当你的电脑出现mfc140u.dll丢失的问题是什么情况呢&#xff1f;mfc140u.dll文件依赖了什么&#xff1f;mfc140u.dll丢失会导致电脑出现什么情况&#xff1f;今天这篇文章就和大家聊聊mfc140u.dll丢失的解决办法。希望能够有效的帮助你解决这问题。 哪些程序依赖mfc140u.dll文件…

深圳市软件行业协会领导到访开源网安,共筑大湾区数字经济安全未来

近日&#xff0c;深圳市软件行业协会会长邓爱国、秘书长郑飞等一行人到访开源网安进行参观交流。双方以网信行业技能培训、软件安全开发能力评价和智能网联汽车安全测试等方面为探讨方向&#xff0c;对未来的合作进行了深入交流。 在参观过后&#xff0c;深圳市软件行业协会相关…

查找满足条件的行序号

有 2022 年 1 月的日销售额统计表如下所示&#xff1a; 找出日销售额大于 1000 的日子&#xff1a; spl("E(?1).pselecta(Sales>1000)",A1:B32)pselecta()返回所有满足条件的记录序号&#xff0c;pselect() 则只返回第一个满足条件的行序号 免费课程、免费软件下…

Elasticsearch要点简记

Elasticsearch要点简记 1、ES概述2、基础概念&#xff08;1&#xff09;索引、文档、字段&#xff08;2&#xff09;映射&#xff08;3&#xff09;DSL 3、架构原理4、索引字段的数据类型5、ES的三种分页方式&#xff08;1&#xff09;深度分页&#xff08;fromsize&#xff09…

码随想录算法训练营第60天|卡码网:94. 城市间货物运输 I、95. 城市间货物运输 II、96. 城市间货物运输 III

1.卡码网&#xff1a;94. 城市间货物运输 I 题目链接&#xff1a;https://kamacoder.com/problempage.php?pid1152 文章链接&#xff1a;https://www.programmercarl.com/kamacoder/0094.城市间货物运输I-SPFA.html 思路&#xff1a; 只对 上一次松弛的时候更新过的节点作为出…

智能餐饮:Spring Boot 点餐系统

第四章 系统设计 4.1 系统体系结构 网上点餐系统的结构图4-1所示&#xff1a; 图4-1 系统结构 模块包括主界面&#xff0c;首页、个人中心、用户管理、美食店管理、美食分类管理、美食信息管理、美食订单管理、美食评价管理、系统管理等进行相应的操作。 登录系统结构图&…

2025年3月PMP考试《PMBOK®指南》第六版不再作为参考资料

大家都知道《PMBOK指南》第六版是PMP认证考试的必备教材&#xff0c;由项目管理协会&#xff08;PMI&#xff09;指定。本书详细介绍了项目管理的5个过程组&#xff0c;并对项目管理的10个知识领域进行了阐述。 就在2024.9.30昨天的时候中国国际人才交流基金会公布了&#xff…

崖山数据库的共享集群机制初探

本文作者&#xff1a;YashanDB高级服务工程师周国超 YashanDB共享集群是崖⼭数据库系统&#xff08;YashanDB&#xff09;的⼀个关键特性&#xff0c;它是⼀个单库多实例的多活数据库系统。⽤⼾可以连接到任意实例访问同⼀个数据库&#xff0c;多个数据库实例能够并发读写同⼀…

QT对QBytearray的data()指针进行结构体转换时会自动字节对齐填充

1、测试代码 #include <QCoreApplication>#pragma pack(push, 1) typedef struct {int a;float b;char c;int *d; }testStruct; #pragma pack(pop)#include <QByteArray> #include <QDebug>int main() {testStruct structA;structA.a 1;structA.b 2;struc…

正点原子阿波罗STM32F429IGT6移植zephyr rtos(二)---使用I2C驱动MPU6050

硬件平台&#xff1a;正点原子阿波罗STM32F429IGT6 zephyr版本&#xff1a;Zephyr version 3.7.99 开发环境&#xff1a;ubuntu 24.4 zephyr驱动开发与之前接触到的开发方式可能都不一样&#xff0c;更像是linux驱动开发&#xff0c;zephyr源码里边其实已经有写好的I2C和MPU60…