生成对抗网络CycleGAN

news2024/11/24 22:52:27

1.介绍

论文:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

论文地址:https://arxiv.org/abs/1703.10593

什么是CycleGAN:CycleGAN主要用于图像之间的转换,假设有两个不成对的图像X和Y,算法训练去学习一个“自动相互转换”,训练时不需要成对的配对样本,只需要源域和目标域的图像。训练后网络就能实现对图像源域到目标域的迁移。CycleGAN适用于非配对的图像到图像转换,解决了模型需要成对数据进行训练的困难。

与pix2pixGAN的区别:二者都可以做图像变换,pix2pix模型必须要求成对数据(paired data),而CycleGAN利用非成对数据也能进行训练(unpaired data)。

 2.Cycle-GAN网络架构

相关工作:

GAN,DCGAN,CGAN,pix2pixGAN

CycleGAN其实就是一个 A→B 的单向 GAN 加上一个 B→A 的单向 GAN。两个 GAN 共享两个生成器,然后各自带一个判别器,所以加起来总共有两个判别器和两个生成器。一个单向 GAN 有两个 loss, 故 CycleGAN 加起来总共有四个 loss。

循环一致损失:因为网络需要保证生成的图像必须保留有原 始图像的特性,所以如果我们使用生成器GenratorA-B生 成一张假图像,那么要能够使用另外一个生成器 GenratorB-A来努力恢复成原始图像。此过程必须满足循环一致性。

identity loss:可以理解为,生成器是负责域x到域y的图像生成, 如果输入域y的图片还是应该生成域y的图片。

# 用狗的图像生成猫的图像
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from PIL import Image
import tqdm
import glob

dogs_path = glob.glob('D:\cnn\All Classfication\AlexNet\data/train\Dog/*.jpg') #获取数据集中的.jpg图片
cats_path = glob.glob('D:\cnn\All Classfication\AlexNet\data/train\Cat/*.jpg') #获取数据集中的.jpg图片
# print(cats_path[:3])
# print(dogs_path[:3])
cats_path_test = glob.glob('D:\cnn\All Classfication\AlexNet\data/val\Cat/*.jpg') #获取数据集中的.jpg图片
dogs_path_test = glob.glob('D:\cnn\All Classfication\AlexNet\data/val\Dog/*.jpg') #获取数据集中的.jpg图片

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize((256, 256)),
                                transforms.Normalize(mean=0.5, std=0.5)]) #Normalize为转化到-1~1之间

# 定义数据读取
class SGANDataset(Dataset):
    def __init__(self, imgs_path): #初始化
        super(SGANDataset, self).__init__()
        self.imgs_path     = imgs_path #定义属性

    def __len__(self):
        return len(self.imgs_path)

    def __getitem__(self, index): #对数据切片
        img_path        = self.imgs_path[index]

        # 从文件中读取图像
        pil_img         = Image.open(img_path)
        pil_img         = transform(pil_img)
        return pil_img

# 初始化训练集
dog_dataset = SGANDataset(dogs_path) #创建dataset
cat_dataset = SGANDataset(cats_path) #创建dataset

# 初始化测试集
dog_dataset_test = SGANDataset(dogs_path_test) #创建dataset
cat_dataset_test = SGANDataset(cats_path_test) #创建dataset

dog_dataloader = torch.utils.data.DataLoader(dog_dataset, batch_size=4, shuffle=True)
cat_dataloader = torch.utils.data.DataLoader(cat_dataset, batch_size=4, shuffle=True)

dog_dataloader_test = torch.utils.data.DataLoader(dog_dataset_test, batch_size=4)
cat_dataloader_test = torch.utils.data.DataLoader(cat_dataset_test, batch_size=4)

# cat_bath = next(iter(cat_dataloader)) #查看
# dog_bath = next(iter(dog_dataloader)) #查看
# print(dog_bath.shape) #torch.Size([4, 3, 256, 256])
# print(cat_bath.shape) #torch.Size([4, 3, 256, 256])

# 查看数据集
# plt.figure(figsize=(8, 12))
# for i, (dog, cat) in enumerate(zip(dog_bath[:3], cat_bath[:3])): #zip代表元组
#     # 因为dataset返回的数据是tensor,需要转为numpy格式,因为Normalize为转化到-1~1之间,所以加1再除以2将其转化到0~1之间
#     dog = (dog.permute(1, 2, 0).numpy() + 1) / 2
#     cat = (cat.permute(1, 2, 0).numpy() + 1) / 2
#     plt.subplot(3, 2, 2*i+1)
#     plt.title('dog')
#     plt.imshow(dog)
#     plt.subplot(3, 2, 2*i+2)
#     plt.title('cat')
#     plt.imshow(cat)
# plt.show()


#定义下采样模块
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.InstanceNorm2d(out_channels)

    def forward(self, x, is_bn=True): #is_bn用于确定是否使用bn层,默认为True
        x = self.conv_relu(x)
        if is_bn:
            x = self.bn(x)
        return x

#定义上采样模块
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.upconv_relu = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.InstanceNorm2d(out_channels)

    def forward(self, x, is_drop=False): #is_drop用于确定是否使用drop层,默认为False
        x = self.upconv_relu(x)
        x = self.bn(x)
        if is_drop:
            x = F.dropout2d(x)
        return x

# 定义生成器,包含6个下采样层,6个上采样层
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)     #3,256,256 -- 64,128,128
        self.down2 = Downsample(64, 128)   #64,128,128 -- 128,64,64
        self.down3 = Downsample(128, 256)  #128,64,64 -- 256,32,32
        self.down4 = Downsample(256, 512)  #256,32,32 -- 512,16,16
        self.down5 = Downsample(512, 512)  #512,16,16 -- 512,8,8
        self.down6 = Downsample(512, 512)  #512,8,8 -- 512,4,4

        self.up1 = Upsample(512, 512)      #512,4,4 -- 512,8,8
        self.up2 = Upsample(1024, 512)     #1024,8,8 -- 512,16,16
        self.up3 = Upsample(1024, 256)     #1024,16,16 -- 256,32,32
        self.up4 = Upsample(512, 128)      #512,32,32 -- 128,64,64
        self.up5 = Upsample(256, 64)       #256,64,64 -- 64,128,128
        #128,128,128 -- 3,256,256
        self.last = nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)

        x6 = self.up1(x6, is_drop=True)
        x6 = torch.cat([x6, x5], dim=1)

        x6 = self.up2(x6, is_drop=True)
        x6 = torch.cat([x6, x4], dim=1)

        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x6, x3], dim=1)

        x6 = self.up4(x6)
        x6 = torch.cat([x6, x2], dim=1)

        x6 = self.up5(x6)
        x6 = torch.cat([x6, x1], dim=1)

        x6 = torch.tanh(self.last(x6))

        return x6

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.down1 = Downsample(3, 64)
        self.down2 = Downsample(64, 128)
        self.last = nn.Conv2d(128, 1, 3)

    def forward(self, img):
        x = self.down1(img)
        x = self.down2(x)
        x =torch.sigmoid(self.last(x))
        return x

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 初始化两个生成器
gen_AB = Generator().to(device)
gen_BA = Generator().to(device)

# 初始化两个判别器
dis_A = Discriminator().to(device)
dis_B = Discriminator().to(device)

# 损失函数  1.gan loss  2.cycle consistance loss  3.identity loss
bce_loss = torch.nn.BCELoss()
l1_loss = torch.nn.L1Loss()

# 初始化优化器
# 对两个生成器同时进行优化, 使用itertools.chain对二者同时进行迭代
gen_optimizer = torch.optim.Adam(itertools.chain(gen_AB.parameters(), gen_BA.parameters()), lr=2e-4, betas=(0.5, 0.999))

# 对两个判别器分别进行优化
dis_A_optimizer = torch.optim.Adam(dis_A.parameters(), lr=2e-4, betas=(0.5, 0.999))
dis_B_optimizer = torch.optim.Adam(dis_B.parameters(), lr=2e-4, betas=(0.5, 0.999))

# 绘图函数,将每一个epoch中生成器生成的图片绘制
def gen_img_plot(model, epoch, test_input): # model为gen_AB/gen_BA,test_input
    generate = model(test_input).permute(0, 2, 3, 1).cpu().numpy() #将通道维度放在最后
    test_input = test_input.permute(0, 2, 3, 1).cpu().numpy() #1,3,256,256 -- 1,256,256,3
    plt.figure(figsize=(10, 6))
    display_list = [test_input[0], generate[0]]
    title = ['Input image', 'Generate image']
    for i in range(2):
        plt.subplot(1, 2, i + 1)
        plt.title(title[i])
        plt.imshow((display_list[i]+1)/2) #从-1~1 --> 0~1
        plt.axis('off')
    plt.savefig('./image/image_at_{}.png'.format(epoch))

test_batch = next(iter(dog_dataloader_test)) #batch_size,3,256,256
# 测试输入:选取test_batch中的第一张图片,并添加一个batch_size维度  3,256,256--1,3,256,256
test_input = torch.unsqueeze(test_batch[0], 0).to(device)

# cycleGAN训练
D_loss = []
G_loss = []
epochs = 50
for epoch in range(epochs):
    d_epoch_loss = 0
    g_epoch_loss = 0
    for step, (real_A, real_B) in enumerate(zip(dog_dataloader, cat_dataloader)): #取出真实的狗,猫图片
        real_A = real_A.to(device)
        real_B = real_B.to(device)
        #--------------------begin--------------------#
        # 生成器训练
        gen_optimizer.zero_grad() #训练之前梯度清0
        # identity loss
        same_B = gen_AB(real_B) #真实的B经过生成器gen_AB还是要得到真实的B
        identity_B_loss = l1_loss(same_B, real_B)
        same_A = gen_AB(real_A) #真实的A经过生成器gen_BA还是要得到真实的A
        identity_A_loss = l1_loss(same_A, real_A)
        # 对抗损失 gan loss
        fake_B = gen_AB(real_A) #真实A通过生成器生成了B,此时生成器希望判别器将其判别为真
        D_pred_fake_B = dis_B(fake_B)
        gen_loss_AB = bce_loss(D_pred_fake_B, torch.ones_like(D_pred_fake_B, device=device))
        fake_A = gen_BA(real_B) #真实B通过生成器生成了A,此时生成器希望判别器将其判别为真
        D_pred_fake_A = dis_A(fake_A)
        gen_loss_BA = bce_loss(D_pred_fake_A, torch.ones_like(D_pred_fake_A, device=device))
        # 循环一致损失
        recovered_A = gen_BA(fake_B)
        cycle_loss_ABA = l1_loss(recovered_A, real_A)

        recovered_B = gen_AB(fake_A)
        cycle_loss_BAB = l1_loss(recovered_B, real_B)

        # 生成器总的损失
        g_loss = identity_A_loss + identity_B_loss + gen_loss_AB + gen_loss_BA +cycle_loss_ABA + cycle_loss_BAB

        g_loss.backward()
        gen_optimizer.step()
        # --------------------end--------------------#

        # --------------------begin--------------------#
        # 判别器训练
        # dis_A训练
        dis_A_optimizer.zero_grad()
        dis_A_real_output = dis_A(real_A) #输入为真,期望判定为真
        dis_A_real_loss = bce_loss(dis_A_real_output, torch.ones_like(dis_A_real_output, device=device))

        dis_A_fake_output = dis_A(fake_A.detach())  #输入为假,期望判定为假,梯度截断
        dis_A_fake_loss = bce_loss(dis_A_fake_output, torch.zeros_like(dis_A_fake_output, device=device))

        dis_A_loss = dis_A_real_loss + dis_A_fake_loss #生成器A的总损失
        dis_A_loss.backward()
        dis_A_optimizer.step()

        # dis_B训练
        dis_B_optimizer.zero_grad()
        dis_B_real_output = dis_B(real_B)  #输入为真,期望判定为真
        dis_B_real_loss = bce_loss(dis_B_real_output, torch.ones_like(dis_B_real_output, device=device))

        dis_B_fake_output = dis_B(fake_B.detach())  #输入为假,期望判定为假,梯度截断
        dis_B_fake_loss = bce_loss(dis_B_fake_output, torch.zeros_like(dis_B_fake_output, device=device))

        dis_B_loss = dis_B_real_loss + dis_B_fake_loss #生成器B的总损失
        dis_B_loss.backward()
        dis_B_optimizer.step()
        # --------------------end--------------------#

        with torch.no_grad():
            g_epoch_loss += g_loss.item() #将每一个批次的loss累加
            d_epoch_loss += (dis_A_loss + dis_B_loss).item()  # 将每一个批次的loss累加

    with torch.no_grad():
        g_epoch_loss /= (step + 1) #求得每一轮的平均loss
        d_epoch_loss /= (step + 1) #求得每一轮的平均loss
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('epoch:', epoch, 'g_epoch_loss:', g_epoch_loss, 'd_epoch_loss:', d_epoch_loss)
        gen_img_plot(gen_AB, epoch, test_input)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/467264.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

c++11 标准模板(STL)(std::priority_queue)(三)

适配一个容器以提供优先级队列 std::priority_queue 定义于头文件 <queue> template< class T, class Container std::vector<T>, class Compare std::less<typename Container::value_type> > class priority_queue; priority_queu…

chatgpt接入ROS2控制小海龟

chatgpt接入ROS2控制小海龟 0.前言一、使用功能测试&#xff1a; 二、运行结果三、总结 0.前言 在小破站看到的案例&#xff0c;感觉很有趣就自己尝试复现了一下。需要一个OpenAI API Key、ubuntu以及安装ROS2环境。 一、使用 代码仓库在这里&#xff0c;示例操作可以参考B站视…

DATAX hdfsreader orc格式读取数据丢失问题

最近做一个数据同步任务&#xff0c;从hive仓库同步数据到pg&#xff0c;Hive有4000w多条数据&#xff0c;但datax只同步了280w就结束了&#xff0c;也没有任何报错。 看了下datax源码&#xff0c;找到HdfsReader模块DFSUtil核心实现源码读取orc格式的文件方法&#xff1a; pu…

应用运行环境实时洞察,亚马逊云科技Cisco AppDynamics展优势

Cisco AppDynamics(APM)产品&#xff0c;现已正式上线亚马逊云科技Marketplace&#xff08;中国区域&#xff09;。可以通过亚马逊云科技Marketplace&#xff08;中国区域&#xff09;网站&#xff0c;灵活便捷地部署该解决方案&#xff0c;以便充分利用云原生APM(应用性能管理…

(上)苹果有开源,但又怎样呢?

苹果&#xff08;Apple Inc.&#xff09;有多伟大&#xff0c;我相信已经无需赘述了。但是&#xff0c;这里的伟大是指用产品和理念对行业进行的革命性颠覆&#xff0c;而不是对开源而言。 相反&#xff0c;在某种程度上&#xff0c;苹果几乎就是开源的反义词。这种骨子里的 “…

8个Wireshark使用技巧

一&#xff1a;数据包过滤 过滤需要的IP地址 ip.addr 在数据包过滤的基础上过滤协议ip.addrxxx.xxx.xxx.xxx and tcp 过滤端口ip.addrxxx.xxx.xxx.xxx and http and tcp.port80 指定源地址 目的地址ip.srcxxx.xxx.xxx.xxx and ip.dstxxx.xxx.xxx.xxx SEQ字段&#xff08;序列号…

浅谈 git 底层工作原理

浅谈 git 底层工作原理 系统复习到这里也快差不多了&#xff0c;大概就剩下两三个 sections&#xff0c;这里学习一下 git 的 hashing 和对象。 当然&#xff0c;跳过问题也不大。 config 文件 这里还是会用 redux 的项目&#xff0c;先看一下基本信息&#xff1a; ➜ re…

短视频矩阵系统---开发技术源码能力

短视频矩阵系统开发涉及到多个领域的技术&#xff0c;包括视频编解码技术、大数据处理技术、音视频传输技术、电子商务及支付技术等。因此&#xff0c;短视频矩阵系统开发人员需要具备扎实的计算机基础知识、出色的编程能力、熟练掌握多种开发工具和框架&#xff0c;并掌握音视…

制冷暖通工业互联网平台孵化

制冷暖通工业互联网平台孵化可以帮助初创企业或者创新项目快速建立和推广制冷暖通工业互联网平台。以下是一些常见的制冷暖通工业互联网平台孵化服务&#xff1a; 创业辅导&#xff1a;孵化器提供创业辅导服务&#xff0c;帮助企业或者项目找到合适的市场和商业模式&#xff0c…

sd卡中病毒的表现及sd文件消失后的恢复方法

sd卡在日常使用中十分常见&#xff0c;但有时也会发生一些意外情况。例如&#xff0c;不小心意外感染病毒&#xff0c;导致sd卡中存储的文件消失。那么对于丢失的文件&#xff0c;我们该如何恢复呢&#xff1f;下面将带您了解sd卡中病毒的表现以及sd卡文件消失怎么恢复的方法。…

【C语言】学习路线大纲思维导图

思维导图下载地址&#xff1a;点击跳转   配套专栏&#xff1a;【C语言】基础语法 思维导图 1. 基础语法1.1 变量和数据类型1.2 运算符和表达式1.3 控制流程结构1.4 函数和递归1.5 数组和指针1.6 字符串和字符处理1.7 文件操作 2. 高级特性标准库和常用函数动态内存分配多文件…

理解龙格库塔法基本C程序

先学习龙格-库塔法&#xff1b; 龙格-库塔&#xff0c;Runge-Kutta&#xff0c;该方法用于数值求解微分方程&#xff1b; 其中包括著名的欧拉法&#xff1b; 经典四阶法 该方法主要是在已知方程导数和初值信息&#xff0c;利用计算机仿真时应用&#xff0c;省去求解微分方…

【LeetCode】213. 打家劫舍 II

213. 打家劫舍 II&#xff08;中等&#xff09; 思路 这道题是 198.打家劫舍 的拓展版&#xff0c;区别在于&#xff1a;本题的房间是环形排列&#xff0c;而198.题中的房间是单排排列。 将房间环形排列&#xff0c;意味着第一间房间和最后一间房间不能同时盗窃&#xff0c;因…

虹科分享|不再受支持的Windows系统如何免受攻击?| 自动移动目标防御

传统的微软操作系统(OS)可能会一直伴随着我们&#xff0c;操作系统使用统计数据显示&#xff0c;传统操作系统的总市场份额仍在10%以上。Windows的总安装基数为13亿&#xff0c;大约有1.5亿个终端仍在运行旧版操作系统。 数十万组织的终端和服务器采用不受支持的操作系统。如果…

curl方式调用电商API接口示例 详细介绍

cURL是一个利用URL语法在命令行下工作的文件传输工具&#xff0c;1997年首次发行。它支持文件上传和下载&#xff0c;所以是综合传输工具&#xff0c;但按传统&#xff0c;习惯称cURL为下载工具。cURL还包含了用于程序开发的libcurl。 cURL支持的通信协议有FTP、FTPS、HTTP、H…

数字化工厂:虹科Vuzix AR眼镜在工业制造中的革新应用

随着现代科学技术和新兴需求的快速增长&#xff0c;增强现实(AR)、各种“现实”产品与技术不断涌入创新市场&#xff0c;新兴用例数量正在快速增长&#xff0c;可以肯定&#xff0c;在可预见的未来&#xff0c;AR技术将成为各行各业的生产与工作主流。 增强现实&#xff08;AR&…

应用scrapy爬虫框架

Scrapy是一个基于Python的开源网络爬虫框架&#xff0c;它可以帮助我们快速、高效地抓取网页数据&#xff0c;并支持数据的自动化处理、存储和导出。Scrapy提供了丰富的扩展机制&#xff0c;可以轻松地实现各种自定义需求。 Scrapy的基本使用流程&#xff1a; 1、安装Scrapy框…

服务(第十五篇)HAproxy负载+高可用

HAProxy负载均衡的调度算法&#xff08;策略&#xff09;&#xff1a; &#xff08;1&#xff09;roundrobin&#xff0c;表示简单的轮询 &#xff08;2&#xff09;static-rr&#xff0c;表示根据权重 &#xff08;3&#xff09;leastconn&#xff0c;表示最少连接者先处理 &…

RestTemplate使用不当引发的504及连接池耗尽问题分析

背景 系统&#xff1a; SpringBoot开发的Web应用&#xff1b;ORM: JPA(Hibernate)接口功能简述&#xff1a; 根据实体类ID到数据库中查询实体信息&#xff0c;然后使用RestTemplate调用外部系统接口获取数据。 问题现象 浏览器页面有时报504 GateWay Timeout错误&#xff0c…

C语言函数大全-- r 开头的函数

C语言函数大全 本篇介绍C语言函数大全-- r 开头的函数 1. raise 1.1 函数说明 函数声明函数功能int raise(int sig);用于向当前进程发送指定的信号。 参数&#xff1a; sig &#xff1a; 指定要发送的信号编号 返回值&#xff1a; 如果调用成功&#xff0c;raise() 函数将返…