GAN生成漫画脸

news2025/1/19 23:06:02

最近对对抗生成网络GAN比较感兴趣,相关知识点文章还在编辑中,以下这个是一个练手的小项目~

 (在原模型上做了,为了减少计算量让其好训练一些。)

一、导入工具包

import tensorflow as tf
from tensorflow.keras import layers

import numpy as np
import os
import time
import glob
import matplotlib.pyplot as plt
from IPython.display import clear_output
from IPython import display

1.1 设置GPU

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
gpus 
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

二、导入训练数据

链接: 点这里

fileList = glob.glob('./ani_face/*.jpg')
len(fileList)
41621

2.1 数据可视化 

# 随机显示几张图
for index,i in enumerate(fileList[:3]):
    display.display(display.Image(fileList[index]))

2.2 数据预处理

# 文件名列表
path_ds = tf.data.Dataset.from_tensor_slices(fileList)

# 预处理,归一化,缩放
def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [64, 64])
    image /= 255.0  # normalize to [0,1] range
    image = tf.reshape(image, [1, 64,64,3])
    return image

image_ds = path_ds.map(load_and_preprocess_image)
image_ds
<MapDataset shapes: (1, 64, 64, 3), types: tf.float32>
# 查看一张图片
for x in image_ds:
    plt.axis("off")
    plt.imshow((x.numpy() * 255).astype("int32")[0])
    break

三、网络构建

3.1 D网络

discriminator = keras.Sequential(
    [
        keras.Input(shape=(64, 64, 3)),
        layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Flatten(),
        layers.Dropout(0.2),
        layers.Dense(1, activation="sigmoid"),
    ],
    name="discriminator",
)
discriminator.summary()
Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 32, 32, 64)        3136      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 16, 16, 128)       131200    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 8, 8, 128)         262272    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 128)         0         
_________________________________________________________________
flatten (Flatten)            (None, 8192)              0         
_________________________________________________________________
dropout (Dropout)            (None, 8192)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 8193      
=================================================================
Total params: 404,801
Trainable params: 404,801
Non-trainable params: 0

3.2 G网络

latent_dim = 128

generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        layers.Dense(8 * 8 * 128),
        layers.Reshape((8, 8, 128)),
        layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
    ],
    name="generator",
)
generator.summary()

3.3 重写 train_step

class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def train_step(self, real_images):
        # 生成噪音
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # 生成的图片
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # 训练判别器,生成的当成0,真实的当成1 
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Update metrics
        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {
            "d_loss": self.d_loss_metric.result(),
            "g_loss": self.g_loss_metric.result(),
        }

3.4 设置回调函数

class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=3, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images *= 255
        generated_images.numpy()
        for i in range(self.num_img):
            img = keras.preprocessing.image.array_to_img(generated_images[i])
            display.display(img)
            img.save("gen_ani/generated_img_%03d_%d.png" % (epoch, i))

四、训练模型

epochs = 100  # In practice, use ~100 epochs

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    loss_fn=keras.losses.BinaryCrossentropy(),
)

gan.fit(
    image_ds, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)]
)

五、保存模型

#保存模型
gan.generator.save('./data/ani_G_model')

生成模型文件:点这里

六、生成漫画脸

G_model =  tf.keras.models.load_model('./data/ani_G_model/',compile=False)

def randomGenerate():
    noise_seed = tf.random.normal([16, 128])
    predictions = G_model(noise_seed, training=False)
    fig = plt.figure(figsize=(8, 8))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        img = (predictions[i].numpy() * 255 ).astype('int')
        plt.imshow(img )
        plt.axis('off')
    plt.show()
count = 0
while True:
    randomGenerate()
    clear_output(wait=True)
    time.sleep(0.1)
    if count > 100:
        break
    count+=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/48744.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

tinymce富文本编辑器做评论区

今天分享一下tinymce富文本编辑器做评论区的全过程。 文章目录一、介绍1.最终效果2.功能介绍3.主要项目包版本介绍&#xff1a;二、每个功能的实现1.自定义toolbar的功能区①对应的样式以及意义②对应的代码实现【忽略了一切非实现该功能的代码】2.展示、收起评论区①对应的样式…

ctf工具之:mitmproxy实践测试

1、安装居然使用的pip pip install mitmproxy 导入证书&#xff0c;密码为空 2、启用mitmweb pause 直接可以查看方式 搜索里输入login 对于http协议 直接看到了密码原文 3、后台日志方式 录入和回放 mitmdump -w baidu.txt pause 录制结束 mitmdump -nC baidu.txt paus…

如何设计可扩展架构

架构设计复杂度模型 业务复杂度和质量复杂度是正交的 业务复杂度 业务固有的复杂度&#xff0c;主要体现为难以理解、难以扩展&#xff0c;例如服务数量多、业务流程长、业务之间关系复杂 质量复杂度 高性能、高可用、成本、安全等质量属性的要求 架构复杂度应对之道 复杂…

MySQL备份与恢复

目录 一.数据备份的重要性 二.数据库备份的分类 2.1 物理备份 2.2 逻辑备份 2.3 完全备份&#xff08;只适合第一次&#xff09; 三.常见的备份方法 四.MySQL完全备份 4.1 MySQL完全备份优缺点 4.2 数据库完全备份分类 4.2.1 物理冷备份与恢复 五.完全备份 5.1 MySQ…

YOLO家族再度升级——阿里达摩院DAMO-YOLO重磅来袭

最近看到阿里达摩院发表了他们的最新研究成果&#xff0c;在YOLO系列上推出的新的模型DAMO-YOLO&#xff0c;还没有来得及去仔细了解一下&#xff0c;这里只是简单介绍下&#xff0c;后面有时间的话再详细研究下。 官方项目在这里&#xff0c;首页截图如下所示&#xff1a; 目…

ASEMI整流桥UD4KB100,UD4KB100体积,UD4KB100大小

编辑-Z ASEMI整流桥UD4KB100参数&#xff1a; 型号&#xff1a;UD4KB100 最大重复峰值反向电压&#xff08;VRRM&#xff09;&#xff1a;1000V 最大平均正向整流输出电流&#xff08;IF&#xff09;&#xff1a;4A 峰值正向浪涌电流&#xff08;IFSM&#xff09;&#xf…

堆(C语言实现)

文章目录&#xff1a;1.堆的概念2.堆的性质3.堆的结构4.接口实现4.1初始化堆4.2销毁堆4.3打印堆内元素4.4向上调整4.5向堆中插入数据4.6向下调整4.7删除堆顶元素4.8查看堆顶元素4.9统计堆内数据个数4.10判断堆是否为空4.11堆的构建1.堆的概念 如果有一个关键码的集合&#xff0…

【Redis】缓存更新策略

1. 缓存更新策略综述 内存淘汰 不用自己维护&#xff0c;利用 Redis 自己的内存淘汰机制 &#xff08;内存不足时&#xff0c;触发策略&#xff0c;默认开启&#xff0c;可自己配置&#xff09;&#xff0c;其可在一定程度上保持数据一致性 超时剔除 给数据添加 TTL&#x…

【电力运维】浅谈电力通信与泛在电力物联网技术的应用与发展

摘要&#xff1a;随着我国社会经济的快速发展&#xff0c;我国科技实力得到了巨大的提升&#xff0c;当前互联网通信技术在社会中得到了广泛的应用。随着电力通信技术的快速发展与更新&#xff0c;泛在电力物联网建设成为电力通讯发展的重要方向。本文已泛在电力物联网系统为核…

Docker使用

xshell和xftp软件下载 链接&#xff1a;https://pan.baidu.com/s/1G7DIw14UvOmTwU9SwtYILg 提取码&#xff1a;he18 --来自百度网盘超级会员V6的分享 docker相关资料&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1VcxvuJvBIKNKnUUHPlM3MA 提取码&#xff1a;6w5e …

一些常见的项目管理 KPI

本文将介绍一些常见的项目管理kpi&#xff0c;让大家更深刻的了解其作用及所存在的问题。 一、关键绩效指标的作用 在 GPS 和其他现代导航方法出现之前&#xff0c;水手和探险家们只能通过星星找到正确的方向。特别是在北半球&#xff0c;他们利用北极星找出真正的北方方位。…

[附源码]SSM计算机毕业设计医学季节性疾病筛查系统JAVA

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

Mysql高频面试题(一)

文章目录1. Mysql如何实现的索引机制&#xff1f;2. InnoDB索引与MyISAM索引实现的区别是什么&#xff1f;3. 一个表中如果没有创建索引&#xff0c;那么还会创建B树吗&#xff1f;4. B树索引实现原理&#xff08;数据结构&#xff09;5. 聚簇索引与非聚簇索引的B树实现有什么区…

Vector源码分析

Vector源码分析 1 Vector基本介绍与类图 Vector 类实现了一个动态数组。和 ArrayList 很相似,但是两者是不同的: Vector 是同步访问的。Vector 包含了许多传统的方法,这些方法不属于集合框架。Vector 主要用在事先不知道数组的大小,或者只是需要一个可以改变大小的数组的…

pytest + yaml 框架 - 1.我们发布上线了!

前言 基于 httprunner 框架的用例结构&#xff0c;我自己开发了一个pytest yaml 的框架&#xff0c;那么是不是重复造轮子呢&#xff1f; 不可否认 httprunner 框架设计非常优秀&#xff0c;但是也有缺点&#xff0c;httprunner3.x的版本虽然也是基于pytest框架设计&#xff…

Spring中JDK与Cglib动态代理的区别

靠Spring吃饭的小伙伴一定经常听说动态代理这个词&#xff0c;没错&#xff0c;Aop就是靠它来实现的。Spring提供了两种代理模式&#xff1a;JDK动态代理、Cglib动态代理&#xff0c;供我们选择&#xff0c;那他们有啥区别呢&#xff1f;Sping为啥不自己从中挑选一个作为代理模…

IB物理的费曼图怎么考?

费曼图是用来描述基本粒子间相互作用的图形化表示&#xff0c;由诺贝尔物理学奖得主、著名物理学家理查德费曼&#xff08;Richard Feynman&#xff09;提出&#xff0c;十分清晰直观。虽然真正的费曼图可以用来做更深奥的数学计算&#xff0c;但是在IB物理中&#xff0c;考纲要…

那些惊艳一时的 CSS 属性

1.position: sticky 不知道大家平时业务开发中有没有碰到像上图一样的吸顶的需求&#xff1a;标题在滚动的时候&#xff0c;会一直贴着最顶上。 这种场景实际上很多&#xff1a;比如表格的标题栏、网站的导航栏、手机通讯录的人名首字母标题等等。如果让大家自己动手做的话&…

flink学习

Flink学习之路&#xff08;一&#xff09;Flink简介 - 走看看 Flink(一)-基本概念 - 知乎 Flink架构&#xff1a; Flink整个系统包含三个部分&#xff1a; 1、Client&#xff1a; 给用户提供向Flink系统提交用户任务&#xff08;流式作业&#xff09;的能力。用户提交一个F…

大型商场借力泛微,实现内外协同招商,合同、铺位、费用统一管理

对即将开业或是面临调整改造的购物中心来说&#xff0c;用什么样的方式才能快速地达成招商目的&#xff0c;实现资产价值的保值和增值&#xff0c;成为商业操盘手们共同面临的难题…… 行业需求 • 建立充足的品牌资源储备&#xff0c;拓宽招商渠道和线索&#xff0c;提高成交…