迁移学习——CycleGAN

news2024/11/24 2:45:26

CycleGAN

    • 1.导入需要的包
    • 2.数据加载
      • (1)to_img 函数
      • (2)数据加载
      • (3)图像转换
    • 3.随机读取图像进行预处理
      • (1)函数参数
      • (2)数据路径
      • (3)读取文件列表
      • (4)初始化结果列表
      • (5)随机采样
      • (6)读取和预处理图像
      • (7)返回结果
    • 4.残差网络块
      • (1)构造函数
      • (2)残差块层
      • (3)跳跃连接
    • 5.生成器网络
      • (1)构造函数
      • (2)编码器部分
      • (3)残差块部分
      • (4)解码器部分
      • (5)输出层
      • (6)模型初始化
      • (7)前向传播
    • 6.判别器网络
      • (1)构造函数
      • (2)判别器层
      • (3)全卷积网络部分
      • (4)输出
    • 7.缓存生成器
      • (1)构造函数
      • (2)push_and_pop 方法
    • 8.训练生成对抗网络(GAN)
    • 9.优化器
    • 10.训练循环的迭代次数
    • 11.训练循环
    • 12.训练生成器
    • 13.训练判别器
    • 14.损失打印,存储伪造图片
    • 全部代码

CycleGAN(循环一致性对抗网络),用于实现两个域(例如,风格或主题不同的图像)之间的无监督图像到图像转换。
CycleGAN的核心思想是使用生成器(Generator)和判别器(Discriminator)来学习从源域(source
domain)到目标域(target domain)的映射,同时保持循环一致性,即从目标域映射回源域应该尽可能接近原始源域图像。

1.导入需要的包

from random import randint: 从Python的random模块中导入randint函数,用于生成随机整数。

import numpy as np: 导入Numpy库,并将其重命名为np,以便在代码中使用。
import torch:导入PyTorch库。
torch.set_default_tensor_type(torch.FloatTensor):设置PyTorch的默认Tensor类型为torch.FloatTensor。
import torch.nn as nn:导入PyTorch的神经网络模块,并将其重命名为nn。
import torch.optim as optim:导入PyTorch的优化器模块,并将其重命名为optim。
import torchvision.datasets as datasets: 导入PyTorch的图像数据集模块,并将其重命名为datasets。
import torchvision.transforms as transforms:导入PyTorch的图像变换模块,并将其重命名为transforms。
import os:导入Python的操作系统模块,用于处理文件和目录。
import matplotlib.pyplot as plt:导入matplotlib的Pyplot模块,用于绘图。
import torch.nn.functional as F:导入PyTorch的函数模块,并将其重命名为F。
from torch.autograd import Variable:从PyTorch的自动求导模块中导入Variable类。
from torchvision.utils import save_image: 从PyTorch的图像处理模块中导入save_image函数。
import shutil:导入Python的文件操作模块,用于删除文件和目录。
import cv2: 导入OpenCV库,用于图像处理和计算机视觉。
import random: 导入Python的随机模块。
from PIL import Image:从Pillow库中导入Image类。
import itertools: 导入Python的迭代工具模块。

from random import randint
import numpy as np 
import torch
torch.set_default_tensor_type(torch.FloatTensor)
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import shutil
import cv2
import random
from PIL import Image
import itertools

2.数据加载

(1)to_img 函数

out = 0.5 * (x + 1): 将输入张量 x 的值从 [-1, 1] 范围转换到 [0, 1] 范围。这是因为在训练过程中,图像通常会被归一化到 [-1, 1] 范围,而显示图像时需要将其转换回 [0, 1] 范围。
out = out.clamp(0, 1): 确保所有像素值都在 [0, 1] 范围内。clamp 函数将小于0的值设为0,大于1的值设为1。
out = out.view(-1, 3, 256, 256): 将张量 out 的形状重新调整为批次的形状,其中每个样本是一个 3通道(RGB)的 256x256 图像。-1 表示自动计算批次大小。

def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)  
    out = out.view(-1, 3, 256, 256)  
    return out

(2)数据加载

data_path = os.path.abspath('D:\probject\pythonProject1\pytorch\CycleGAN\data'):定义了数据的路径,使用os.path.abspath()将相对路径转换为绝对路径。
image_size = 256:指定图像的大小为256x256。
batch_size = 1:定义了批处理的大小为1。

data_path = os.path.abspath('D:\probject\pythonProject1\pytorch\CycleGAN\data')
image_size = 256
batch_size = 1

(3)图像转换

transform = transforms.Compose([: 创建一个由多个图像转换操作组成的管道。
transforms.Resize(int(image_size * 1.12), Image.BICUBIC): 将图像大小调整为原始大小的 1.12 倍。这样做是为了在后续的随机裁剪中提供更多的裁剪选择。
transforms.RandomCrop(image_size): 从调整大小后的图像中随机裁剪出 256x256 像素大小的区域。
transforms.RandomHorizontalFlip(): 以 50% 的概率随机水平翻转图像。
transforms.ToTensor(): 将 PIL 图像转换为 PyTorch 张量。
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)):对图像进行归一化处理,将每个通道的像素值从 [0, 1] 范围转换为 [-1, 1] 范围。

transform = transforms.Compose([transforms.Resize(int(image_size * 1.12), 
                                                  Image.BICUBIC), 
            transforms.RandomCrop(image_size), 
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

3.随机读取图像进行预处理

(1)函数参数

batch_size: 一个整数,表示每个批次中图像的数量。默认值为1。

def _get_train_data(batch_size=1):

(2)数据路径

train_a_filepath: 训练集A的文件路径。
train_b_filepath: 训练集B的文件路径。

	train_a_filepath = data_path + '\\trainA\\'
    train_b_filepath = data_path + '\\trainB\\'

(3)读取文件列表

train_a_list: 读取训练集A目录中的所有文件名。
train_b_list: 读取训练集B目录中的所有文件名。

   train_a_list = os.listdir(train_a_filepath)
   train_b_list = os.listdir(train_b_filepath)

(4)初始化结果列表

train_a_result: 存储处理后的训练集A图像。
train_b_result: 存储处理后的训练集B图像。

    train_a_result = []
    train_b_result = [] 

(5)随机采样

numlist: 从0到训练集A长度之间的范围中随机采样 batch_size 个索引。

numlist = random.sample(range(0, len(train_a_list)), batch_size)

(6)读取和预处理图像

对于 numlist 中的每个索引 i: 读取训练集A和B中对应的文件名。 使用 PIL.Image.open
打开图像文件,并将其转换为RGB格式。 应用之前定义的 transform 方法对图像进行预处理(包括调整大小、裁剪、翻转和归一化)。
将预处理后的图像添加到 train_a_result 和 train_b_result 列表中。

	for i in numlist:
        a_filename = train_a_list[i]
        a_img = Image.open(train_a_filepath + a_filename).convert('RGB')
        res_a_img = transform(a_img)
        train_a_result.append(torch.unsqueeze(res_a_img, 0))
        
        b_filename = train_b_list[i]
        b_img = Image.open(train_b_filepath + b_filename).convert('RGB')
        res_b_img = transform(b_img)
        train_b_result.append(torch.unsqueeze(res_b_img, 0))
        

(7)返回结果

使用 torch.cattrain_a_resulttrain_b_result
列表中的图像堆叠成一个批次,并返回这两个批次的图像。

4.残差网络块

残差块是一种常用的构建块,用于深度卷积神经网络,特别是在
ResNet(残差网络)架构中。它允许网络在学习过程中保留和利用之前层的信息,通过跳跃连接(shortcut
connections)来解决深层网络训练过程中的梯度消失问题。

(1)构造函数

def __init__(self, in_features): 构造函数接收一个参数 in_features,表示输入特征图的通道数。
super(ResidualBlock, self).__init__(): 调用父类 nn.Module 的构造函数。
self.block_layer: 定义一个顺序模型 nn.Sequential,包含残差块的所有层。

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block_layer = nn.Sequential

(2)残差块层

nn.ReflectionPad2d(1):使用反射填充(padding)来扩展输入张量的边界。这种填充方式在边缘反射输入数据,以保持边缘信息的连续性。
nn.Conv2d(in_features, in_features, 3): 使用 3x3的卷积核进行卷积操作,输入和输出通道数相同。
nn.InstanceNorm2d(in_features):应用实例归一化(Instance Normalization)来对每个样本的特征图进行归一化处理。这与批量归一化(Batch Normalization)不同,它不对整个批次的数据进行归一化,而是对单个样本的特征图进行归一化。
nn.ReLU(inplace=True): 应用 ReLU 激活函数,并设置 inplace=True以便直接修改输入张量,减少内存使用。

(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features))

(3)跳跃连接

return x + self.block_layer(x): 这是残差块的核心,它将输入张量 x 与
self.block_layer(x) 的输出相加,形成跳跃连接。这样,即使 self.block_layer
的输出为零(即网络未能学习到任何东西),输入 x 仍然可以通过跳跃连接直接传递到下一层,从而保持了信息的流通。

	def forward(self, x):
        return x + self.block_layer(x)

5.生成器网络

生成器的目的是将输入图像从一个域转换到另一个域。

(1)构造函数

super(Generator, self).__init__(): 调用父类 nn.Module 的构造函数。
model: 初始化一个列表,用于存储生成器网络中的层。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

(2)编码器部分

nn.ReflectionPad2d(3): 使用反射填充(padding)来扩展输入张量的边界。
nn.Conv2d(3, 64, 7): 使用 7x7 的卷积核将输入图像(3 通道)转换为 64 通道的特征图。
nn.InstanceNorm2d(64):应用实例归一化。
nn.ReLU(inplace=True): 应用 ReLU 激活函数。
for _ in range(2):重复以下层两次,以逐渐减少特征图的尺寸。
nn.Conv2d(in_features, out_features, 3,stride=2, padding=1): 使用 3x3 的卷积核,步长为 2,进行降采样。
nn.InstanceNorm2d(out_features): 应用实例归一化。
nn.ReLU(inplace=True):应用 ReLU 激活函数。

model = [nn.ReflectionPad2d(3), 
                 nn.Conv2d(3, 64, 7), 
                 nn.InstanceNorm2d(64), 
                 nn.ReLU(inplace=True)]

        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [nn.Conv2d(in_features, out_features, 
                                3, stride=2, padding=1), 
            nn.InstanceNorm2d(out_features), 
            nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features*2

(3)残差块部分

for _ in range(9): 重复添加 9 个残差块,这些块是 CycleGAN 生成器的核心,用于学习域之间的映射。

 	for _ in range(9):
            model += [ResidualBlock(in_features)]

(4)解码器部分

out_features = in_features // 2: 准备进行上采样,将特征图的尺寸加倍。
for _ in range(2): 重复以下层两次,以逐渐增加特征图的尺寸。
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1): 使用 3x3 的转置卷积核,步长为 2,进行上采样。
nn.InstanceNorm2d(out_features): 应用实例归一化。
nn.ReLU(inplace=True): 应用 ReLU 激活函数。

    out_features = in_features // 2
        for _ in range(2):
            model += [nn.ConvTranspose2d(
                    in_features, out_features, 
                    3, stride=2, padding=1, output_padding=1), 
                nn.InstanceNorm2d(out_features), 
                nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features // 2

(5)输出层

nn.ReflectionPad2d(3): 使用反射填充。
nn.Conv2d(64, 3, 7): 使用 7x7的卷积核将特征图转换回 3 通道的图像。
nn.Tanh(): 应用 Tanh 激活函数,将输出值范围映射到 [-1, 1]。

	model += [nn.ReflectionPad2d(3), 
                  nn.Conv2d(64, 3, 7), 
                  nn.Tanh()]

(6)模型初始化

self.gen = nn.Sequential( * model): 将所有层组合成一个顺序模型。

self.gen = nn.Sequential( * model)

(7)前向传播

def forward(self, x): 定义前向传播函数。
x = self.gen(x): 通过生成器网络传递输入 x。
return x: 返回生成器的输出。

	def forward(self, x):
        x = self.gen(x)
        return x 

6.判别器网络

(1)构造函数

super(Discriminator, self).__init__(): 调用父类 nn.Module 的构造函数。
self.dis: 定义一个顺序模型 nn.Sequential,包含判别器网络的所有层。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential

(2)判别器层

nn.Conv2d(3, 64, 4, 2, 1, bias=False): 使用 4x4 的卷积核,步长为2,进行降采样,输入通道数为 3(RGB),输出通道数为 64。
nn.LeakyReLU(0.2, inplace=True): 应用Leaky ReLU 激活函数,设置斜率为 0.2。
for _ in range(3): 重复以下层三次,以逐渐减少特征图的尺寸。
nn.Conv2d(in_features, out_features, 4, 2, 1, bias=False): 使用 4x4 的卷积核,步长为 2,进行降采样。
nn.InstanceNorm2d(out_features): 应用实例归一化。
nn.LeakyReLU(0.2, inplace=True): 应用 Leaky ReLU 激活函数。

(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

(3)全卷积网络部分

nn.Conv2d(256, 512, 4, padding=1): 使用 4x4 的卷积核,不进行降采样,输入通道数为256,输出通道数为 512。
nn.InstanceNorm2d(512): 应用实例归一化。
nn.LeakyReLU(0.2, inplace=True): 应用 Leaky ReLU 激活函数。
nn.Conv2d(512, 1, 4, padding=1):使用 4x4 的卷积核,不进行降采样,输入通道数为 512,输出通道数为 1。

			nn.Conv2d(256, 512, 4, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1, 4, padding=1))  

(4)输出

return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1):对判别器输出的特征图进行平均池化操作,然后将其展平为一维向量。这个一维向量将作为最终的判别结果,其长度为 1,表示输入图像的真实性(接近 1表示真实,接近 0 表示假)。

	def forward(self, x):
        x = self.dis(x)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

7.缓存生成器

(1)构造函数

def __init__(self, max_size=50): 定义了一个构造函数 init,用于在创建ReplayBuffer 对象时初始化其属性。
self.max_size = max_size: 初始化缓冲区的大小。
self.data = []: 初始化一个空列表 self.data,用于存储缓存的数据。

class ReplayBuffer():
#     """
#     缓存队列,若不足则新增,否则随机替换
#     """
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []

(2)push_and_pop 方法

def push_and_pop(self, data): 定义了一个方法,用于将新数据推入缓冲区,并在需要时弹出旧数据。
to_return = []: 初始化一个空列表 to_return,用于存储从缓冲区中弹出的数据。
for element in data.data:: 遍历传入的数据 data.data 中的每个元素。
element = torch.unsqueeze(element, 0):将每个元素展平为一维张量。这通常是为了确保张量的形状与预期的形状匹配,以便后续的操作可以正确执行。
if len(self.data) < self.max_size:: 如果缓冲区中还没有达到最大容量,则将新元素添加到缓冲区。
self.data.append(element): 将新元素添加到缓冲区。
to_return.append(element): 将新元素添加到 to_return 列表中。
else:: 如果缓冲区已满,则随机替换缓冲区中的一个元素。
if random.uniform(0,1) > 0.5:: 如果随机数大于 0.5,则从缓冲区中随机选择一个元素替换。
i = random.randint(0, self.max_size-1): 随机选择一个索引。
to_return.append(self.data[i].clone()): 将缓冲区中的元素复制并添加到 to_return列表中。
self.data[i] = element: 用新元素替换缓冲区中的元素。
else:: 如果随机数小于或等于 0.5,则直接添加新元素到 to_return 列表中。
to_return.append(element): 将新元素添加到 to_return 列表中。
return Variable(torch.cat(to_return)): 返回 to_return 列表中所有元素的拼接张量。Variable 是一个 PyTorch 类,用于表示可变的张量。torch.cat 函数用于将多个张量拼接在一起。

def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

8.训练生成对抗网络(GAN)

fake_A_buffer = ReplayBuffer(): 创建了一个名为 fake_A_buffer 的 ReplayBuffer实例。ReplayBuffer是一个用于缓存和随机替换数据的结构,在训练循环中用于缓存生成器生成的假图像,以便在后续的训练步骤中用于训练判别器。
fake_B_buffer = ReplayBuffer(): 创建了一个名为 fake_B_buffer 的 ReplayBuffer实例。这个缓冲区的作用与 fake_A_buffer 类似,用于缓存从生成器 netG_B2A 生成的假图像。

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

netG_A2B = Generator(): 创建了一个名为 netG_A2B 的 Generator 实例。Generator是一个用于生成新图像的神经网络,在这里,它将从域 A 生成域 B 的图像。
netG_B2A = Generator(): 创建了一个名为 netG_B2A 的 Generator 实例。这个生成器将从域 B生成域 A 的图像。
netD_A = Discriminator(): 创建了一个名为 netD_A 的 Discriminator实例。Discriminator 是一个用于判断图像是否真实的神经网络,在这里,它用于判断 A 类图像是否真实。
netD_B = Discriminator(): 创建了一个名为 netD_B 的 Discriminator实例。这个判别器用于判断 B 类图像是否真实。

netG_A2B = Generator()
netG_B2A = Generator()
netD_A = Discriminator()
netD_B = Discriminator()

criterion_GAN = torch.nn.MSELoss(): 定义了一个名为 criterion_GAN 的 MSELoss
损失函数。这个损失函数用于计算 GAN 损失,即判别器对真实图像和假图像的预测之间的差异。
criterion_cycle = torch.nn.L1Loss(): 定义了一个名为 criterion_cycle 的 L1Loss损失函数。这个损失函数用于计算循环一致性损失,即生成器生成的图像与其输入图像之间的差异。
criterion_identity = torch.nn.L1Loss(): 定义了一个名为 criterion_identity 的 L1Loss损失函数。这个损失函数用于计算身份损失,即生成器生成的图像与其输入图像之间的差异。

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

d_learning_rate = 3e-4 : 定义了判别器的学习率。
g_learning_rate = 3e-4:定义了生成器的 learning rate。
optim_betas = (0.5, 0.999): 定义了优化器的超参数betas,这是用于计算梯度下降的动量项的值。

d_learning_rate = 3e-4  
g_learning_rate = 3e-4
optim_betas = (0.5, 0.999)

9.优化器

g_optimizer = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=d_learning_rate): 创建了一个名为 g_optimizer 的Adam 优化器实例。Adam 是一种常用的优化算法,用于调整神经网络的权重。这里,itertools.chain函数用于将两个生成器的参数合并为一个单一的迭代器,以便于一起优化。lr 参数指定了学习率,它用于控制权重更新的速度。

da_optimizer = optim.Adam(netD_A.parameters(), lr=d_learning_rate):创建了一个名为 da_optimizer 的 Adam 优化器实例,用于训练判别器 netD_A。

db_optimizer = optim.Adam(netD_B.parameters(), lr=d_learning_rate):创建了一个名为 db_optimizer 的 Adam 优化器实例,用于训练判别器 netD_B。

g_optimizer = optim.Adam(itertools.chain(netG_A2B.parameters(), 
                                         netG_B2A.parameters()), 
            lr=d_learning_rate)
da_optimizer = optim.Adam(netD_A.parameters(), lr=d_learning_rate)
db_optimizer = optim.Adam(netD_B.parameters(), lr=d_learning_rate)

10.训练循环的迭代次数

num_epochs = 100: 定义了训练循环的迭代次数。epoch是一个训练周期,在这个周期内,所有数据都会被遍历一次。在这里,训练循环将执行 100 个周期。

num_epochs = 100

11.训练循环

for epoch in range(num_epochs):: 开始一个循环,该循环将执行指定的次数(由 num_epochs定义)。
real_a, real_b = _get_train_data(batch_size): 从数据集中获取一批真实图像real_a 和 real_b。
target_real = torch.full((batch_size,), 1).float():创建一个全为 1 的张量 target_real,用于指示真实图像。
target_fake =torch.full((batch_size,), 0).float(): 创建一个全为 0 的张量target_fake,用于指示假图像。
g_optimizer.zero_grad():清除生成器的梯度,以便于下一次前向传播和反向传播时不会累积梯度。

for epoch in range(num_epochs): 

    real_a, real_b = _get_train_data(batch_size)
    target_real = torch.full((batch_size,), 1).float()
    target_fake = torch.full((batch_size,), 0).float()
    
	g_optimizer.zero_grad()

12.训练生成器

same_B = netG_A2B(real_b).float(): 使用生成器 netG_A2B 从真实图像 real_b生成相似的图像 same_B。

loss_identity_B = criterion_identity(same_B, real_b) * 5.0: 计算same_B 和 real_b 之间的身份损失,并乘以 5.0 以增加其权重。

same_A = netG_B2A(real_a).float(): 使用生成器 netG_B2A 从真实图像 real_a生成相似的图像 same_A。

loss_identity_A = criterion_identity(same_A, real_a) * 5.0: 计算same_A 和 real_a 之间的身份损失,并乘以 5.0 以增加其权重。

fake_B = netG_A2B(real_a).float(): 使用生成器 netG_A2B 从真实图像 real_a 生成假图像fake_B。

pred_fake = netD_B(fake_B).float(): 使用判别器 netD_B 判断 fake_B 是否为假图像。

loss_GAN_A2B = criterion_GAN(pred_fake, target_real): 计算判别器对 fake_B的预测和真实图像的损失,即 GAN 损失。

fake_A = netG_B2A(real_b).float(): 使用生成器 netG_B2A 从真实图像 real_b 生成假图像fake_A。

pred_fake = netD_A(fake_A).float(): 使用判别器 netD_A 判断 fake_A 是否为假图像。

loss_GAN_B2A = criterion_GAN(pred_fake, target_real): 计算判别器对 fake_A的预测和真实图像的损失,即 GAN 损失。

recovered_A = netG_B2A(fake_B).float(): 使用生成器 netG_B2A 从假图像 fake_B生成恢复的图像 recovered_A。

loss_cycle_ABA = criterion_cycle(recovered_A, real_a) * 10.0: 计算recovered_A 和 real_a 之间的循环一致性损失,并乘以 10.0 以增加其权重。

recovered_B = netG_A2B(fake_A).float(): 使用生成器 netG_A2B 从假图像 fake_A生成恢复的图像 recovered_B。

loss_cycle_BAB = criterion_cycle(recovered_B, real_b) * 10.0: 计算recovered_B 和 real_b 之间的循环一致性损失,并乘以 10.0 以增加其权重。

loss_G = (loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB): 将所有损失加在一起,得到生成器的总损失。

loss_G.backward(): 对总损失进行反向传播,计算每个参数的梯度。

g_optimizer.step():会对生成器的所有参数进行梯度更新,以最小化生成器损失函数。

# 第一步:训练生成器
    same_B = netG_A2B(real_b).float()
    loss_identity_B = criterion_identity(same_B, real_b) * 5.0   
    same_A = netG_B2A(real_a).float()
    loss_identity_A = criterion_identity(same_A, real_a) * 5.0
    
    fake_B = netG_A2B(real_a).float()
    pred_fake = netD_B(fake_B).float()
    loss_GAN_A2B = criterion_GAN(pred_fake, target_real)
    fake_A = netG_B2A(real_b).float()
    pred_fake = netD_A(fake_A).float()
    loss_GAN_B2A = criterion_GAN(pred_fake, target_real)
    recovered_A = netG_B2A(fake_B).float()
    loss_cycle_ABA = criterion_cycle(recovered_A, real_a) * 10.0
    recovered_B = netG_A2B(fake_A).float()
    loss_cycle_BAB = criterion_cycle(recovered_B, real_b) * 10.0  
    loss_G = (loss_identity_A + loss_identity_B + loss_GAN_A2B + 
              loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB)
    loss_G.backward()    
    g_optimizer.step()

13.训练判别器

da_optimizer.zero_grad(): 清除判别器 A 的梯度,以便于下一次前向传播和反向传播时不会累积梯度。

pred_real = netD_A(real_a).float(): 使用判别器 A 来判断真实图像 real_a 是否为真实图像。

loss_D_real = criterion_GAN(pred_real, target_real): 计算判别器 A对真实图像的预测和真实图像的损失,即 GAN 损失。

fake_A = fake_A_buffer.push_and_pop(fake_A): 从 fake_A_buffer 中获取一批fake_A 图像,这些图像是从生成器 A 生成的假图像。

pred_fake = netD_A(fake_A.detach()).float(): 使用判别器 A 来判断 fake_A是否为假图像。由于 fake_A 是从 fake_A_buffer 中获取的,它已经与生成器的梯度解耦,因此不需要梯度信息。

loss_D_fake = criterion_GAN(pred_fake, target_fake): 计算判别器 A 对fake_A 的预测和假图像的损失,即 GAN 损失。

loss_D_A = (loss_D_real + loss_D_fake) * 0.5: 将判别器 A的真实图像损失和假图像损失加在一起,得到判别器 A 的总损失。

loss_D_A.backward(): 对判别器 A 的总损失进行反向传播,计算每个参数的梯度。
da_optimizer.step(): 使用之前计算的梯度来更新判别器 A 的参数。

   # 第二步:训练判别器
    # 训练判别器A
    da_optimizer.zero_grad()
    pred_real = netD_A(real_a).float()
    loss_D_real = criterion_GAN(pred_real, target_real)
    fake_A = fake_A_buffer.push_and_pop(fake_A)
    pred_fake = netD_A(fake_A.detach()).float()
    loss_D_fake = criterion_GAN(pred_fake, target_fake)
    loss_D_A = (loss_D_real + loss_D_fake) * 0.5
    loss_D_A.backward()
    da_optimizer.step()
    # 训练判别器B
    db_optimizer.zero_grad()
    pred_real = netD_B(real_b)
    loss_D_real = criterion_GAN(pred_real, target_real)
    fake_B = fake_B_buffer.push_and_pop(fake_B)
    pred_fake = netD_B(fake_B.detach())
    loss_D_fake = criterion_GAN(pred_fake, target_fake)
    loss_D_B = (loss_D_real + loss_D_fake) * 0.5
    loss_D_B.backward()
	db_optimizer.step()

14.损失打印,存储伪造图片

print('Epoch[{}],loss_G:{:.6f} ,loss_D_A:{:.6f},loss_D_B:{:.6f}' .format(epoch, loss_G.data.item(), loss_D_A.data.item(), loss_D_B.data.item())):打印当前训练周期(epoch)的损失,包括生成器损失(loss_G)和两个判别器损失(loss_D_A 和 loss_D_B)。
if (epoch + 1) % 20 == 0 or epoch == 0:: 检查当前训练周期是否是 20的倍数,或者是否是第一个周期。如果是,则执行以下操作。
b_fake = to_img(fake_B.data): 将判别器 B的输入(fake_B)转换回图像格式。
a_fake = to_img(fake_A.data): 将判别器 A的输入(fake_A)转换回图像格式。
a_real = to_img(real_a.data): 将真实图像 A 转换回图像格式。
b_real = to_img(real_b.data): 将真实图像 B 转换回图像格式。
save_image(a_fake,'../tmp/a_fake.png'): 将 a_fake 图像保存到文件 …/tmp/a_fake.png。
save_image(b_fake, '../tmp/b_fake.png'): 将 b_fake 图像保存到文件…/tmp/b_fake.png。
save_image(a_real, '../tmp/a_real.png'): 将 a_real图像保存到文件 …/tmp/a_real.png。
save_image(b_real, '../tmp/b_real.png'):将 b_real 图像保存到文件 …/tmp/b_real.png。

 #损失打印,存储伪造图片
    print('Epoch[{}],loss_G:{:.6f} ,loss_D_A:{:.6f},loss_D_B:{:.6f}'
      .format(epoch, loss_G.data.item(), loss_D_A.data.item(), 
              loss_D_B.data.item()))
    if (epoch + 1) % 20 == 0 or epoch == 0:  
        b_fake = to_img(fake_B.data)
        a_fake = to_img(fake_A.data)
        a_real = to_img(real_a.data)
        b_real = to_img(real_b.data)
        save_image(a_fake, '../tmp/a_fake.png') 
        save_image(b_fake, '../tmp/b_fake.png') 
        save_image(a_real, '../tmp/a_real.png') 
        save_image(b_real, '../tmp/b_real.png') 

在这里插入图片描述
在这里插入图片描述

全部代码

from random import randint
import numpy as np 
import torch
torch.set_default_tensor_type(torch.FloatTensor)
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import shutil
import cv2
import random
from PIL import Image
import itertools   
 def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)  
    out = out.view(-1, 3, 256, 256)  
    return out

# 数据加载 
data_path = os.path.abspath('D:\probject\pythonProject1\pytorch\CycleGAN\data')
image_size = 256
batch_size = 1

transform = transforms.Compose([transforms.Resize(int(image_size * 1.12), 
                                                  Image.BICUBIC), 
            transforms.RandomCrop(image_size), 
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
def _get_train_data(batch_size=1):
    
    train_a_filepath = data_path + '\\trainA\\'
    train_b_filepath = data_path + '\\trainB\\'
    
    train_a_list = os.listdir(train_a_filepath)
    train_b_list = os.listdir(train_b_filepath)
    
    train_a_result = []
    train_b_result = [] 
    
    numlist = random.sample(range(0, len(train_a_list)), batch_size)
    
    for i in numlist:
        a_filename = train_a_list[i]
        a_img = Image.open(train_a_filepath + a_filename).convert('RGB')
        res_a_img = transform(a_img)
        train_a_result.append(torch.unsqueeze(res_a_img, 0))
        
        b_filename = train_b_list[i]
        b_img = Image.open(train_b_filepath + b_filename).convert('RGB')
        res_b_img = transform(b_img)
        train_b_result.append(torch.unsqueeze(res_b_img, 0))
        
    return torch.cat(train_a_result, dim=0), torch.cat(train_b_result, dim=0)

# """
# 残差网络block
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block_layer = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features))
        
    def forward(self, x):
        return x + self.block_layer(x)
# 生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
             
        model = [nn.ReflectionPad2d(3), 
                 nn.Conv2d(3, 64, 7), 
                 nn.InstanceNorm2d(64), 
                 nn.ReLU(inplace=True)]

        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [nn.Conv2d(in_features, out_features, 
                                3, stride=2, padding=1), 
            nn.InstanceNorm2d(out_features), 
            nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features*2

        for _ in range(9):
            model += [ResidualBlock(in_features)]

        out_features = in_features // 2
        for _ in range(2):
            model += [nn.ConvTranspose2d(
                    in_features, out_features, 
                    3, stride=2, padding=1, output_padding=1), 
                nn.InstanceNorm2d(out_features), 
                nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features // 2

        model += [nn.ReflectionPad2d(3), 
                  nn.Conv2d(64, 3, 7), 
                  nn.Tanh()]

        self.gen = nn.Sequential( * model)
        
    def forward(self, x):
        x = self.gen(x)
        return x 
# 判别器 

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1, 4, padding=1))        
        
    def forward(self, x):
        x = self.dis(x)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)


class ReplayBuffer():
#     """
#     缓存队列,若不足则新增,否则随机替换
#     """
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []
        
    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))
    
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

netG_A2B = Generator()
netG_B2A = Generator()
netD_A = Discriminator()
netD_B = Discriminator()

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

d_learning_rate = 3e-4  # 3e-4
g_learning_rate = 3e-4
optim_betas = (0.5, 0.999)

g_optimizer = optim.Adam(itertools.chain(netG_A2B.parameters(), 
                                         netG_B2A.parameters()), 
            lr=d_learning_rate)
da_optimizer = optim.Adam(netD_A.parameters(), lr=d_learning_rate)
db_optimizer = optim.Adam(netD_B.parameters(), lr=d_learning_rate)

num_epochs = 100
for epoch in range(num_epochs): 

    real_a, real_b = _get_train_data(batch_size)
    target_real = torch.full((batch_size,), 1).float()
    target_fake = torch.full((batch_size,), 0).float()
    
    g_optimizer.zero_grad()
    
    # 第一步:训练生成器
    same_B = netG_A2B(real_b).float()
    loss_identity_B = criterion_identity(same_B, real_b) * 5.0   
    same_A = netG_B2A(real_a).float()
    loss_identity_A = criterion_identity(same_A, real_a) * 5.0
    
    fake_B = netG_A2B(real_a).float()
    pred_fake = netD_B(fake_B).float()
    loss_GAN_A2B = criterion_GAN(pred_fake, target_real)
    fake_A = netG_B2A(real_b).float()
    pred_fake = netD_A(fake_A).float()
    loss_GAN_B2A = criterion_GAN(pred_fake, target_real)
    recovered_A = netG_B2A(fake_B).float()
    loss_cycle_ABA = criterion_cycle(recovered_A, real_a) * 10.0
    recovered_B = netG_A2B(fake_A).float()
    loss_cycle_BAB = criterion_cycle(recovered_B, real_b) * 10.0  
    loss_G = (loss_identity_A + loss_identity_B + loss_GAN_A2B + 
              loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB)
    loss_G.backward()    
    g_optimizer.step()
    
    
    # 第二步:训练判别器
    # 训练判别器A
    da_optimizer.zero_grad()
    pred_real = netD_A(real_a).float()
    loss_D_real = criterion_GAN(pred_real, target_real)
    fake_A = fake_A_buffer.push_and_pop(fake_A)
    pred_fake = netD_A(fake_A.detach()).float()
    loss_D_fake = criterion_GAN(pred_fake, target_fake)
    loss_D_A = (loss_D_real + loss_D_fake) * 0.5
    loss_D_A.backward()
    da_optimizer.step()
    # 训练判别器B
    db_optimizer.zero_grad()
    pred_real = netD_B(real_b)
    loss_D_real = criterion_GAN(pred_real, target_real)
    fake_B = fake_B_buffer.push_and_pop(fake_B)
    pred_fake = netD_B(fake_B.detach())
    loss_D_fake = criterion_GAN(pred_fake, target_fake)
    loss_D_B = (loss_D_real + loss_D_fake) * 0.5
    loss_D_B.backward()
    db_optimizer.step()
    
    
    #损失打印,存储伪造图片
    print('Epoch[{}],loss_G:{:.6f} ,loss_D_A:{:.6f},loss_D_B:{:.6f}'
      .format(epoch, loss_G.data.item(), loss_D_A.data.item(), 
              loss_D_B.data.item()))
    if (epoch + 1) % 20 == 0 or epoch == 0:  
        b_fake = to_img(fake_B.data)
        a_fake = to_img(fake_A.data)
        a_real = to_img(real_a.data)
        b_real = to_img(real_b.data)
        save_image(a_fake, '../tmp/a_fake.png') 
        save_image(b_fake, '../tmp/b_fake.png') 
        save_image(a_real, '../tmp/a_real.png') 
        save_image(b_real, '../tmp/b_real.png') 
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1865636.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构——带头双向循环链表(c语言实现)

目录 1.单链表和双向链表对比 2.双向链表实现 2.1 创建新节点 2.2 链表初始化 2.3 尾插 2.4 头插 2.5 尾删 2.6 头删 2.7 查找 2.8 指定位置后插入数据 2.9 删除指定节点 2.10 销毁链表 2.11 打印链表 前言&#xff1a; 我们在前几期详细地讲解了不带头单…

EthernetIP IO从站设备数据 转opc ua项目案例

1 案例说明 设置网关采集EthernetIP IO设备数据把采集的数据转成opc ua协议转发给其他系统。 2 VFBOX网关工作原理 VFBOX网关是协议转换网关&#xff0c;是把一种协议转换成另外一种协议。网关可以采集西门子&#xff0c;欧姆龙&#xff0c;三菱&#xff0c;AB PLC&#xff0…

day41--Redis(三)高级篇之最佳实践

Redis高级篇之最佳实践 今日内容 Redis键值设计批处理优化服务端优化集群最佳实践 1、Redis键值设计 1.1、优雅的key结构 Redis的Key虽然可以自定义&#xff0c;但最好遵循下面的几个最佳实践约定&#xff1a; 遵循基本格式&#xff1a;[业务名称]:[数据名]:[id]长度不超过…

STM32音频应用开发:DMA与定时器的高效协作

摘要: 本文章将深入浅出地介绍如何使用STM32单片机实现音频播放功能。文章将从音频基础知识入手&#xff0c;逐步讲解音频解码、DAC转换、音频放大等关键环节&#xff0c;并结合STM32 HAL库给出具体的代码实现和电路设计方案。最后&#xff0c;我们将通过一个实例演示如何播放W…

FPGA SATA高速存储设计

今天来讲一篇如何在fpga上实现sata ip&#xff0c;然后利用sata ip实现读写sata 盘的目的&#xff0c;如果需要再速度和容量上增加&#xff0c;那么仅仅需要增加sata ip个数就能够实现增加sata盘&#xff0c;如果仅仅实现data的读写整体来说sata ip设计比较简单&#xff0c;下面…

大模型产品的“命名经济学”:名字越简单,产品越火爆?

文 | 智能相对论 作者 | 陈泊丞 古人云&#xff1a;赐子千金&#xff0c;不如教子一艺&#xff1b;教子一艺&#xff0c;不如赐子一名。 命名之妙&#xff0c;玄之又玄。 早两年&#xff0c;大模型爆火&#xff0c;本土厂商在大模型产品命名上可谓下足了功夫&#xff0c;引…

小公司选择高水平LabVIEW团队外包

小公司选择高水平LabVIEW团队外包的优点可以从多个方面进行详细说明&#xff0c;包括专业技能、成本效益、时间效率、技术支持、风险管理和灵活性等。以下是具体的优点分析&#xff1a; 1. 专业技能和经验 优点&#xff1a; 高水平专业技能&#xff1a;高水平的LabVIEW团队通…

力扣316.去除重复字母

力扣316.去除重复字母 从左到右遍历每个字母 若当前字母比栈顶字母小 并且右边仍然后栈顶字母出现弹出栈顶字母 最后加入当前字母 class Solution {public:string removeDuplicateLetters(string s) {//记录每个字母出现次数;当前字符串中字母是否出现vector<int> lef…

单目标应用:基于吸血水蛭优化器(Blood-Sucking Leech Optimizer,BSLO)的微电网优化(MATLAB代码)

一、微电网模型介绍 微电网多目标优化调度模型简介_vmgpqv-CSDN博客 参考文献&#xff1a; [1]李兴莘,张靖,何宇,等.基于改进粒子群算法的微电网多目标优化调度[J].电力科学与工程, 2021, 37(3):7 二、吸血水蛭优化器求解微电网 2.1算法简介 吸血水蛭优化器&#xff08;B…

AI技术与艺术的融合:开创性的用户界面与产品体验

引言 近年来&#xff0c;人工智能&#xff08;AI&#xff09;的飞速发展改变了我们的生活和工作方式。AI技术不仅在算力和模型上取得了重大进步&#xff0c;更在用户界面和产品体验方面迎来了突破。近日&#xff0c;科技博客 Stratechery 的文章以及硅谷投资基金 AI Grant 的两…

platform 设备驱动实验

platform 设备驱动实验 Linux 驱动的分离与分层 代码的重用性非常重要&#xff0c;否则的话就会在 Linux 内核中存在大量无意义的重复代码。尤其是驱动程序&#xff0c;因为驱动程序占用了 Linux内核代码量的大头&#xff0c;如果不对驱动程序加以管理&#xff0c;任由重复的…

注意!短视频的致命误区,云微客教你避开!

为什么你做短视频就是干不过同行&#xff1f;因为你总想着拍剧情、段子这些娱乐视频&#xff0c;还想着当网红做IP人设&#xff0c;但是这些内容跟你的盈利没有半毛钱关系&#xff0c;况且难度大、见效慢&#xff0c;还不是精准客户。 以上这些就代表你走进了短视频的误区&…

Linux常用环境变量PATH

Linux常用环境变量 一、常用的默认的shell环境变量二、环境变量 PATH三、持久化修改环境变量四、常用的环境变量 一、常用的默认的shell环境变量 1、当我们在shell命令行属于一个命令&#xff0c;shell解释器去解释这个命令的时候&#xff0c;需要先找到这个命令. 找到命令有两…

边缘混合计算智慧矿山视频智能综合管理方案:矿山安全生产智能转型升级之路

一、智慧矿山方案介绍 智慧矿山是以矿山数字化、信息化为前提和基础&#xff0c;通过物联网、人工智能等技术进行主动感知、自动分析、快速处理&#xff0c;实现安全矿山、高效矿山的矿山智能化建设。旭帆科技TSINGSEE青犀基于图像的前端计算、边缘计算技术&#xff0c;结合煤…

k8s学习--Kruise Rollouts 基本使用

文章目录 Kruise Rollouts简介什么是 Kruise Rollouts&#xff1f;核心功能 应用环境一、OpenKruise部署1.安装helm客户端工具2. 通过 helm 安装 二、Kruise Rollouts 安装2. kubectl plugin安装 三、Kruise Rollouts 基本使用(多批次发布)1. 使用Deployment部署应用2.准备Roll…

当前的网安行业绝对不是高薪行业

昨天&#xff0c;面试了一个刚毕业两年的同学小A。第一学历为某大专&#xff0c;第二学历为某省地区的本科院校。面试过程表现一般偏下&#xff0c;但动不动就要薪资15K 这个人&#xff0c;我当场就PASS了。主要原因是&#xff0c;并非是否定小A同学的能力&#xff0c;而是他…

Axure 教程 | 雅虎新闻焦点

主要内容 在雅虎首页&#xff0c;新闻焦点大图和焦点小图同步切换轮播&#xff0c;本课程我们来学习如何实现这个效果。 交互说明 1.页面载入后&#xff0c;切换当前屏幕显示的5张焦点图&#xff0c;小图标处以横线提示当前焦点图。 2.鼠标移入焦点大图&#xff0c;新闻标题显示…

打破数据分析壁垒:SPSS复习必备(九)

有序定性资料统计推断 1.分类 单向有序行列表 双向有序属性相同行列表 双向有序属性不同行列表 2.单向有序行列表 秩和检验 ① 两组单向有序分类资料 ②多组单向有序定性资料 步骤&#xff1a; 1.建立检验假设和确定检验水准 2.编秩 3.求秩和 4.确定检验统计量 5…

MAC 查看公钥私钥

电脑配置过公钥私钥&#xff0c;现在需要查看&#xff1a; 1、 查看本地是否存在SSH密钥 命令&#xff1a;ls -al ~/.ssh 如果在输出的文件列表中发现id_rsa和id_rsa.pub的存在&#xff0c;证明本地已经存在SSH密钥&#xff0c;请执行第3步 2、 生成SSH密钥 命令&#xff1…

【Java】已解决java.nio.channels.OverlappingFileLockException异常

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例五、注意事项 已解决java.nio.channels.OverlappingFileLockException异常 在Java的NIO&#xff08;New I/O&#xff09;编程中&#xff0c;java.nio.channels.OverlappingFileLockException是一…