李宏毅机器学习作业6-使用GAN生成动漫人物脸

news2024/9/22 9:29:18

理论部分参考:​李宏毅机器学习——对抗生成网络(GAN)_iwill323的博客-CSDN博客

目录

任务和数据集

评价方法

FID

AFD (Anime face detection) rate

DCGAN和WGAN

代码

导包

建立数据集

显示一些图片

模型设置

生成器

判别器

权重初始化

训练函数

训练

读取数据

Set config

推断

GAN效果


任务和数据集

1. Input: 随机数,输入的维度是(batch size, 特征数)
2. Output: 动漫人物脸
3. Implementation requirement: DCGAN & WGAN & WGAN-GP
4. Target:产生1000动漫人物脸

数据来自Crypko网站,有71,314个图像

评价方法

FID

将真假图片送入另一个模型,产生对应的特征,计算真假特征的距离

AFD (Anime face detection) rate

1. To detect how many anime faces in your submission
2. The higher, the better


DCGAN和WGAN

 

代码

导包

# import module
import os
import glob
import random
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch import autograd
from torch.autograd import Variable

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import logging
from tqdm import tqdm
from d2l import torch as d2l


# seed setting
def same_seeds(seed):
    # Python built-in random module
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Torch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

same_seeds(2022)
workspace_dir = 'data/faces'

建立数据集

# prepare for CrypkoDataset

class CrypkoDataset(Dataset):
    def __init__(self, fnames, transform):
        self.transform = transform
        self.fnames = fnames
        self.num_samples = len(self.fnames)

    def __getitem__(self,idx):
        fname = self.fnames[idx]
        img = Image.open(fname)
        img = self.transform(img)
        return img

    def __len__(self):
        return self.num_samples

def get_dataset(root):
    # glob.glob返回匹配给定通配符的文件列表
    fnames = glob.glob(os.path.join(root, '*')) # list
    transform = transforms.Compose([        
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ])
    dataset = CrypkoDataset(fnames, transform)
    return dataset

显示一些图片

temp_dataset = get_dataset(os.path.join(workspace_dir, 'faces'))

images = [temp_dataset[i] for i in range(4)]
grid_img = torchvision.utils.make_grid(images, nrow=4)
plt.figure(figsize=(10,10))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

模型设置

生成器

# Generator

class Generator(nn.Module):
    """
    Input shape: (batch, in_dim)
    Output shape: (batch, 3, 64, 64)
    """
    def __init__(self, in_dim, feature_dim=64):
        super().__init__()
    
        #input: (batch, 100)
        self.l1 = nn.Sequential(
            nn.Linear(in_dim, feature_dim * 8 * 4 * 4, bias=False),
            nn.BatchNorm1d(feature_dim * 8 * 4 * 4),
            nn.ReLU()
        )
        self.l2 = nn.Sequential(
            self.dconv_bn_relu(feature_dim * 8, feature_dim * 4),               #(batch, feature_dim * 16, 8, 8)     
            self.dconv_bn_relu(feature_dim * 4, feature_dim * 2),               #(batch, feature_dim * 16, 16, 16)     
            self.dconv_bn_relu(feature_dim * 2, feature_dim),                   #(batch, feature_dim * 16, 32, 32)     
        )
        self.l3 = nn.Sequential(
            nn.ConvTranspose2d(feature_dim, 3, kernel_size=5, stride=2,
                               padding=2, output_padding=1, bias=False),
            nn.Tanh()   
        )
        self.apply(weights_init)
    def dconv_bn_relu(self, in_dim, out_dim):
        return nn.Sequential(
            nn.ConvTranspose2d(in_dim, out_dim, kernel_size=5, stride=2,
                               padding=2, output_padding=1, bias=False),        #double height and width
            nn.BatchNorm2d(out_dim),
            nn.ReLU(True)
        )
    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1, 4, 4)
        y = self.l2(y)
        y = self.l3(y)
        return y

判别器

# Discriminator
class Discriminator(nn.Module):
    """
    Input shape: (batch, 3, 64, 64)
    Output shape: (batch)
    """
    def __init__(self, model_type, in_dim, feature_dim=64):
        super(Discriminator, self).__init__()
            
        #input: (batch, 3, 64, 64)
        """
        Remove last sigmoid layer for WGAN
        """
        
        self.model_type = model_type
        
        self.l1 = nn.Sequential(
            nn.Conv2d(in_dim, feature_dim, kernel_size=4, stride=2, padding=1), #(batch, 3, 32, 32)
            nn.LeakyReLU(0.2),
            self.conv_bn_lrelu(feature_dim, feature_dim * 2),                   #(batch, 3, 16, 16)
            self.conv_bn_lrelu(feature_dim * 2, feature_dim * 4),               #(batch, 3, 8, 8)
            self.conv_bn_lrelu(feature_dim * 4, feature_dim * 8),               #(batch, 3, 4, 4)
            nn.Conv2d(feature_dim * 8, 1, kernel_size=4, stride=1, padding=0)            
        )        
        
        # WGAN的思路是将discriminator训练为距离函数,所以discriminator不需要最后的非线性sigmoid层
        if self.model_type == 'GAN':
            self.l1.add_module(
                'sigmoid', nn.Sigmoid() 
            )
        
        self.apply(weights_init)
        
    def conv_bn_lrelu(self, in_dim, out_dim):
        layer = nn.Sequential(
            nn.Conv2d(in_dim, out_dim, 4, 2, 1),
            nn.BatchNorm2d(out_dim),
            nn.LeakyReLU(0.2),
        )
        
        if self.model_type == 'WGAN-GP':
            layer[1] = nn.InstanceNorm2d(out_dim)
        
        return layer
    
    def forward(self, x):
        y = self.l1(x)
        y = y.view(-1)
        return y

权重初始化

# setting for weight init function
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

训练函数

  1. prepare_environment: prepare the overall environment, construct the models, create directory for the log and ckpt
  2. train: train for generator and discriminator, you can try to modify the code here to construct WGAN or WGAN-GP
  3. inference: after training, you can pass the generator ckpt path into it and the function will save the result for you

WGAN-GP部分运行不来,一个很大的问题是网上找到的代码,训练到10个epoch才能看到训练效果,之前全是噪音图,调试代价太大,不想改了。

class TrainerGAN():
    def __init__(self, config, devices):
        self.config = config        
        self.model_type = self.config["model_type"]
        self.devices = devices
        
        self.G = Generator(self.config["z_dim"])
        self.D = Discriminator(self.model_type, 3)  # 3代表输入通道数
                
        self.loss = nn.BCELoss()        

        if self.model_type == 'GAN' or self.model_type == 'WGAN-GP':
            self.opt_D = torch.optim.Adam(self.D.parameters(), lr=self.config["lr"], betas=(0.5, 0.999))
            self.opt_G = torch.optim.Adam(self.G.parameters(), lr=self.config["lr"], betas=(0.5, 0.999))
        elif self.model_type == 'WGAN':
            self.opt_D = torch.optim.RMSprop(self.D.parameters(), lr=self.config["lr"])
            self.opt_G = torch.optim.RMSprop(self.G.parameters(), lr=self.config["lr"])    

        self.dataloader = None
        self.log_dir = os.path.join(self.config["save_dir"], 'logs')
        self.ckpt_dir = os.path.join(self.config["save_dir"], 'checkpoints')
        
        FORMAT = '%(asctime)s - %(levelname)s: %(message)s'
        logging.basicConfig(level=logging.INFO, 
                            format=FORMAT,
                            datefmt='%Y-%m-%d %H:%M')
        
        self.steps = 0
        self.z_samples = torch.randn(100, self.config["z_dim"], requires_grad = True).to(self.devices[0])  # 打印100个看看生成的效果
        
    def prepare_environment(self):
        """
        Use this funciton to prepare function
        """
        os.makedirs(self.log_dir, exist_ok=True)
        os.makedirs(self.ckpt_dir, exist_ok=True)
        
        # update dir by time
        time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        self.log_dir = os.path.join(self.log_dir, time+f'_{self.config["model_type"]}')
        self.ckpt_dir = os.path.join(self.ckpt_dir, time+f'_{self.config["model_type"]}')
        os.makedirs(self.log_dir)
        os.makedirs(self.ckpt_dir)
        
        # model preparation
        self.G = self.G.to(self.devices[0])
        self.D = self.D.to(self.devices[0])
        self.G.train()
        self.D.train()
        
    def gp(self, r_imgs, f_imgs):
        """
        Implement gradient penalty function
        """
        Tensor = torch.FloatTensor
        alpha = Tensor(np.random.random((r_imgs.size(0), 1, 1, 1))).to(devices[0])
        interpolates = (alpha*r_imgs + (1 - alpha)*f_imgs).requires_grad_(True)
        d_interpolates = self.D(interpolates)
        fake = Tensor(r_imgs.shape[0]).fill_(1.0).to(devices[0])
        fake.requires_grad = False
        gradients = autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(1, dim=1) - 1)**2).mean()
        return gradient_penalty
        
    def train(self, dataloader):
        """
        Use this function to train generator and discriminator
        """
        self.prepare_environment()
        
        legend = ['Gen loss', 'Dis acc']        
        animator = d2l.Animator(xlabel='epoch', xlim=[0, self.config["n_epoch"]], legend=legend)
        num_batches = len(dataloader)
        show_batch = num_batches // self.config["show_num"] # 多少batch打印一次loss
        
        for e, epoch in enumerate(range(self.config["n_epoch"])):
            progress_bar = tqdm(self.dataloader)
            progress_bar.set_description(f"Epoch {e+1}")          
            for i, data in enumerate(dataloader):
                bs = data.size(0)  # batch size

                # *********************
                # *    Train D        *
                # *********************
                z = torch.randn(bs, self.config["z_dim"]).to(self.devices[0])
                r_imgs = data.to(self.devices[0])
                f_imgs = self.G(z)

                # Discriminator forwarding
                r_logit = self.D(r_imgs)  # 判断真实图像
                f_logit = self.D(f_imgs.detach())  # 判断生成的假图像  使用detach()是为了避免对G求导
                
                # SETTING DISCRIMINATOR LOSS
                if self.model_type == 'GAN':
                    r_label = torch.ones((bs)).to(self.devices[0])
                    f_label = torch.zeros((bs)).to(self.devices[0])
                    r_loss = self.loss(r_logit, r_label)
                    f_loss = self.loss(f_logit, f_label)
                    loss_D = (r_loss + f_loss) / 2
                elif self.model_type == 'WGAN':
                    loss_D = -torch.mean(r_logit) + torch.mean(f_logit)
                elif self.model_type == 'WGAN-GP':
                    loss_D = -torch.mean(r_logit) + torch.mean(f_logit) + self.gp(r_imgs, f_imgs) # 最后一项是gradient_penalty

                # Discriminator backwarding
                self.D.zero_grad()
                loss_D.backward()
                self.opt_D.step()                
                
                # SETTING WEIGHT CLIP:
                if self.model_type == 'WGAN':
                    for p in self.D.parameters():
                         p.data.clamp_(-self.config["clip_value"], self.config["clip_value"])

                # *********************
                # *    Train G        *
                # *********************
                if self.steps % self.config["n_critic"] == 0:
                    # Generator forwarding
                    f_logit = self.D(f_imgs)  # f_imgs没必要再生成一遍,甚至可以在训练前生成一个,来回使用

                    if self.model_type == 'GAN':
                        loss_G = self.loss(f_logit, r_label)
                    elif self.model_type == 'WGAN' or self.model_type == 'WGAN-GP':
                        loss_G = -torch.mean(self.D(f_imgs))                        

                    # Generator backwarding
                    self.G.zero_grad()
                    loss_G.backward()
                    self.opt_G.step()
                    
                    loss_G_sum += loss_G.item()
                
                if self.steps % 10 == 0:
                    progress_bar.set_postfix(loss_G=loss_G.item(), loss_D=loss_D.item())
                self.steps += 1         

            self.G.eval()
            # G()最后一层是tanh(), 输出是-1到1,也就是说,G()的输出要变成0-1才是图像
            f_imgs_sample = (self.G(self.z_samples).data + 1) / 2.0 
            filename = os.path.join(self.log_dir, f'Epoch_{epoch+1:03d}.jpg')
            torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
            logging.info(f'Save some samples to {filename}.')

            # Show some images during training.
            grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
            plt.figure(figsize=(10,10))
            plt.imshow(grid_img.permute(1, 2, 0))
            plt.show()

            self.G.train()

            if (e+1) % 5 == 0 or e == 0:
                # Save the checkpoints.
                torch.save(self.G.state_dict(), os.path.join(self.ckpt_dir, f'G_{e}.pth'))
                torch.save(self.D.state_dict(), os.path.join(self.ckpt_dir, f'D_{e}.pth'))

        logging.info('Finish training')

    def inference(self, G_path, n_generate=1000, n_output=30, show=False):
        """
        1. G_path is the path for Generator ckpt
        2. You can use this function to generate final answer
        """

        self.G.load_state_dict(torch.load(G_path))
        self.G.to(self.devices[0])
        self.G.eval()
        z = torch.randn(n_generate, self.config["z_dim"]).to(self.devices[0])
        imgs = (self.G(z).data + 1) / 2.0
        
        os.makedirs('output', exist_ok=True)
        for i in range(n_generate):
            torchvision.utils.save_image(imgs[i], f'output/{i+1}.jpg')
        
        if show:
            row, col = n_output//10 + 1, 10
            grid_img = torchvision.utils.make_grid(imgs[:n_output].cpu(), nrow=row)
            plt.figure(figsize=(row, col))
            plt.imshow(grid_img.permute(1, 2, 0))
            plt.show()

训练

读取数据

devices = d2l.try_all_gpus()
print(f'DEVICE: {devices}')

# create dataset by the above function
batch_size = 512
num_workers = 4
dataset = get_dataset(os.path.join(workspace_dir, 'faces'))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last = True)
print('训练集总长度是 {:d}, batch数量是 {:.2f}'.format(len(dataset), len(dataset)/batch_size))

Set config

config = {
    "model_type": "WGAN",    
    "lr": 1e-4,
    "n_epoch": 60,
    "n_critic": 5,  # 训练一次generator,多训练几次discriminator,效果更好 n_critic=5意味着训练比是1:5
    "z_dim": 100,
    "workspace_dir": workspace_dir, # define in the environment setting
    "save_dir": workspace_dir,
    'clip_value': 1,    
    'show_num': 12
}

trainer = TrainerGAN(config, devices)
trainer.train(dataloader)

推断

# save the 1000 images into ./output folder
trainer.inference(f'{workspace_dir}/checkpoints/2022-03-31_15-59-17_GAN/G_0.pth') # you have to modify the path when running this line

GAN效果

下面是GAN产生的图片,效果挺一般。我只是大体运行了一下,再调一调能好多了。

除了效果差,训练过中可以发现到了第22个epoch,图像突然会变差,前一个还是正常的人像(下面gif中暂停的那一幅图像),下一个epoch突然变坏,根据李宏毅2022机器学习HW6解析_机器学习手艺人的博客-CSDN博客,loss_G突然增大,loss_D接近于0,这说明后续的训练discriminator相对generator表现的太好,这与GAN的训练背道而驰,GAN训练最好的结果是loss_G小,loss_D大,也就是discriminator无法分辨generator的结果。

还有一个问题是,生成的图像多样性变差,具体原因老师上课讲过了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/12827.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

火山引擎:数字化时代,如何给金融业注入“内容活水”?

数字化,已经成为中国经济的一架强劲发动机。 工业和信息化部统计显示,中国数字经济规模从2012年的11万亿元增长到2021年的超45万亿元,排名世界第二,数字经济占国内生产总值比重由21.6%提升至39.8%。 数据,是数字化的…

git可视化工具-idea插件使用

上一篇文章说了git的命令行操作,是不是还沉浸在命令行在指间跳跃的兴奋中,这一篇再说一说在idea中如何使用git,会让人更兴奋了,也许你会认为这会是最好用的方式的。我想说这只是最好用的方式之一。 1.功能入口 当我们在idea里想使…

键盘输入语句和位运算

键盘输入语句键盘输入语句案例:可以从控制台接收用户信息,【姓名,年龄,薪水】进制介绍案例:输出 二,十,八,十六进制的数据位运算原码、反码、补码位运算符java 中有 7 个位运算(&…

数字工业 弹性安全丨2022 Fortinet工业互联网安全发展峰会成功举办

随着数字化转型的持续推进,工业互联网的作用和地位日益加强。而 OT 安全作为工业互联网体系不可或缺的部分,虽然受到越来越多企业的关注,但仍然面临着多方面的挑战。11月16日,一年一度的 OT 安全盛会——2022 Fortinet工业互联网安…

算法设计与分析 SCAU11091 最优自然数分解问题(优先做)

11091 最优自然数分解问题(优先做) 时间限制:1000MS 代码长度限制:10KB 提交次数:0 通过次数:0 题型: 编程题 语言: G;GCC;VC;JAVA Description 问题描述:设n是一个正整数。 (1)现在将n分解为若干个互不相同的自然…

【毕业设计】电影评论情感分析 - GRU 深度学习

文章目录0 前言1 项目介绍2 情感分类介绍3 数据集4 实现4.1 数据预处理4.2 构建网络4.3 训练模型4.4 模型评估4.5 模型预测5 最后0 前言 🔥 Hi,大家好,这里是丹成学长的毕设系列文章! 🔥 对毕设有任何疑问都可以问学…

手机拍照模糊怎么办?拍摄低像素照片如何修复清晰?

相信有很多人在用手机拍摄照片时自认为应该非常精美,拍完后却发现它模糊不清!最终遗憾地错过了精彩的瞬间,令人非常遗憾!虽然手机不是专业的摄像机,拍摄时模糊在所难免。但是我们可以在前期尽量避免拍摄的照片模糊&…

感冒了吃抗生素有用吗?

点击蓝字 |关注我们 2023年《科学世界》杂志全年订阅现已开启。 现在订阅,立享7.5折,并赠送经典科普图书《从一到无穷大》。通过文末链接,即可登录“科学世界”微店订购。抗生素,简单地说就是杀死细菌的药物。更准确地…

从源码上看,RocketMQ 5.0 跟 RocketMQ 4.x相比增加了哪几个模块

今天来介绍一下 RocketMQ 5.0 源码上的变化。 RocketMQ 5.0 是一个里程碑式的版本,经历了近 5 年的打磨,代码变更达到 60%。 首先看一下源码中模块的变化,如下图: 从图中可以看到,RocketMQ 5.0 主要增加了 4 个模块儿…

2023最新SSM计算机毕业设计选题大全(附源码+LW)之java校园生活互助平台06qe4

对于即将毕业或者即将做课设的同学而言,由于经验的欠缺,面临的第一个难题就是选题,确定好题目之后便是开题报告,如果选题首先看自己学习那些技术,不同技术适合做不同的产品,比如自己会些简单的Java语言&…

DataScience:KNIME工具的简介、安装、使用方法之详细攻略

DataScience:KNIME工具的简介、安装、使用方法之详细攻略 目录 KNIME的简介—数据挖掘与分析工具 1、KNIME软件如何帮助您的数据分析? 1.1、Create 1.2、Productionize 2、KNIME Analytics Platform 3、KNIME Hub KNIME的安装 KNIME的使用方法 1、构建第一…

[附源码]java毕业设计领导干部听课评课管理系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

【深入浅出Spring6】第四期——实例化Bean和Bean的生命周期

一、获取 Bean Spring 提供了多种实例化Bean的方式:【只是表现形式不同,底层都是通过构造方法创建对象的】 通过构造方法实例化 【最简单的方式直接声明bean】通过简单工厂模式实例化 【定义一个简单模式工厂,然后通过工厂的静态方法获得Bea…

P3205 [HNOI2010]合唱队

[HNOI2010]合唱队 题目描述 为了在即将到来的晚会上有更好的演出效果,作为 AAA 合唱队负责人的小 A 需要将合唱队的人根据他们的身高排出一个队形。假定合唱队一共 nnn 个人,第 iii 个人的身高为 hih_ihi​ 米(1000≤hi≤20001000 \le h_i …

Java网络编程套接字

文章目录1、网络编程基础2、Socket套接字2.1 Java数据报套接字通信模型2.2 Java流式套接字通信模型2.3 Socket编程注意事项3、UDP数据报套接字编程4、TCP流式套接字编程1、网络编程基础 在没有网路之前,两个进程只能在同一主机上进行通信,但是无法跨距离…

【kubernetes篇】使用Nfs实现kubernetes持久化存储

引言 在kubernetes使用的过程中,有很多数据需要持久化保存。而kubernetes本身不能实现这样的功能,所以需要提供外部存储来实现。nfs网络文件系统,能良好支持pv动态创建等功能,是一个不错的持久化保存方式。今天将这一部分内容作以…

jsx代码如何变成dom

jsx代码如何变成dom一、三个问题考察对jsx的理解二、jsx的本质以及它和js之间是什么关系?2.1 jsx是什么2.2 和js的关系2.3 jsx的本质三、为什么要用jsx?不用会有什么后果四、jsx背后的功能模块是什么?这个功能模块都做了哪些事情?…

DVWA 之 SQL注入(非盲注)

文章目录SQL注入1.判断是否存在注入,注入是字符型还是数字型2.猜解SQL查询语句中的字段数3.确定显示的字段顺序4.获取当前数据库5.获取数据库中的表6.获取表中的字段名7.下载数据SQL注入 步骤: 1.判断是否存在注入,注入是字符型还是数字型 2…

数据库平滑扩容方案剖析

1. 扩容方案剖析 1.1 扩容问题 在项目初期,我们部署了三个数据库A、B、C,此时数据库的规模可以满足我们的业务需求。为了将数据做到平均分配,我们在Service服务层使用uid%3进行取模分片,从而将数据平均分配到三个数据库中。 如…

4-6 最小生成树Prim,Kruskal(贪心)

4.6最小生成树 Prim,Kruskal(贪心) 一、问题描述 设G (V,E)是无向连通带权图,即一个网络。E中每条边(u,v)的权为 c[u][v]。 如果G的子图G’是一棵包含G的所有顶点的树,则称G’为G的生成树。生成树上各边权的总和称为该生成树的耗费。 在G的所有生成树中…