【深度学习】生成对抗网络(GAN)生成海贼王图像 + PyTorch代码实现

news2025/1/11 0:13:21

文章目录

  • 一、数据集介绍
    • 1.1 lufei(路飞)
    • 1.2 luobin(罗宾)
    • 1.3 namei(娜美)
    • 1.4 qiaoba(乔巴)
    • 1.5 shanzhi(山智)
    • 1.6 suolong(索隆)
    • 1.7 wusuopu(乌索普)
  • 二、项目结构
  • 三、代码实现
    • 3.1 DataSet.py
    • 3.2 Model.py
    • 3.3 Run_Test.py
    • 3.4 Run_Train.py
    • 3.5 TrainFunction.py
    • 3.6 TestFunction.py
    • 3.7 Util.py
  • 四、效果展示
  • 五、RuntimeError: DataLoader worker (pid(s) 26192, 29160) exited unexpectedly


关于生成对抗网络(GAN)的介绍可以参考链接:【深度学习】李宏毅2021/2022春深度学习课程笔记 - Generative Adversarial Network 生成式对抗网络(GAN)

一、数据集介绍

海贼王图像数据集下载链接:hzw-photos.rar

在这里插入图片描述

1.1 lufei(路飞)

在这里插入图片描述

1.2 luobin(罗宾)

在这里插入图片描述

1.3 namei(娜美)

在这里插入图片描述

1.4 qiaoba(乔巴)

在这里插入图片描述

1.5 shanzhi(山智)

在这里插入图片描述

1.6 suolong(索隆)

在这里插入图片描述

1.7 wusuopu(乌索普)

在这里插入图片描述


二、项目结构

其中 data 为数据集的根目录;outputs目录是自动生成的,不用手动创建。其余py文件会在下一章进行介绍

在这里插入图片描述


三、代码实现

3.1 DataSet.py

用来获取数据集,并返回继承于Dataset的自定义的MyDataset对象

import glob
import torchvision.transforms as transforms
import os
import torchvision
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, fnames, transform):
        self.transform = transform
        self.fnames = fnames
        self.num_samples = len(self.fnames)

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = torchvision.io.read_image(fname)
        resize_transform = transforms.Compose([transforms.Resize((64, 64))])
        img = resize_transform(img)
        img = self.transform(img)
        return img

    def __len__(self):
        return self.num_samples


def get_dataset(root):
    fnames = []
    for sub_dir in os.listdir(root):
        fnames.extend(glob.glob(os.path.join(os.path.join(root, sub_dir), '*')))
    compose = [
        transforms.ToPILImage(),
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ]
    transform = transforms.Compose(compose)
    dataset = MyDataset(fnames, transform)
    return dataset

3.2 Model.py

用来定义 Generator 生成器模型和 Discriminator 判别器模型

import torch.nn as nn

# setting for weight init function
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class Generator(nn.Module):
    """
    Input shape: (batch, in_dim)
    Output shape: (batch, 3, 64, 64)
    """

    def __init__(self, in_dim, feature_dim=64):
        super().__init__()

        # input: (batch, 100)
        self.l1 = nn.Sequential(
            nn.Linear(in_dim, feature_dim * 8 * 4 * 4, bias=False),
            nn.BatchNorm1d(feature_dim * 8 * 4 * 4),
            nn.ReLU()
        )
        self.l2 = nn.Sequential(
            self.dconv_bn_relu(feature_dim * 8, feature_dim * 4),  # (batch, feature_dim * 16, 8, 8)
            self.dconv_bn_relu(feature_dim * 4, feature_dim * 2),  # (batch, feature_dim * 16, 16, 16)
            self.dconv_bn_relu(feature_dim * 2, feature_dim),  # (batch, feature_dim * 16, 32, 32)
        )
        self.l3 = nn.Sequential(
            nn.ConvTranspose2d(feature_dim, 3, kernel_size=5, stride=2,
                               padding=2, output_padding=1, bias=False),
            nn.Tanh()
        )
        self.apply(weights_init)

    def dconv_bn_relu(self, in_dim, out_dim):
        return nn.Sequential(
            nn.ConvTranspose2d(in_dim, out_dim, kernel_size=5, stride=2,
                               padding=2, output_padding=1, bias=False),  # double height and width
            nn.BatchNorm2d(out_dim),
            nn.ReLU(True)
        )

    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1, 4, 4)
        y = self.l2(y)
        y = self.l3(y)
        return y


class Discriminator(nn.Module):
    """
    Input shape: (batch, 3, 64, 64)
    Output shape: (batch)
    """

    def __init__(self, in_dim, feature_dim=64):
        super(Discriminator, self).__init__()

        # input: (batch, 3, 64, 64)
        """
        NOTE FOR SETTING DISCRIMINATOR:

        Remove last sigmoid layer for WGAN
        """
        self.l1 = nn.Sequential(
            nn.Conv2d(in_dim, feature_dim, kernel_size=4, stride=2, padding=1),  # (batch, 3, 32, 32)
            nn.LeakyReLU(0.2),
            self.conv_bn_lrelu(feature_dim, feature_dim * 2),  # (batch, 3, 16, 16)
            self.conv_bn_lrelu(feature_dim * 2, feature_dim * 4),  # (batch, 3, 8, 8)
            self.conv_bn_lrelu(feature_dim * 4, feature_dim * 8),  # (batch, 3, 4, 4)
            nn.Conv2d(feature_dim * 8, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )
        self.apply(weights_init)

    def conv_bn_lrelu(self, in_dim, out_dim):
        """
        NOTE FOR SETTING DISCRIMINATOR:

        You can't use nn.Batchnorm for WGAN-GP
        Use nn.InstanceNorm2d instead
        """

        return nn.Sequential(
            nn.Conv2d(in_dim, out_dim, 4, 2, 1),
            nn.BatchNorm2d(out_dim),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        y = self.l1(x)
        y = y.view(-1)
        return y

3.3 Run_Test.py

对训练好的模型进行测试

import json
import pickle
from TestFunction import *
from Model import *
from Util import *
from pathlib import Path

if __name__ == '__main__':
    # 防止报错 OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
    os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

    # load and print config
    with open("./outputs/models/config", "rb") as file:
        config = pickle.load(file)
    print("config:")
    print(json.dumps(config, indent=4, ensure_ascii=False, sort_keys=False, separators=(',', ':')))

    # Set seed for reproducibility
    # same_seed(config['seed'])

    # create test_save_dir
    Path(config['test_save_dir']).mkdir(parents=True, exist_ok=True)

    # init Model and load saved model's parameter
    G = Generator(100).cuda()

    # Load your best model
    ckpt = torch.load(config['model_save_dir'] + "G_0.pth")
    G.load_state_dict(ckpt)

    # test process
    test(G, config)

3.4 Run_Train.py

利用数据集进行训练

import pickle
from torch.utils.data import DataLoader
from TrainFunction import *
from Model import *
from Util import *
from pathlib import Path

if __name__ == '__main__':
    # 防止报错 OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
    os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

    # data path
    data_dir = './data'

    # define config
    config = {
        'seed': 929,  # Your seed number, you can pick your lucky number. :)
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',

        "model_type": "GAN",
        "batch_size": 64,
        "lr": 1e-4,
        "n_epoch": 1000,
        "n_critic": 1,
        "z_dim": 100,
        "workspace_dir": data_dir,  # define in the environment setting

        'model_save_dir': './outputs/models/',  # Your model will be saved here.
        'train_save_dir': './outputs/train/',  # Your model train pred valid data and learning curve will be saved here.
        'test_save_dir': './outputs/test/',  # Your model pred test data will be saved here.
        'log_dir': './outputs/log/',
    }
    print("device:", config['device'])

    # Set seed for reproducibility
    same_seed(config['seed'])

    # create save dir
    Path(config['model_save_dir']).mkdir(parents=True, exist_ok=True)
    Path(config['train_save_dir']).mkdir(parents=True, exist_ok=True)
    Path(config['log_dir']).mkdir(parents=True, exist_ok=True)

    # Get Data
    dataset = get_dataset(os.path.join(config["workspace_dir"]))
    train_loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True, num_workers=2)

    # init Model (Construct model and move to device)
    G = Generator(100).cuda()
    D = Discriminator(3).cuda()

    # train process
    loss_record = train(train_loader, config, G, D)

    # plot and save learning curve and valid_pred image
    plot_GAN_learning_curve(loss_record, save_dir=config['train_save_dir'])

    # save config
    with open(config['model_save_dir'] + "config", "wb") as file:
        pickle.dump(config, file)

3.5 TrainFunction.py

训练过程的函数

import logging
import torch
from matplotlib import pyplot as plt
from torch.autograd import Variable
from tqdm import tqdm
from DataSet import *
from Model import *


def train(train_loader, config, G, D):
    # Setup optimizer
    opt_D = torch.optim.Adam(D.parameters(), lr=config["lr"], betas=(0.5, 0.999))
    opt_G = torch.optim.Adam(G.parameters(), lr=config["lr"], betas=(0.5, 0.999))

    # loss function
    loss = nn.BCELoss()

    FORMAT = '%(asctime)s - %(levelname)s: %(message)s'
    logging.basicConfig(level=logging.INFO,
                        format=FORMAT,
                        datefmt='%Y-%m-%d %H:%M')

    z_samples = Variable(torch.randn(100, config["z_dim"])).cuda()

    loss_record = {"G": [], "D": []}
    steps = 0
    for e, epoch in enumerate(range(config["n_epoch"])):
        progress_bar = tqdm(train_loader)
        progress_bar.set_description(f"Epoch {e + 1}")
        for i, data in enumerate(progress_bar):
            imgs = data.cuda()
            bs = imgs.size(0)

            # *********************
            # *    Train D        *
            # *********************
            z = Variable(torch.randn(bs, config["z_dim"])).cuda()
            r_imgs = Variable(imgs).cuda()
            f_imgs = G(z)
            r_label = torch.ones((bs)).cuda()
            f_label = torch.zeros((bs)).cuda()

            # Discriminator forwarding
            r_logit = D(r_imgs)
            f_logit = D(f_imgs)

            """
            NOTE FOR SETTING DISCRIMINATOR LOSS:

            GAN: 
                loss_D = (r_loss + f_loss)/2
            WGAN: 
                loss_D = -torch.mean(r_logit) + torch.mean(f_logit)
            WGAN-GP: 
                gradient_penalty = gp(r_imgs, f_imgs)
                loss_D = -torch.mean(r_logit) + torch.mean(f_logit) + gradient_penalty
            """
            # Loss for discriminator
            r_loss = loss(r_logit, r_label)
            f_loss = loss(f_logit, f_label)
            loss_D = (r_loss + f_loss) / 2

            # Discriminator backwarding
            D.zero_grad()
            loss_D.backward()
            opt_D.step()

            """
            NOTE FOR SETTING WEIGHT CLIP:

            WGAN: below code
            """
            # for p in D.parameters():
            #     p.data.clamp_(-config["clip_value"], config["clip_value"])

            # *********************
            # *    Train G        *
            # *********************
            if steps % config["n_critic"] == 0:
                # Generate some fake images.
                z = Variable(torch.randn(bs, config["z_dim"])).cuda()
                f_imgs = G(z)

                # Generator forwarding
                f_logit = D(f_imgs)

                """
                NOTE FOR SETTING LOSS FOR GENERATOR:

                GAN: loss_G = loss(f_logit, r_label)
                WGAN: loss_G = -torch.mean(D(f_imgs))
                WGAN-GP: loss_G = -torch.mean(D(f_imgs))
                """
                # Loss for the generator.
                loss_G = loss(f_logit, r_label)

                # Generator backwarding
                G.zero_grad()
                loss_G.backward()
                opt_G.step()

            loss_record["G"].append(loss_G.item())
            loss_record["D"].append(loss_D.item())

            if steps % 10 == 0:
                progress_bar.set_postfix(loss_G=loss_G.item(), loss_D=loss_D.item())
            steps += 1

        # 验证生成器效果,输出图片
        G.eval()
        f_imgs_sample = (G(z_samples).data + 1) / 2.0
        filename = os.path.join(config['log_dir'], f'Epoch_{epoch + 1:03d}.jpg')
        torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
        logging.info(f'Save some samples to {filename}.')

        # Show some images during training.
        grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
        plt.figure(figsize=(10, 10))
        plt.imshow(grid_img.permute(1, 2, 0))
        plt.show()

        G.train()

        # Save the checkpoints.
        torch.save(G.state_dict(), os.path.join(config['model_save_dir'], f'G_{e}.pth'))
        torch.save(D.state_dict(), os.path.join(config['model_save_dir'], f'D_{e}.pth'))

    logging.info('Finish training')
    return loss_record

3.6 TestFunction.py

测试过程的代码

import logging
import os
import torch
import torchvision
from matplotlib import pyplot as plt
from torch.autograd import Variable


def test(G, config):
    G.eval()
    z_samples = Variable(torch.randn(100, config["z_dim"])).cuda()
    FORMAT = '%(asctime)s - %(levelname)s: %(message)s'
    logging.basicConfig(level=logging.INFO,
                        format=FORMAT,
                        datefmt='%Y-%m-%d %H:%M')

    f_imgs_sample = (G(z_samples).data + 1) / 2.0
    filename = os.path.join(config['test_save_dir'], f'test.jpg')
    torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
    logging.info(f'Save some samples to {filename}.')
    # Show some images during training.
    grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
    plt.figure(figsize=(10, 10))
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.show()

3.7 Util.py

工具类

import numpy as np
import torch
import random
import matplotlib.pyplot as plt


def same_seed(seed):
    """
    Fixes random number generator seeds for reproducibility
    固定时间种子。由于cuDNN会自动从几种算法中寻找最适合当前配置的算法,为了使选择的算法固定,所以固定时间种子
    :param seed: 时间种子
    :return: None
    """
    torch.backends.cudnn.deterministic = True  # 解决算法本身的不确定性,设置为True 保证每次结果是一致的
    torch.backends.cudnn.benchmark = False  # 解决了算法选择的不确定性,方便复现,提升训练速度
    np.random.seed(seed)  # 按顺序产生固定的数组,如果使用相同的seed,则生成的随机数相同, 注意每次生成都要调用一次
    torch.manual_seed(seed)  # 手动设置torch的随机种子,使每次运行的随机数都一致
    random.seed(seed)
    if torch.cuda.is_available():
        # 为GPU设置唯一的时间种子
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

def plot_GAN_learning_curve(loss_record, save_dir=None):
    x = range(len(loss_record['G']))
    # Generator
    plt.figure()
    plt.plot(x, loss_record['G'], c='tab:red')
    plt.xlabel('Generator Updates')
    plt.ylabel('Loss')
    plt.title('Learning curve of Generator')
    plt.grid(False)
    if save_dir is not None:
        plt.savefig(save_dir + "generator_learning_curve.svg")
    plt.show()

    # Discriminator
    plt.figure()
    plt.plot(x, loss_record['D'], c='tab:red')
    plt.xlabel('Discriminator Updates')
    plt.ylabel('Loss')
    plt.title('Learning curve of Discriminator')
    plt.grid(False)
    if save_dir is not None:
        plt.savefig(save_dir + "discriminator_learning_curve.svg")
    plt.show()

四、效果展示

由于训练到一半内存不够报错了,我只训练了 506 个epoch

第 1 个epoch

在这里插入图片描述

第 10 个epoch

在这里插入图片描述

第 50 个epoch

在这里插入图片描述

第 100 个epoch

在这里插入图片描述

第 200 个epoch

在这里插入图片描述

第 300 个epoch

在这里插入图片描述

第 400 个epoch

在这里插入图片描述

第 506 个epoch

在这里插入图片描述


五、RuntimeError: DataLoader worker (pid(s) 26192, 29160) exited unexpectedly

下图所示就是我运行时报的错,查阅网上的回答,说最可能的原因是:cuda 虚拟环境的共享内存不足,解决办法是改成更小的 batch_size

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/134638.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【内核笔记】2022回顾-2023规划

文章目录一、2022博客总结二、2023博客规划三、改进沉淀、分享、成长,让自己和他人都能有所收获!😄 一、2022博客总结 2022年重点优化了🚀RK3399平台开发入门到精通系列专栏,订阅量突破2100个订阅2022年浏览量破190万2…

《梁启超家书》笔记三——交友取益,或读书取益,也要方面稍多,才有接谈交换,或开卷引进的机会

目录 1、学科选择建议 2、学习 3、所学不可太专向(不知道学科外其他方向的了) 4、建议欧游实地开开眼界 5、建议结婚前分开两地住 6、毕业后工作看法(分析利弊) 7、处事态度与人生观 1、学科选择建议 我很想你以生物学为主科,因为它是…

Vue(五)

1. 绑定class和style样式 dom里自己拿到节点,改它上面的class样式 shift()移除第一个class属性 push想加哪个class属性都行 59行div知道自己只能用两个样式,四种情况 上面这种不合法,就算把样式加进div也得 样式对象:里面的key绿…

【ElasticSearch7.X】学习笔记(一)

【ElasticSearch7.X】学习笔记一、介绍1.1、ElasticSearch1.2、RESTful1.3、数据格式二、下载安装(单机部署)2.1、解压文件2.2、修改配置配置2.3、创建用户并授权2.5、启动三、增删查改3.1、 索引操作3.1.1、增加索引3.1.2、查看索引3.1.2.1、查看所有索…

定时任务介绍及X-Job和E-Job简单对比

下面对elastic-Job和xxl-Job做个简单的描述和对比 一、区别 1、XJob环境依赖于mysql,不依赖ZooKeeper, 2、EJob依赖ZooKeeper1、X-Job 侧重的业务实现的简单和管理的方便,学习成本简单, 失败策略和路由策略丰富。推荐使用在“用户…

对elasticsearch的实际应用

有业务需求,要做一个搜索接口,从千万的数据中根据筛选项查到需要的数据 筛选项有 1.学段。高中,初中等 2.学科。数学,语文等 3.专区。同步专区,高考专区等 4.试卷类型。月考卷,期中考试卷等 5.年级。高一&a…

MyBatis-Plus常用注解

常用注解 TableName 在使用MyBatis-Plus实现基本的CRUD时,我们并没有指定要操作的表,只是在Mapper接口继承BaseMapper时,设置了泛型User,而操作的表为user表 由此得出结论,MyBatis-Plus在确定操作的表时,…

基于 Tensorflow 2.x 从零训练花卉图像识别模型

一、数据集准备 本篇文章使用数千张花卉照片作为数据集,共分为5个分类:雏菊(daisy)、蒲公英(dandelion)、玫瑰(roses)、向日葵(sunflowers)、郁金香&#xf…

ClickHouse中的数据类型

ClickHouse中的数据类型数据类型说明数值类型IntFloatDecimal字符串类型StringFixedStringUUID时间类型DateDateTimeDateTIme64复杂类型ArrayEnumTupleNested特殊类型IPV4GEOBoolean 和 Nullable[参考文章]: ClickHouse | 数据类型介绍 数据类型说明 …

Docker | docker安装portainer详细步骤

portainer是一款容器管理可视化界面,不想在虚拟中使用命令管理容器的小伙伴,可以选择安装portainer对容器进行管理,查看日志、启动、停止容器等非常方便。 1. 搜索portainer镜像 docker search portainer 2. 拉取portainer镜像 docker pull…

[torch]日志记录之SummaryWriter(持更)

参考博文: 《tensorboardX的简单使用》《详解PyTorch项目使用TensorboardX进行训练可视化》《使用TensorBoard检查监控深度学习模型》 一、使用方法 class SummaryWriter(builtins.object)| SummaryWriter(log_dirNone, comment, purge_stepNone, max_queue10, f…

Springboot扩展点之ApplicationContextInitializer

前言 一提到Spring、Springoboot,很多人马上就会想到依赖注入、控制反转、自动装配、约定大于配置、使开发变得简单等等。但是如果仅仅会使用Springboot、SpringMVC完成一些增删改查,解决一些bug,那么实际上你并没有真的懂Spring、Sp…

每天5分钟机器学习:线性判别分析LDA算法

本文重点 我们前面学习了PCA算法,本文我们将学习一种新的算法,这种算法和PCA非常相似,这个算法叫做线性判别分析,简称为LDA,也称为Fisher线性判别(Fisher Linear Discriminant,FLD),是模式识别的经典算法,在1996年由Belhumeur引入模式识别和人工智能领域。 LDA原理 …

毒鸡汤 | 云服务器部署项目

文章目录前言环境介绍安装PHP安装Apache安装MySQL上传项目访问网站前言 新年的第一篇文章,今年继续努力加油,继续产出高质量的文章。 前面用PHPstudy搭建了【毒鸡汤】项目,但是这样只能自己访问,别人是访问不了的,因此…

Redis 持久化机制

个人博客地址: 文章目录Redis 持久化机制1. 概述2. RDB2.1 手动快照2.2 自动快照3. AOF3.1 AOF机制3.2 AOF文件的重写4. 混合型持久化5. 总结Redis 持久化机制 1. 概述 Redis​为了保证性能,会将所有数据放在内存中,那么万一Redis宕机&…

AcWing 756. 蛇形矩阵

目录问题描述思路与代码1. 个人解法2. 官方题解问题描述 原题链接🔗:756. 蛇形矩阵 输入两个整数 nnn 和 mmm,输出一个 nnn 行 mmm 列的矩阵,将数字 111 到 nmnmnm 按照回字蛇形填充至矩阵中。 具体矩阵形式可参考样例。 输入格…

synchronized底层原理

synchronized是JVM内置锁,基于monitor机制实现,依赖底层操作系统的互斥源语Mutex(互斥量),它是一个重量级锁,性能较低。 当然,JVM内置锁在1.5之后版本做了重大优化,如锁粗化&#xf…

[红明谷CTF 2021]JavaWeb

0x01 好久没打过ctf了,最近也在学Java 就看下java的题吧 WP 进入环境就提示访问 /login ,访问之后 提示的 /json ,试着访问一下 给了 jessid,继续访问 /json 的话 又会跳转至 /login 应该是要传点username password 果然是了&#…

web制作网页

HTML(人的骨架): 标签 1.超文本标记性语言 2.当前最新版HTML5 3.URL统一资源定位器(网址) 4.网站首页名称index.html,default.htm,main.heml 基本结构: 注意:1.标签都是小写 2.标签成对写 3.代码要层次缩进…

JavaScript内存管理

JavaScript 是一个弱类型的、动态语言,在执行一段 JS 代码时,需要经历编译、执行、内存回收阶段。 一、编译阶段 JS 代码执行时,会首先创建全局执行上下文、以及函数执行上下文,上下文的执行顺序按照栈的方式进行调用&#xff0c…