生成对抗网络pix2pixGAN

news2025/1/10 11:52:37

1.介绍

论文:Image-to-Image Translation with Conditional Adversarial Networks

论文地址:https://arxiv.org/abs/1611.07004

图像处理的很多问题都是将一张输入的图片转变为一张对应的 输出图片,比如灰度图、彩色图之间的转换、图像自动上色等。

什么是 pix2pixGAN:pix2pixGAN主要用于图像之间的转换,又称图像翻译。作者证明了这种方法在从标签图合成照片(synthesizing photos from label map)、从边缘图重建对象(reconstructing objects from edge maps)以及给图像上色(colorizing images)等多种任务中是有效的。

与普通GAN的区别:普通GAN的生成器G输入的是随机向量(噪声),输出是图像; 判别器D接收的输入是图像(生成的或是真实的),输出是对或者错 。这样G和D联手就能输出真实的图像。Pix2pixGAN本质上是一个cGAN,图片x作为此cGAN的条件, 输入到生成器G中。G的输出是生成的图片G(x)。 D则需要分辨出{x,G(x)}和{x, y}。其中x是需要转换的图片,y是x对应的真实图片。

2.生成器与判别器的设计

生成器G的设计:生成器G采用了Encoder-Decoder模型,参考U-Net的结构。

判别器D的设计:D中要输入成对的图像。判别器D的输入与cGAN中的不同,因为除了要生成真实图像之外,还要保证生成的图像和输入图像是匹配的。Pix2Pix论文中将判别器D实现为Patch-D,所谓Patch,是指无论生成的图像有多大,将其切分为多个固定大小的Patch输入进D去判断。这样设计的好处是:D的输入变小,计算量小,训练速度快。

3.损失函数

D网络损失函数(使用二元交叉熵损失BCELoss)

输入真实的成对图像希望判定为1,即{x, y};输入原图与生成图像希望判定为0,即{x,G(x)}。

G网络损失函数(使用二元交叉熵损失BCELoss和L1loss)

L1loss保证输入和输出之间的一致性;

输入原图与生成图像希望判定为1,即{x,G(x)}。

4.模型搭建 

import torch
from PIL import Image
import os
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torch.utils.data.dataset import Dataset
import tqdm
import glob


imgs_path = glob.glob('D:\cnn\All_Classfication/base_data/train/*.jpg') #获取训练集中的.jpg图片
annos_path = glob.glob('D:\cnn\All_Classfication/base_data/train/*.png') #获取训练集中的.png图片

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize((256, 256)),
                                transforms.Normalize(mean=0.5, std=0.5)]) #Normalize为转化到-1~1之间

# 定义数据读取
class GANDataset(Dataset):
    def __init__(self, imgs_path, annos_path): #初始化
        super(GANDataset, self).__init__()
        self.imgs_path     = imgs_path #定义属性
        self.annos_path   = annos_path#定义属性

    def __len__(self):
        return len(self.imgs_path)

    def __getitem__(self, index): #对数据切片
        img_path        = self.imgs_path[index]
        anno_path = self.annos_path[index]

        # 从文件中读取图像
        jpg         = Image.open(img_path)
        jpg         = transform(jpg)

        png         = Image.open(anno_path)
        png         = png.convert('RGB') #因为anno_path为单通道图片,使用convert方法还原回三通道
        png         = transform(png)
        return jpg, png

train_dataset = GANDataset(imgs_path, annos_path) #创建dataset
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)

jpg_batch, png_batch = next(iter(dataloader)) #查看,返回一个批次的训练数据
# print(jpg_bath.shape)
# print(png_bath.shape)

# 查看训练集
# plt.figure(figsize=(8, 12))
# for i, (anno, img) in enumerate(zip(png_batch[:3], jpg_batch[:3])): #zip代表元组
#     # 因为dataset返回的数据是tensor,需要转为numpy格式,因为Normalize为转化到-1~1之间,所以加1再除以2将其转化到0~1之间
#     anno = (anno.permute(1, 2, 0).numpy() + 1) / 2
#     img = (img.permute(1, 2, 0).numpy() + 1) / 2
#     plt.subplot(3, 2, 2*i+1)
#     plt.title('input_img')
#     plt.imshow(anno)
#     plt.subplot(3, 2, 2*i+2)
#     plt.title('output_img')
#     plt.imshow(img)
# plt.show()

#定义下采样模块
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x, is_bn=True): #is_bn用于确定是否使用bn层,默认为True
        x = self.conv_relu(x)
        if is_bn:
            x = self.bn(x)
        return x

#定义上采样模块
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.upconv_relu = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x, is_drop=False): #is_drop用于确定是否使用drop层,默认为False
        x = self.upconv_relu(x)
        x = self.bn(x)
        if is_drop:
            x = F.dropout2d(x)
        return x

# 定义生成器,包含6个下采样层,6个上采样层
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)     #3,256,256 -- 64,128,128
        self.down2 = Downsample(64, 128)   #64,128,128 -- 128,64,64
        self.down3 = Downsample(128, 256)  #128,64,64 -- 256,32,32
        self.down4 = Downsample(256, 512)  #256,32,32 -- 512,16,16
        self.down5 = Downsample(512, 512)  #512,16,16 -- 512,8,8
        self.down6 = Downsample(512, 512)  #512,8,8 -- 512,4,4

        self.up1 = Upsample(512, 512)      #512,4,4 -- 512,8,8
        self.up2 = Upsample(1024, 512)     #1024,8,8 -- 512,16,16
        self.up3 = Upsample(1024, 256)     #1024,16,16 -- 256,32,32
        self.up4 = Upsample(512, 128)      #512,32,32 -- 128,64,64
        self.up5 = Upsample(256, 64)       #256,64,64 -- 64,128,128
        #128,128,128 -- 3,256,256
        self.last = nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)

        x6 = self.up1(x6, is_drop=True)
        x6 = torch.cat([x6, x5], dim=1)

        x6 = self.up2(x6, is_drop=True)
        x6 = torch.cat([x6, x4], dim=1)

        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x6, x3], dim=1)

        x6 = self.up4(x6)
        x6 = torch.cat([x6, x2], dim=1)

        x6 = self.up5(x6)
        x6 = torch.cat([x6, x1], dim=1)

        x6 = torch.tanh(self.last(x6))

        return x6

# 定义判别器   将条件(anno)与图片(生成的或真实的)同时输入到判别器中进行判定  concat
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.down1 = Downsample(6, 64)
        self.down2 = Downsample(64, 128)
        self.conv1 = nn.Conv2d(128, 256, 3)
        self.bn = nn.BatchNorm2d(256)
        self.last = nn.Conv2d(256, 1, 3)

    # 判别器的输入为成对的图片,anno为结构图,img为真实的或生成的图片
    def forward(self, anno, img):
        x = torch.cat([anno, img], dim=1) #batch_size,6,256,256
        x = self.down1(x, is_bn=False) #batch_size,64,128,128
        x = self.down2(x) #batch_size,128,64,64
        x = F.dropout2d(self.bn(F.leaky_relu(self.conv1(x)))) #batch_size,256,62,62
        x = torch.sigmoid(self.last(x)) #batch_size,1,60,60
        return x


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

gen = Generator().to(device)
dis = Discriminator().to(device)

# 判别器优化器
d_optimizer = torch.optim.Adam(dis.parameters(), lr=1e-4, betas=(0.5, 0.999)) #通过减小判别器的学习率降低其能力
# 生成器优化器
g_optimizer = torch.optim.Adam(gen.parameters(), lr=1e-3, betas=(0.5, 0.999))

# 绘图函数,将每一个epoch中生成器生成的图片绘制
def gen_img_plot(model, epoch, test_anno, test_real): # model为Generator,test_anno为结构图,test_real为真实图片
    generate = model(test_anno).permute(0, 2, 3, 1).detach().cpu().numpy() #detach()截断梯度,将通道维度放在最后
    test_anno = test_anno.permute(0, 2, 3, 1).cpu().numpy() #1,3,256,256 -- 1,256,256,3
    test_real = test_real.permute(0, 2, 3, 1).cpu().numpy() #1,3,256,256 -- 1,256,256,3
    plt.figure(figsize=(10, 10))
    title = ['Input image', 'Ground truth', 'Generate image']
    display_list0 = [test_anno[0], test_real[0], generate[0]]
    for i in range(3):
        plt.subplot(3, 3, i + 1)
        plt.title(title[i])
        plt.imshow((display_list0[i]+1)/2) #从-1~1 --> 0~1
        plt.axis('off')
    display_list1 = [test_anno[1], test_real[1], generate[1]]
    for i in range(3,6):
        plt.subplot(3, 3, i + 1)
        # plt.title(title[i])
        plt.imshow((display_list1[i-3]+1)/2) #从-1~1 --> 0~1
        plt.axis('off')
    display_list2 = [test_anno[2], test_real[2], generate[2]]
    for i in range(6,9):
        plt.subplot(3, 3, i + 1)
        # plt.title(title[i])
        plt.imshow((display_list2[i-6]+1)/2) #从-1~1 --> 0~1
        plt.axis('off')
    # plt.show()
    plt.savefig('./imageP2P/image_at_{}.png'.format(epoch))

test_imgs_path = glob.glob('D:\cnn\All_Classfication/base_data/val/*.jpg') #获取验证集中的.jpg图片
test_annos_path = glob.glob('D:\cnn\All_Classfication/base_data/val/*.png') #获取验证集中的.png图片

test_dataset = GANDataset(test_imgs_path, test_annos_path) #创建dataset
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=True)

imgs_batch, annos_batch = next(iter(test_dataloader)) #查看,返回一个批次的测试数据
# print(jpg_bath.shape)
# print(png_bath.shape)

# 查看测试集
# plt.figure(figsize=(8, 12))
# for i, (anno, img) in enumerate(zip(annos_batch[:3], imgs_batch[:3])): #zip代表元组
#     # 因为dataset返回的数据是tensor,需要转为numpy格式,因为Normalize为转化到-1~1之间,所以加1再除以2将其转化到0~1之间
#     anno = (anno.permute(1, 2, 0).numpy() + 1) / 2
#     img = (img.permute(1, 2, 0).numpy() + 1) / 2
#     plt.subplot(3, 2, 2*i+1)
#     plt.title('input_img')
#     plt.imshow(anno)
#     plt.subplot(3, 2, 2*i+2)
#     plt.title('output_img')
#     plt.imshow(img)
# plt.show()

annos_batch, imgs_batch = annos_batch.to(device), imgs_batch.to(device)

# 定义cGAN损失
loss_fn = torch.nn.BCELoss() # 二元交叉熵损失
LAMBDA = 7 #L1损失的权重

# pix2pixGAN训练
D_loss = []
G_loss = []

for epoch in range(100):
    D_epoch_loss = 0 #记录判别器每个epoch损失
    G_epoch_loss = 0 #记录生成器每个epoch损失
    count = len(dataloader) #len(dataloader)返回批次数
    count1 = len(train_dataset) #len(train_dataset)返回样本数
    for step, (imgs, annos) in enumerate(tqdm.tqdm(dataloader)): #注意dataloader输出的图片和标签的顺序
        annos = annos.to(device)
        imgs = imgs.to(device)

        #-------------------------------------#
        # 判别器损失
        d_optimizer.zero_grad()
        disc_real_output = dis(annos, imgs) #输入真实的成对图像希望判定为1,即{x, y}
        d_real_loss = loss_fn(disc_real_output, torch.ones_like(disc_real_output, device=device))
        d_real_loss.backward()  # 反向传播

        gen_output = gen(annos) #结构图通过生成器生成图片
        disc_gen_output = dis(annos, gen_output.detach()) #输入原图与生成图像希望判定为0,即{x,G(x)}
        d_fake_loss = loss_fn(disc_gen_output, torch.zeros_like(disc_gen_output, device=device))
        d_fake_loss.backward()  # 反向传播

        # 判别器总损失
        disc_loss = d_real_loss + d_fake_loss
        d_optimizer.step() #优化
        # -------------------------------------#

        # -------------------------------------#
        # 生成器损失
        g_optimizer.zero_grad()
        disc_gen_out = dis(annos, gen_output) #输入原图与生成图像希望判定为1,即{x,G(x)}
        gen_loss_celoss = loss_fn(disc_gen_out, torch.ones_like(disc_gen_out, device=device))

        gen_l1_loss = torch.mean(torch.abs(gen_output - imgs)) #L1loss度量生成图像与原结构图之间的距离
        # 生成器总损失
        gen_loss = gen_loss_celoss + LAMBDA*gen_l1_loss
        gen_loss.backward()  #反向传播
        g_optimizer.step() #优化
        # -------------------------------------#

        with torch.no_grad():
            D_epoch_loss += disc_loss.item()  # 将每一个批次的loss累加
            G_epoch_loss += gen_loss.item()  # 将每一个批次的loss累加

    with torch.no_grad():
        D_epoch_loss /= count  # 求得每一轮的平均loss
        G_epoch_loss /= count  # 求得每一轮的平均loss
        D_loss.append(D_epoch_loss)
        G_loss.append(G_epoch_loss)
        print('epoch:', epoch)
        gen_img_plot(gen, epoch, annos_batch, imgs_batch)

        plt.figure(figsize=(10, 10))
        plt.plot(range(1, len(D_loss) + 1), D_loss, label='D_loss')
        plt.plot(range(1, len(G_loss) + 1), G_loss, label='G_loss')
        plt.xlabel('epoch')  # 横轴名称
        plt.legend()
        plt.savefig('./imageP2P/loss.png')  # 保存图片


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/461155.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【JavaEE】SpringMVC_day02

今日内容 完成SSM的整合开发能够理解并实现统一结果封装与统一异常处理能够完成前后台功能整合开发掌握拦截器的编写 1,SSM整合 前面我们已经把Mybatis、Spring和SpringMVC三个框架进行了学习,今天主要的内容就是把这三个框架整合在一起完成我们的业务功…

网络基础-IP和端口号以及认识传输层协议

概念回顾 MAC地址仅需要在同一个局域网下唯一,就可以保证不会出现通讯问题。 通信的目的是两台机器上的应用软件要通信。即客户端进程和服务端进程要获取这个数据,借助主机来完成通信。故将数据在主机间转发仅仅是手段,机器收到后&#xff…

为什么别的测试工程师年薪30W,而你做不到?

最近收到一位同学的私信: “看到了这个岗位想去应聘,但任职要求熟悉Shell、Python、Java其中的一种语言。软件测试工程师不是对编程代码要求不高吗?我如果学习应该选择Java还是Python?” 对于刚入行的测试新人来说,在求…

深入理解Linux内核(第三版)- 进程切换

为了控制进程的执行,内核必须有能力挂起正在CPU上运行的进程,并恢复以前挂起的某个进程的执行。这种行为被称为进程切换(process switch)、任务切换(task switch)或上下文切换(context switch&a…

手把手教你Java实现栈和队列

目录 一、栈(Stack) 1、概念 2、栈的使用 3、栈的模拟实现 4、栈的应用场景 2. 队列(Queue) 1、概念 2、队列的使用 3、队列模拟实现 4、循环队列 三、双端队列 (Deque) 五、栈和队列的互相实现 用队列实现栈: 用栈实现队列: 一、栈(St…

【剑指offer】(2)

系列文章目录 剑指offer系列是一本非常著名的面试题目集,旨在帮助求职者提升编程能力和应对面试的能力。 文章目录 系列文章目录[TOC](文章目录) 前言一、 用两个栈实现队列🔥 思路🌈代码 二、青蛙跳台阶问题🔥 思路&#x1f308…

git从入门到卸载

git是什么? 从git的官网Git可以找到: Git is a free and open source distributed version control system designed to handle everything from small to very large projects with speed and efficiency. Git is easy to learn and has a tiny footpr…

SANGFOR防火墙如何查看现网运行参数

环境: 防火墙 8.0.48 AF-1000BB1510 问题描述: 公司防火墙设备使用2年多了 AF-2000-FH2130B-SC;性能参数:网络层吞吐量:20G,应用层吞吐量:9G,防病毒吞吐量:1.5G,IPS吞…

python基础实战4-python基础语法

1、注释(Comments) 注释用来向用户提示或解释某些代码的作用和功能,它可以出现在代码中的任何位置。 Python解释器在执行代码时会忽略注释,不做任何处理,就好像它不存在一样。 1.1 代码注释介绍 注释就是对代码的解…

计算机组成原理 指令系统(1)

本文是HIT计算机组成原理上课笔记,由于唐书有些内容讲的比较抽象,添加了一些王道的图片加以补充。 回忆计算机的工作过程 代码被编译器翻译成与之对等的机器指令,除了指令之外还会有一些数据同时被放到主存里 机器指令 指令格式 一条指令是…

第十四章 代理模式

文章目录 前言一、静态代理完整代码接口 ITeacherDao (代理类和被代理类都需要实现这个接口)被代理类 TeacherDao代理类 TeacherDaoProxy测试类 Client 二、JDK动态代理完整代码接口 ITeacher实现类TeacherDao代理工厂 ProxyFacyoryclient 测试 三、Cgli…

Java阶段二Day09

Java阶段二Day09 文章目录 Java阶段二Day09DQLSELECT基础查询全部查询WHERE子句连接多个条件ORDER BY子句分页查询在SELECT子句中使用函数在WHERE中使用表达式别名聚合函数 教师总结DQL语言-数据查询语言语法基础查询语法例 WHERE子句例连接多个条件例AND的优先级高于OR IN(列表…

vue使用原生bootstrap-fileinput无效(未解决)

这篇只记录一下踩到的坑,由于时间关系,此问题未解决 起因:要求替换项目框架,原先jq要替换成vue。之前bootstrap中自带的文件上传插件自带很多功能,上传进度条、上传内容预览等非常方便(如图)&a…

Netty核心源码分析(四)心跳检测源码分析

文章目录 系列文章目录一、心跳检测案例二、源码分析1、Netty心跳的三个Handler2、IdleStateHandler源码(1)四个关键属性(2)handlerAdded方法(3)四个内部类 3、读事件的run方法——ReaderIdleTimeoutTask4、…

easyrecovery16最新数据恢复软件密钥使用方法教程

easyrecovery是一款专业的数据恢复软件,其最新版本为easyrecovery2023将于2022年底发布。总之,easyrecovery是一款功能齐全、性能稳定的专业数据恢复软件,无论删除文件、格式化分区或磁盘故障,它都可以提供最高的恢复成功率。值得个人用户选用。此版本在功能和性能上有较大提升…

支持中英双语和多种插件的开源对话语言模型,160亿参数

一、开源项目简介 MOSS是一个支持中英双语和多种插件的开源对话语言模型,moss-moon系列模型具有160亿参数,在FP16精度下可在单张A100/A800或两张3090显卡运行,在INT4/8精度下可在单张3090显卡运行。MOSS基座语言模型在约七千亿中英文以及代码…

HTB靶机-Lame-WP

Lame 简介: Lame is a beginner level machine, requiring only one exploit to obtain root access. It was the first machine published on Hack The Box and was often the first machine for new users prior to its retirement Tags: Injection, C…

Midjourney 注册 12 步流程教学

原文: https://bysocket.com/midjourney-register/ 先推荐一个 PromptHero 中文官网 https://promptheroes.cn/ :Prompt Heroes 官网是提供 AI 绘画相关提示词中文网站,包括 Midjourney(MJ)、 Stable Diffusion、DALL…

printf,echo,cat指令与输出重定向>,输入重定向<与追加重定向>>等

printf指令的功能(输出/追加重定向) 语法:printf “格式化数据” (>/>>重定向)功能:格式化输出(默认往显示器文件且不带换行符) 实例演示 echo指令的功能(输出/追加重定向) 语法&am…

使用chatgpt分析 too many open files 问题-未验证

java.io.IOException: Too many open files 怎么能定位到时哪行代码出的问题 ? 2023/4/25 19:46:33 当出现类似 "java.io.IOException: Too many open files" 的错误时,通常是因为程序打开了过多的文件句柄(File Handles&#xff…