DP-GAN-生成器代码

news2024/11/17 6:35:12

在train文件中,对生成器和判别器分别进行更新,根据loss的不同,分别计算对于的损失:

loss_G, losses_G_list = model(image, label, "losses_G", losses_computer)
loss_D, losses_D_list = model(image, label, "losses_D", losses_computer)

在model中:

from models.sync_batchnorm import DataParallelWithCallback
import models.generator as generators
import models.discriminator as discriminators
import os
import copy
import torch
import torch.nn as nn
from torch.nn import init
import models.losses as losses
class DP_GAN_model(nn.Module):
    def __init__(self, opt):
        super(DP_GAN_model, self).__init__()
        self.opt = opt
        #--- generator and discriminator ---
        self.netG = generators.DP_GAN_Generator(opt).cuda()
        if opt.phase == "train" or opt.phase == "eval":
            self.netD = discriminators.DP_GAN_Discriminator(opt)
        self.print_parameter_count()
        self.init_networks()
        #--- EMA of generator weights ---
        with torch.no_grad():
            self.netEMA = copy.deepcopy(self.netG) if not opt.no_EMA else None
        #--- load previous checkpoints if needed ---
        self.load_checkpoints()
        #--- perceptual loss ---#
        if opt.phase == "train":
            if opt.add_vgg_loss:
                self.VGG_loss = losses.VGGLoss(self.opt.gpu_ids)
        self.GAN_loss = losses.GANLoss()
        self.MSELoss = nn.MSELoss(reduction='mean')

    def align_loss(self, feats, feats_ref):
        loss_align = 0
        for f, fr in zip(feats, feats_ref):
            loss_align += self.MSELoss(f, fr)
        return loss_align

    def forward(self, image, label, mode, losses_computer):
        # Branching is applied to be compatible with DataParallel
        if mode == "losses_G":
            loss_G = 0
            fake = self.netG(label)
            output_D, scores, feats = self.netD(fake)
            _, _, feats_ref = self.netD(image)
            loss_G_adv = losses_computer.loss(output_D, label, for_real=True)
            loss_G += loss_G_adv
            loss_ms = self.GAN_loss(scores, True, for_discriminator=False)
            loss_G += loss_ms.item()
            loss_align = self.align_loss(feats, feats_ref)
            loss_G += loss_align
            if self.opt.add_vgg_loss:
                loss_G_vgg = self.opt.lambda_vgg * self.VGG_loss(fake, image)
                loss_G += loss_G_vgg
            else:
                loss_G_vgg = None
            return loss_G, [loss_G_adv, loss_G_vgg]

        if mode == "losses_D":
            loss_D = 0
            with torch.no_grad():
                fake = self.netG(label)
            output_D_fake, scores_fake, _ = self.netD(fake)
            loss_D_fake = losses_computer.loss(output_D_fake, label, for_real=False)
            loss_ms_fake = self.GAN_loss(scores_fake, False, for_discriminator=True)
            loss_D += loss_D_fake + loss_ms_fake.item()
            output_D_real, scores_real, _ = self.netD(image)
            loss_D_real = losses_computer.loss(output_D_real, label, for_real=True)
            loss_ms_real = self.GAN_loss(scores_real, True, for_discriminator=True)
            loss_D += loss_D_real + loss_ms_real.item()
            if not self.opt.no_labelmix:
                mixed_inp, mask = generate_labelmix(label, fake, image)
                output_D_mixed, _, _ = self.netD(mixed_inp)
                loss_D_lm = self.opt.lambda_labelmix * losses_computer.loss_labelmix(mask, output_D_mixed, output_D_fake,
                                                                                output_D_real)
                loss_D += loss_D_lm
            else:
                loss_D_lm = None
            return loss_D, [loss_D_fake, loss_D_real, loss_D_lm]

        if mode == "generate":
            with torch.no_grad():
                if self.opt.no_EMA:
                    fake = self.netG(label)
                else:
                    fake = self.netEMA(label)
            return fake

        if mode == "eval":
            with torch.no_grad():
                pred, _, _ = self.netD(image)
            return pred

    def load_checkpoints(self):
        if self.opt.phase == "test":
            which_iter = self.opt.ckpt_iter
            path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "models", str(which_iter) + "_")
            if self.opt.no_EMA:
                self.netG.load_state_dict(torch.load(path + "G.pth"))
            else:
                self.netEMA.load_state_dict(torch.load(path + "EMA.pth"))
        elif self.opt.phase == "eval":
            which_iter = self.opt.ckpt_iter
            path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "models", str(which_iter) + "_")
            self.netD.load_state_dict(torch.load(path + "D.pth"))
        elif self.opt.continue_train:
            which_iter = self.opt.which_iter
            path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "models", str(which_iter) + "_")
            self.netG.load_state_dict(torch.load(path + "G.pth"))
            self.netD.load_state_dict(torch.load(path + "D.pth"))
            if not self.opt.no_EMA:
                self.netEMA.load_state_dict(torch.load(path + "EMA.pth"))

    def print_parameter_count(self):
        if self.opt.phase == "train":
            networks = [self.netG, self.netD]
        else:
            networks = [self.netG]
        for network in networks:
            param_count = 0
            for name, module in network.named_modules():
                if (isinstance(module, nn.Conv2d)
                        or isinstance(module, nn.Linear)
                        or isinstance(module, nn.Embedding)):
                    param_count += sum([p.data.nelement() for p in module.parameters()])
            print('Created', network.__class__.__name__, "with %d parameters" % param_count)

    def init_networks(self):
        def init_weights(m, gain=0.02):
            classname = m.__class__.__name__
            if classname.find('BatchNorm2d') != -1:
                if hasattr(m, 'weight') and m.weight is not None:
                    init.normal_(m.weight.data, 1.0, gain)
                if hasattr(m, 'bias') and m.bias is not None:
                    init.constant_(m.bias.data, 0.0)
            elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                init.xavier_normal_(m.weight.data, gain=gain)
                if hasattr(m, 'bias') and m.bias is not None:
                    init.constant_(m.bias.data, 0.0)

        if self.opt.phase == "train":
            networks = [self.netG, self.netD]
        else:
            networks = [self.netG]
        for net in networks:
            net.apply(init_weights)


def put_on_multi_gpus(model, opt):
    if opt.gpu_ids != "-1":
        gpus = list(map(int, opt.gpu_ids.split(",")))
        model = DataParallelWithCallback(model, device_ids=gpus).cuda()
    else:
        model.module = model
    assert len(opt.gpu_ids.split(",")) == 0 or opt.batch_size % len(opt.gpu_ids.split(",")) == 0
    return model


def preprocess_input(opt, data):
    data['label'] = data['label'].long()
    if opt.gpu_ids != "-1":
        data['label'] = data['label'].cuda()
        data['image'] = data['image'].cuda()
    label_map = data['label']
    bs, _, h, w = label_map.size()
    nc = opt.semantic_nc
    if opt.gpu_ids != "-1":
        input_label = torch.cuda.FloatTensor(bs, nc, h, w).zero_()
    else:
        input_label = torch.FloatTensor(bs, nc, h, w).zero_()
    input_semantics = input_label.scatter_(1, label_map, 1.0)
    return data['image'], input_semantics


def generate_labelmix(label, fake_image, real_image):
    target_map = torch.argmax(label, dim = 1, keepdim = True)
    all_classes = torch.unique(target_map)
    for c in all_classes:
        target_map[target_map == c] = torch.randint(0,2,(1,)).cuda()
    target_map = target_map.float()
    mixed_image = target_map*real_image+(1-target_map)*fake_image
    return mixed_image, target_map

首先看生成器流程:
标签输入到生成器中得到fake image,fake image 和 real image 共同输入到判别器中得到中间变量输出,接着分别计算四个损失。我们需要明白生成器和辨别器模型的搭建,损失计算过程。
在这里插入图片描述
首先是生成器的组成:
在这里插入图片描述
在这里插入图片描述
输入标签大小是(b,c,h,w),首先z等于一个正态分布的随机数,大小为(b,64),接着view为(b,64,1,1),再扩张到(b,64,h,w)和(b,c,h,w)沿着通道维度拼接起来。将拼接的结果上采样到W和H大小。
在这里插入图片描述
其中在CityscapesDataset指定了:
在这里插入图片描述
则w=512//2^5=16,h=16/2=8.
在这里插入图片描述
令s等于input label,输入到pyrmid中,生成结果添加到列表中。

self.seg_pyrmid = nn.ModuleList([])
        if not self.opt.no_3dnoise:
            self.fc = nn.Conv2d(self.opt.semantic_nc + self.opt.z_dim, 16 * ch, 3, padding=1)
            self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(self.opt.semantic_nc + self.opt.z_dim, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True)))
        else:
            self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * ch, 3, padding=1)
            self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(self.opt.semantic_nc, 32, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)))

        self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)))
        for i in range(len(self.channels)-2):
            self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)))         

而pyrmid是一个modulist,便利添加的每一个module,生成一个结果:
首先将标签图和噪声拼接起来经过一个3x3卷积,输出通道变为32,再经过一个1x1卷积,输出通道变为64.再经过经过5个步长为2的3x3卷积,下采样32倍。这样pyrmid列表中就有7个结果。
接着将已经采样的x输入到Fc中,输出通道是1024.这里需要清楚两个变量x,和pyrmid.
1:x是输入下采样到(H,W)大小的label+noise.
2:pyrmid是储存经过七次(五次下采样)卷积之后的label+noise。
接着将pyrmid最后一个值采样到x的大小。然后和pyrmid的第i个值拼接在一起。
在这里插入图片描述
对应于:
在这里插入图片描述
每拼接一次生成的值和经过Fc之后的label+noise共同作为输入:
在这里插入图片描述
输入到SPADE块中:
首先要判断SPAD的两个参数即输入通道是否相等。
在这里插入图片描述
在这里插入图片描述
如果相等就输入到SPADE模块,如果不等令变量等于输入值。
在这里插入图片描述
其中最后一个参数是类别值:在Cityscape数据集设定语义标签是34类。有一类是未知,加上噪声的64个通道。
在这里插入图片描述
SPADE:

class SPADE(nn.Module):
    def __init__(self, opt, norm_nc, label_nc):
        super().__init__()
        self.first_norm = get_norm_layer(opt, norm_nc)
        ks = opt.spade_ks
        nhidden = 128
        pw = ks // 2
        #self.mlp_shared = nn.Sequential(
        #    nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
        #    nn.ReLU()
        #)
        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)

    def forward(self, x, segmap):
        normalized = self.first_norm(x)
        #segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
        #actv = self.mlp_shared(segmap)
        actv = segmap
        gamma = self.mlp_gamma(actv)
        beta = self.mlp_beta(actv)
        out = normalized * (1 + gamma) + beta
        return out

公式:
在这里插入图片描述
首先X经过一个norm层,即为分布式BN。
在这里插入图片描述
在这里插入图片描述
接着使用卷积学习β和γ。
在这里插入图片描述
在这里插入图片描述
卷积核大小都为3,padding为1。
接着经过bn之后的变量和γ相乘在和β相加,再和经过归一化之后的x相加。
在这里插入图片描述
接着:x和seg经过相同的norm操作。再进过一个LeakyReLU,再进行一个卷积层。中间有个midlayer过渡。
在这里插入图片描述
在这里插入图片描述
输出的结果经过一个跳连接得到最后输出。
在这里插入图片描述
经过SPADE之后的输出上采样两倍作为输入输入到下一个SPADE中。
最终输出一个通道为3的RGB图片。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/816926.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

环形链表 II(JS)

环形链表 II 题目 给定一个链表的头节点 head ,返回链表开始入环的第一个节点。 如果链表无环,则返回 null。 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表中的环,…

企业数字化转型失败率达80%,面临哪些挑战?应该如何规划?

随着数字化在社会的飞速发展,人们的生活工作娱乐等方方面面都已经被数字化占领,数字化所衍生出的数字经济更是成为高速增长的国民经济支柱,而数据作为“副产品”也成功进化为第五大生产要素,发挥出巨大的价值,变成了个…

智慧展馆展厅人员定位系统解决方案:提升参观体验与管理效率

随着数字化时代的到来,展馆和展厅逐渐成为人们了解文化、艺术、科技等领域的重要窗口。 然而,传统的展馆和展厅存在着一些问题,例如参观者迷路、信息获取不及时、管理效率低下等。 为了提升参观体验和管理效率,研发智慧展馆展厅…

测试|Selenium之WebDriver常见API使用

测试|Selenium之WebDriver常见API使用 文章目录 测试|Selenium之WebDriver常见API使用1.定位对象(findElement)css定位xpath定位css选择器语法:xpath语法:校验结果 2.操作对象鼠标点击对象在对象上模拟按键输入clear清除对象输入的文本内容su…

TCP三次握手和四次挥手以及11种状态(一)

1、三次握手 置位概念:根据TCP的包头字段,存在3个重要的标识ACK、SYN、FIN ACK:表示验证字段 SYN:位数置1,表示建立TCP连接 FIN:位数置1,表示断开TCP连接 三次握手过程说明: 1、…

【自动化剧本】Role角色

目录 一、Roles模块1.1roles的目录结构1.2roles 内各目录含义解释1.3在一个 playbook 中使用 roles 的步骤 二、使用Role编写LNMP剧本2.1 搭建Nginx角色2.2搭建Mysql角色2.3搭建php角色2.4lnmp剧本 一、Roles模块 roles用于层次性、结构化地组织playbook。roles能够根据层次型结…

实战!聊聊工作中使用了哪些设计模式

实战!聊聊工作中使用了哪些设计模式 策略模式 业务场景 假设有这样的业务场景,大数据系统把文件推送过来,根据不同类型采取不同的解析方式。多数的小伙伴就会写出以下的代码: if(type"A"){//按照A格式解析}else if(t…

力扣468 验证IP地址

ipv4地址:1.必须是四个非空子串 2.每个非空子串不含前导零 3.子串里字符只能是0~255 ipv6地址:1.必须是八个非空子串 2。每段非空串得长度是否在1~4之间,且不含0-9,a-f,A-F之外得字符。 3.同时0-9也不允许含前导零 cl…

HTML5+CSS3小实例:带标题的3D多米诺人物卡片

实例:带标题的3D多米诺人物卡片 技术栈:HTML+CSS 效果: 源码: 【html】 <!DOCTYPE html> <html><head><meta http-equiv="content-type" content="text/html; charset=utf-8"><meta name="viewport" content…

vue关闭ESlint

在 vue.config.js里边写上这一句代码 lintOnsave:false

Maven如何创建Java web项目(纯干货版)!!!

1.创建Maven项目。 2.创建完成后会来到这个界面。 3.在src/main目录下&#xff0c;建立webapp / WEB-INF/web.xml文件&#xff0c;并在web.xml文件中写入以下内容&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <web-app xmlns"http…

AI技术快讯:清华开源ChatGLM2双语对话语言模型

ChatGLM2-6B是一个开源项目&#xff0c;提供了ChatGLM2-6B模型的代码和资源。根据提供的搜索结果&#xff0c;以下是对该项目的介绍&#xff1a; 论文&#xff1a;https://arxiv.org/pdf/2103.10360.pdf ChatGLM2-6B是一个开源的双语对话语言模型&#xff0c;是ChatGLM-6B模…

批量生成ChunJun json任务脚本

最近在研究chunjun&#xff0c;它是一款稳定、易用、高效、批流一体的数据集成框架。一直在用chunjun做数据抽取测试&#xff0c;json任务重复地在写&#xff0c;感觉十分浪费时间&#xff0c;于是想写个自动生成json脚本。 1.设计模板 模板通过excel设计&#xff0c;主要记录…

【phaser微信抖音小游戏开发004】往画布上增加文本以及文本的操作

我们在states中创建st004.js的类&#xff0c;或者将states中的index.js直接重命名为st004.js&#xff0c;把里面的类名也修改为st004.如下图 在main.js中&#xff0c;引入st004,并设置启用的state为st004。如下图 接下来到states/st004里面&#xff0c;在create里面将文本修改一…

为什么不推荐用 index 做 key

之所以添加key属性&#xff0c;究其根本是因 diff算法。而在业务开发过程中特别是使用map, forEach 等遍历函数的时候往往随手就将index做为组件的key. 那么:key 到底有什么用&#xff1f; 当 Vue.js 用 v-for 正在更新已渲染过的元素列表时&#xff0c;它默认用就地复用策略 …

IP 工具

什么是IP 工具 IP 工具是用于轻松扫描和排除网络 IP 地址空间故障的网络工程工具。IP 工具使网络管理员能够审核、跟踪和监视 IP 地址、子网以及使用 IP 的设备和主机的性能。这个全面的网络工程工具集包括高级 IP 工具&#xff0c;如 Ping、系统资源管理器、MAC 地址解析器和…

看表情包学C语言 ——局部优先原则

&#x1f517; 【C语言趣味教程】专栏介绍&#x1f448; 猛戳了解&#xff01;&#xff01;&#xff01; Ⅰ. 作用域&#xff08;Scope&#xff09; 0x00 引入&#xff1a;什么是作用域&#xff1f; 变量和常量在程序中都是有作用范围的&#xff0c;这个范围我们称之为变量的 …

40k的offer拿到手!爽歪歪~

据说周一和就业喜报更配&#xff1f;快跟着我一起来看看2023上半年黑马软件测试学科的就业喜报&#xff1a; 从黑马软件测试学科的就业中&#xff0c;我们也能看到软件测试对于企业的重要性&#xff0c;一点也不比程序员差&#xff0c;他们拿到的薪资也能和程序员的高薪媲美&am…

Netty 执行了多次channelReadComplete()却没有执行ChannelRead()

[TOC](Netty 执行了多次channelReadComplete()) Survive by day and develop by night. talk for import biz , show your perfect code,full busy&#xff0c;skip hardness,make a better result,wait for change,challenge Survive. happy for hardess to solve denpendies.…

JAVA的回调机制、同步/异步调用

一、同步调用 同步调用是最基本的调用方式。类A的a()方法调用类B的b()方法&#xff0c;类A的方法需要等到B类的方法执行完成才会继续执行。如果B的方法长时间阻塞&#xff0c;就会导致A类方法无法正常执行下去。 二、异步调用 如果A调用B&#xff0c;B的执行时间比较长&#…