pix2pix代码

news2024/12/25 12:17:44

看一下模型图:
在这里插入图片描述
首先定义生成器G,和CGAN不同的是,pix2pix并没有输入噪声,而是采用dropout来增加随即性。然后生成器输入x,输出y都是一些图片。最后按照原文,G是一个U-Net shape的,除了上采样和下采样,最重要的是跳连接。

import torch
import torch.nn as nn
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x
class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(
            features * 2, features * 4, down=True, act="leaky", use_dropout=False
        )
        self.down3 = Block(
            features * 4, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down4 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down5 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down6 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
        )

        self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up3 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up4 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up5 = Block(
            features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
        )
        self.up6 = Block(
            features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
        )
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):#(1,3,256,256)
        d1 = self.initial_down(x)#(1,64,128,128)
        d2 = self.down1(d1)#(1,128,64,64)
        d3 = self.down2(d2)#(1,256,32,32)
        d4 = self.down3(d3)#(1,512,16,16)
        d5 = self.down4(d4)#(1,512,8,8)
        d6 = self.down5(d5)#(1,512,4,4)
        d7 = self.down6(d6)#(1,512,2,2)
        bottleneck = self.bottleneck(d7)#(1,512,1,1)
        up1 = self.up1(bottleneck)#(1,512,2,2)
        up2 = self.up2(torch.cat([up1, d7], 1))#(1,512,4,4)
        up3 = self.up3(torch.cat([up2, d6], 1))#(1,512,8,8)
        up4 = self.up4(torch.cat([up3, d5], 1))#(1,512,16,16)
        up5 = self.up5(torch.cat([up4, d4], 1))#(1,256,32,32)
        up6 = self.up6(torch.cat([up5, d3], 1))#(1,128,64,64)
        up7 = self.up7(torch.cat([up6, d2], 1))#(1,64,128,128)
        return self.final_up(torch.cat([up7, d1], 1))#(1,3,256,256)
def test():
    x = torch.randn((1, 3, 256, 256))
    model = Generator(in_channels=3, features=64)
    preds = model(x)
    print(preds.shape)
if __name__ == "__main__":
    test()

这里随机生成一个和真实数据集大小的tensor进行验证。
首先使用第一个卷积:和常见的卷积不同的是卷积核大小为4,padd为reflect,后面没加BN,加的是LeakyReLU。
原始大小为(1,3,256,256)经过第一个卷积后变为((1,64,128,128))
在这里插入图片描述
接着经过6个down,就是encoder的连续下采样。
在这里插入图片描述
看一个,其他的也都一样。
在这里插入图片描述
在Block内部:指定down,leakyrelu和dropout。如果指定了down那么使用步长为2的卷积进行下采样,如果未指定就使用转置卷积,后面紧接BN和leakyrelu。最后在encoder中不使用dropout。
在这里插入图片描述
在encoder和decoder中间是bottleneck。是一个卷积加relu。
在这里插入图片描述
在这里插入图片描述
需要注意的是encoder的图片通道变换:不同于ResNet。
在这里插入图片描述
在decoder中首先进行上采样,才能和encoder对应层concat,否则大小不一无法concat。
在这里插入图片描述
通过设置参数down为False,那么就采用的转置卷积,设置的激活函数为relu,且decoder前三层使用dropout。这些是和encoder不一样的地方。
在这里插入图片描述
最终进过一个转置卷积和tanh得到最终的输出和原图像大小一样。
上述代码实现的是:
在这里插入图片描述
接着是辨别器:由原始论文知道采用的是patchGAN。在代码中也是通过卷积实现的。

import torch
import torch.nn as nn
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels * 2,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )
        layers = []
        in_channels = features[0]#64
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )
        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)#(1,6,256,256)
        x = self.initial(x)#(1,64,128,128)
        x = self.model(x)
        return x


def test():
    x = torch.randn((1, 3, 256, 256))
    y = torch.randn((1, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x, y)#(1,1,30,30)
    print(model)
    print(preds.shape)
if __name__ == "__main__":
    test()

辨别器D的输入有两个,因为本质还是CGAN,所以一个输入为生成的图片,另一个输入为condition,也就是x。
分为三步,首先将condition和生成的图片concat在一起,接着经过一个卷积来增大通道数,最后进过辨别器。
在这里插入图片描述
1:concat
2:拼接后的通道扩充到64,步长为2.
在这里插入图片描述
3:遍历feature,layers里面有四个卷积,采用的是CONV+BN+LeakyReLU形式,最后输出的通道为1.输出大小为30x30.
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述辨别器model:

Sequential(
  (0): CNNBlock(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (1): CNNBlock(
    (conv): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (2): CNNBlock(
    (conv): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (3): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
)

接着看train:

import torch
from utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
import config
from dataset import MapDataset
from generator_model import Generator
from discriminator_model import Discriminator
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.utils import save_image

torch.backends.cudnn.benchmark = True
def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y = y.to(config.DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * config.L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )


def main():
    disc = Discriminator(in_channels=3).to(config.DEVICE)
    gen = Generator(in_channels=3, features=64).to(config.DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
        )

    train_dataset = MapDataset(root_dir=config.TRAIN_DIR)
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    val_dataset = MapDataset(root_dir=config.VAL_DIR)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(config.NUM_EPOCHS):
        train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
        )

        if config.SAVE_MODEL and epoch % 5 == 0:
            save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)

        save_some_examples(gen, val_loader, epoch, folder="/home/Projects/ZQB/a/PyTorch-GAN-master/implementations/pix2pix-pytorch/results")


if __name__ == "__main__":
    main()

1:实例化辨别器,生成器,设置优化器和损失函数。
在这里插入图片描述
2:传入预训练的权重:
在这里插入图片描述
定义数据集:我们采用的给素描上色的数据集。
在这里插入图片描述
数据集结构:我们根据设定好的数据集位置,加载train文件下的图片。
在这里插入图片描述
我们到dataset中,主要看getitem中,如何加载处理数据的。

import numpy as np
import config
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image


class MapDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        t = np.unique(image)
        print(t)
        input_image = image[:, :600, :]#(512,600,3)
        target_image = image[:, 600:, :]#(512,424,3)

        augmentations = config.both_transform(image=input_image, image0=target_image)
        input_image = augmentations["image"]#(256,256,3)
        target_image = augmentations["image0"]#(256,256,3)

        input_image = config.transform_only_input(image=input_image)["image"]#(3,256,256)
        target_image = config.transform_only_mask(image=target_image)["image"]#(3,256,256)

        return input_image, target_image


if __name__ == "__main__":
    dataset = MapDataset("data/train/")
    loader = DataLoader(dataset, batch_size=5)
    for x, y in loader:
        print(x.shape)
        save_image(x, "x.png")
        save_image(y, "y.png")
        import sys

        sys.exit()

首先根据索引,我们找到对应的图片并读入:
在这里插入图片描述
接着对图片进行划分:因为原始图片的input和target是连在一起的。
在这里插入图片描述
在这里插入图片描述
将图片拆分:
input:
在这里插入图片描述
target:
在这里插入图片描述
然后将两个图片裁切到256x256.再将两个图片进行变换:
在这里插入图片描述
也就是执行mydataset时候输出的是input和配对的target。
回到train中:通过trainloader对图片进行加载用于训练。
在这里插入图片描述
同理对val文件夹的图片进行加载,用于val。
在这里插入图片描述
然后就是正式训练:
在这里插入图片描述
将模型,数据,优化器,损失函数都传到train中:
在这里插入图片描述
在train_fn函数中,首先添加一个进度条,接着将input和target都输入到cuda中。
训练判别器:
将真实的x,y输入到判别器中输出的真我们希望为1,将x输入到生成器生成的假y和真实的x(作为condition)输入到判别器中,我们希望输出0。
在这里插入图片描述
训练生成器:真实的x和虚假的y输入到D中,我们希望D判别不出来,即输出为1.还有一个L1损失,即真实的标签和虚假的生成之间的损失。然后两个损失加起来作为生成器损失。
在这里插入图片描述
接着我们保存D和G的权重。
然后保存图片:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/537035.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

openwrt的openclash提示【更新失败,请确认设备闪存空间足够后再试】

网上搜索了一下,问题应该是出在“无法从网络下载内核更新包”或者“无法识别内核的版本号” 解决办法:手动下载(我是只搞了DEV内核就搞定了TUN和Meta没有动) --> 上传到路由器上 --> 解压缩 --> 回到openclash界面更新配…

技术干货 | 结构光技术及其实现三维成像的主要原理

原创 | 文BFT机器人 3D表面成像的一种主要方法是基于“结构光”的使用 即使用专门设计的二维空间变化强度模式对场景进行主动照明 3D结构光的整个系统包含结构光投影设备、摄像机、图像采集和处理系统。其过程就是投影设备发射光线到被测物体上,摄像机拍摄在被测物体…

Element 格式化表单文本内容

提示&#xff1a;本文效果通用于其它形式文本格式化&#xff0c;此处以 Element 表单为例。 数据库内容 数据存储格式为 text&#xff0c;换行符号 \r\n 前端代码 <el-table-column v-if"columns.visible(changeContent)" prop"changeContent" labe…

线程池的位运算详解

前言 翻阅 Java 线程池的源码&#xff0c;可以看到用到了大量的位运算操作&#xff0c;本文来分析下这些位运算是如何计算的&#xff0c;以及最后算出的结果是什么。 正文 阅读之前&#xff0c;必须熟悉一下内容 & 与运算| 或运算&#xff5e; 取反<< 左移负数的二进…

迪赛智慧数——柱状图(象形标识图):当代职场人心愿清单TOP10

效果图 职场人十大心愿&#xff1a;“脱单”位列第一&#xff0c;“容貌焦虑”成新难题。 除脱单之外&#xff0c;如今职场人的十大心愿就寄托了人类的高质量梦想&#xff0c;比如财务自由、到点下班、提前退休、有房有车…… 不过&#xff0c;让人心疼的是&#xff0c;不少人…

「网安人必看」安全行业主流证书,你知道如何选择吗

现在&#xff0c;越来越多单位为了满足国家安全法律法规的要求&#xff0c;成立独立的网络安全部门&#xff0c;招聘网络安全人才&#xff0c;组建 SRC&#xff08;安全响应中心&#xff09;&#xff0c;为自己的产品、应用、数据保卫护航。短短几年间&#xff0c;网络安全工程…

Goby 漏洞更新 | Weaver e-cology ofsLogin.jsp 用户登陆绕过漏洞

漏洞名称&#xff1a;Weaver e-cology ofsLogin.jsp 用户登陆绕过漏洞 English Name&#xff1a;Weaver e-cology ofsLogin.jsp User Login Bypass Vulnerability CVSS core: 9.3 影响资产数&#xff1a;92980 漏洞描述&#xff1a; 泛微协同管理应用平台&#xff08;e-co…

Python os模块详解

1. 简介 os就是“operating system”的缩写&#xff0c;顾名思义&#xff0c;os模块提供的就是各种 Python 程序与操作系统进行交互的接口。通过使用os模块&#xff0c;一方面可以方便地与操作系统进行交互&#xff0c;另一方面页也可以极大增强代码的可移植性。如果该模块中相…

【SQLserver】sqlserver数据库还原

这里的还原主要是指一个数据备份文件导入到本地 用到的工具&#xff1a; SQLServerManagement Studio18 1、打开本地库&#xff0c;在数据库右键&#xff0c;点击“还原数据库” 2、选择需要还原的文件&#xff0c;这里选设备&#xff0c;后面选择 3、弹窗点击添加按钮&am…

MyBatis中的别名机制

在我们使用MyBatis中的select语句时&#xff0c;需要指定resultType的值&#xff0c;即查询对象的类型&#xff0c;该值是对象的完整类名&#xff0c;看起来非常的繁琐&#xff0c;因此MyBatis中有了别名机制。 使用步骤 在mybatis-config.xml文件中添加< typeAliases >…

Vue--》探索Pinia:Vue状态管理的未来

目录 Pinia的讲解与使用 Pinia的安装与使用 store数据操作 解构store数据 actions-getters的使用 Pinia常用API 持久化插件 Pinia的讲解与使用 Pinia 是由 Eduardo San Martin Morote 创建的&#xff0c;这是一个轻量级的、使用 Vue3 Composition API 的状态管理库。Pi…

谷粒商城二十五springCloud之Sleuth+Zipkin 服务链路追踪

为什么用 分布式系统庞大而复杂&#xff0c;服务众多&#xff0c;调用关系网也非常复杂&#xff0c; 服务上线以后如果出现了某些错误&#xff0c;错误的异常就很难定位。一个请求可能调用了非常多的链路&#xff0c;我们需要知道到底哪一块儿出现了错误。 最终希望有一个链…

Java基础学习(16)多线程

Java基础学习多线程 一、多线程1.1 什么是多线程1.2 多线程的两个概念1.2.1 并发 1.3 多线程的实现方式1.4 多线程的成员方法1.5 线程的生命周期 二、线程安全1.6 同步方法1.7 锁lock1.8 死锁1.8 生产者和消费者 (等待唤醒机制)1.9 等待唤醒机制(阻塞队列方式实现&#xff09;1…

IntelliJ IDEA汉化

IntelliJ IDEA汉化 描述解决办法 描述 在开发过程中&#xff0c;我们想让界面现实为汉语&#xff0c;那么我们就需要对IDEA工具进行汉化&#xff0c;目前版本的IDEA汉化都非常简单&#xff0c;请看下述实现步骤。 解决办法 下述汉化方法&#xff0c;全家桶软件都通用。 打开…

attention机制

油管attention机制解释 油管的attention机制视频。 基础形态 如下图所示&#xff0c;假设现在有4个向量&#xff0c; v 1 v_1 v1​到 v 4 v_4 v4​。我们以 v 3 v_3 v3​为视角&#xff0c;看它是怎么得到 y 3 y_3 y3​的。首先用 v 3 v_3 v3​和全部4个向量做点乘&#xff…

不依赖硬件,可以无限扩展的闹钟组件

在实际的开发项目中&#xff0c;很多时候我们需要定时的做一些事情&#xff0c;举例&#xff1a; ①路上的路灯&#xff0c;每天晚上6:00准时打开&#xff0c;每天早上6:00准时关闭&#xff1b;②定时闹钟&#xff0c;起床上班。这些行为其实都是定时任务–闹钟。 大部分单片机…

【MySql】数据库设计过程

目录 概念数据库设计&#xff1a; 逻辑数据库设计&#xff1a; 物理数据库设计&#xff1a; ->需求分析&#xff08;收集需求和理解需求,“源”&#xff09; ->概念数据库设计&#xff08;建立概念模型:"E-R图/IDEF1X"&#xff09; ->逻辑数据库设计&…

idle_in_transaction_session_timeout idle_session_timeout

这两个参数都是用来控制PostgreSQL数据库中会话的超时时间的。 idle_in_transaction_session_timeout idle_in_transaction_session_timeout参数用于控制在事务中处于空闲状态的会话的超时时间。如果一个会话在事务中处于空闲状态超过了指定的时间&#xff0c;则该会话将被终…

Rocky Linux 9.2 正式版发布 - RHEL 下游免费发行版

Rocky Linux 由 CentOS 项目的创始人 Gregory Kurtzer 领导。 请访问原文链接&#xff1a;https://sysin.org/blog/rocky-linux-9/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#xff1a;sysin.org 以社区方式驱动的企业 Linux Rocky Linux 是…

Android Studio 基础 之 使用 okhttp 长连接,流式获取数据的方法简单整理了

Android Studio 基础 之 使用 okhttp 长连接&#xff0c;流式获取数据的方法简单整理了 目录 Android Studio 基础 之 使用 okhttp 长连接&#xff0c;流式获取数据的方法简单整理了 一、简单介绍 二、实现原理 三、注意事项 四、效果预览 五、实现关键 六、关键代码 七…