pix2pixHD---loss---损失函数

news2024/12/26 11:41:38

在Pix2PixHDModel代码中首先定义损失:
在这里插入图片描述
首先看第一个:输入的两个参数use_gan_feat_loss, use_vgg_loss默认为false,则前缀有not,所以两个参数都是True。

    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
            return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
        return loss_filter

则flag里面有五个True,zip函数将每一个值和True组合为一个元组。一共有五个。
接着看第二个:
在这里插入图片描述

class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                 tensor=torch.FloatTensor):
        super(GANLoss, self).__init__()
        self.real_label = target_real_label
        self.fake_label = target_fake_label
        self.real_label_var = None
        self.fake_label_var = None
        self.Tensor = tensor
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor

    def __call__(self, input, target_is_real):
        if isinstance(input[0], list):
            loss = 0
            for input_i in input:
                pred = input_i[-1]
                target_tensor = self.get_target_tensor(pred, target_is_real)
                loss += self.loss(pred, target_tensor)
            return loss
        else:            
            target_tensor = self.get_target_tensor(input[-1], target_is_real)
            return self.loss(input[-1], target_tensor)

通过call函数调用,有两个输入,这里用了一个for循环,因为在辨别器中我们的输出列表里面有五个值。将pred值和target值输入到get_target_tensor得到target_tensor 。

    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor

这个地方就是获得一个和input大小一样的矩阵,矩阵的值由1或者0组成。
在这里插入图片描述
最后将pred和target进行损失计算,计算的结果进行累加:损失函数采用的MSELoss.
在这里插入图片描述
在这里插入图片描述
除了GANloss之外还有一个L1loss和VGGloss。
在这里插入图片描述
如果使用VGGloss即feature matching loss的话:

class VGGLoss(nn.Module):
    def __init__(self, gpu_ids):
        super(VGGLoss, self).__init__()        
        self.vgg = Vgg19().cuda()
        self.criterion = nn.L1Loss()
        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]        
    def forward(self, x, y):              
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())        
        return loss

将真实图片和生成图片输入到VGG19中,得到的值进行L1损失计算,每一个值赋予一个权重。
在这里插入图片描述
在model中使用损失进行计算:

        pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
        loss_D_fake = self.criterionGAN(pred_fake_pool, False)        

        # Real Detection and Loss        
        pred_real = self.discriminate(input_label, real_image)
        loss_D_real = self.criterionGAN(pred_real, True)

        # GAN loss (Fake Posibility Loss)
        pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        
        loss_G_GAN = self.criterionGAN(pred_fake, True)   

对于辨别器来说,当输入的是假图片即生成的图片时,我们希望他输出为0,当输入的是真实图片时,我们希望辨别器输出为1.对于生成器来说,我们希望辨别器不能预测出来,即生成的都为1.这是三个GANLoss。
然后将真实图片输入到辨别器的输出和假图片输入到辨别器的输出进行一个l1损失计算。
在这里插入图片描述
最后将生成器生成的加图片和真实图片输入到VGG中得到的结果进行一个VGGloss计算。画了一下损失计算流程。
在这里插入图片描述
最后将所有损失输入到loss_filter,每一个和true组成一个元组,返回一个大列表,同时model在训练时候还输出另一个输出None,在infer时候输出假的图。
返回train中,整个model就搭建完毕。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/595053.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PCIE知识点-022:PCIe 参考时钟结构

图1:参考时钟结构示意图[4] 1. Common Refclk Architecture Common Refclk Architecture,即同源参考时钟架构,PCIe收发设备共用一个时钟源,是目前是使用最为广泛的方案。 缺点: 对于适用同一 Common Clock 作为参考时…

第四章 运行时数据区

文章目录 前言一、🚗 双亲委派机制1、 问题的引出:是否会被外来程序对系统进行破坏2、总结3、双亲委派的优势4、沙箱安全机制5、其他 二、🚒 运行时数据区线程 三、🛺 PC 寄存器概述(记录下一条程序指令的地址&#xf…

Vulnhub: dpwwn: 1靶机

kali:192.168.111.111 靶机:192.168.111.131 信息收集 端口扫描 nmap -A -sC -v -sV -T5 -p- --scripthttp-enum 192.168.111.131 爆破出mysql的root用户为空密码 hydra -l root -P /usr/share/wordlists/rockyou.txt 192.168.111.131 -s 3306 mysq…

UI自动化 Xpath定位必知必会

目录 javascript xpath定位 定位单个元素: 定位多个元素: 验证xpath定位语法是否OK 尽量使用模糊匹配定位元素 模糊匹配contains 使用关联文本值定位 text() 在UI自动化测试中用的最频繁的就是xpath定位了,所以用好xpath定位至关重要&…

关于中断的几个小问题

1. intel 8259芯片中的IRQ2和int2的区别是什么? 答曰:IRQ2是芯片上的引脚,而int2是中断向量表的第2项,两者有很大区别。 Intel8259A芯片的中断引脚分别为: 主片: 0:8254时钟 1:键盘 …

【python代码】Kittle数据集的ground truth生成深度图攻略|彩色深度图|代码无恼运行

目录 1.明确KITTLE数据集特性 2.选择groundtruth 3.转换深度图 4.转换彩色深度图 1.明确KITTLE数据集特性 KITTI数据集包含了来自车载传感器的多模态数据,包括激光雷达、摄像头和GPS/惯性测量单元(IMU)等。该数据集主要采集于城市环境中…

B3645 数列前缀和 20

题目描述 给定一个长度为 n 的数列 a,请回答 q 次询问,每次给定 l,r,请求出 ()​mod p 的值,其中 p1,145,141。 输入格式 第一行是两个整数,依次表示数列长度 n 和询问次数 q。 第二行有 n 个整数,第 …

Qt下QTcpServer服务端识别多个QTcpSocket客户端

文章目录 Qt官方文档编写QTcpServerDemo和QTcpSocketDemo实现QTcpServerDemo实现QTcpSocketDemo 使用windeployqt生成程序运行所需依赖文件 Qt官方文档 QTcpSocket Class :https://doc.qt.io/qt-5/qtcpsocket.html QAbstractSocket Class:https://doc.q…

分布式RPC框架Dubbo详解

目录 1.架构演进 1.1 单体架构 1.2 垂直架构 1.3 分布式架构 1.4 SOA架构 1.5 微服务架构 2.RPC框架 2.1 RPC基本概念介绍 2.1.1 RPC协议 2.1.2 RPC框架 2.1.3 RPC与HTTP、TCP/ UDP、Socket的区别 2.1.4 RPC的运行流程 2.1.5 为什么需要RPC 2.2 Dubbo 2.2.1 Dub…

Redis的网络模型(未完成)

Redis是单线程还是多线程? 1)如果只是针对于Redis的核心业务部分(命令处理),答案是单线程 2)如果是说整个redis,那么就是多线程 在Redis的版本迭代过程中,在两个非常重要的时间节点上引入了对多线程的支持: 1)在Redis4.0版本中&am…

redis第二章-第一课-持久化rdb和aof以及混合模式

RDB快照 1.在默认情况下, Redis 将内存数据库快照保存在名字为 dump.rdb 的二进制文件中,默认存储在redis的当前目录下,也就是和redis.conf同级目录下,可以修改位置 2.你可以对 Redis 进行设置, 让它在“ N 秒内数据…

小白怎么入门CTF,看这个就够了(附学习笔记、靶场、工具包下载)

CTF靶场:CTF刷题,在校生备战CTF比赛,信安入门、提升自己、丰富简历之必备(一场比赛打出好成绩,可以让你轻松进大厂,如近期的各种CTF杯),在职人员可以工作意外提升信安全技能。 渗透…

00后太卷了,搞的我们这些老油条太难受了......

前几天我们公司一下子也来了几个新人,这些年前人是真能熬啊,本来我们几个老油子都是每天稍微加会班就打算走了,这几个新人一直不走,搞得我们也不好走。 2023年春招结束了,最近内卷严重,各种跳槽裁员&#x…

rk3568 点亮LCD(DP)

rk3568 rk3399 Android11/12 适配 DP DisplayPort(简称DP)是第一个依赖数据包化数据传输技术的显示通信端口。是一个由PC及芯片制造商联盟开发,视频电子标准协会标准化的数字式视频接口标准。主要用于视频源与显示器等设备的连接&#xff0c…

进程管理(笔记)

如果对内存寻址熟悉的话, 或者认真看过上一节的内容: 内存管理之内存寻址: https://blog.csdn.net/qq_40482358/article/details/130868188. 那么对linux系统中的进程管理应该已经有一个初步的认识了: cr3作为一个控制寄存器, 描述当前进程的页目录的物理内存基地址, 当进程切换…

SpringCloud Ribbon 学习

SpringCloud Ribbon 学习 文章目录 SpringCloud Ribbon 学习1. Ribbon 是什么?2. LB(Load Balance)3 Ribbon 架构图&机制4 Ribbon 常见负载均衡算法5 测试 1. Ribbon 是什么? Spring Cloud Ribbon 是基于 Netflix Ribbon 实现的一套客户端 负载均衡…

Docker容器与虚拟机(VM)大对比

Docker是一个开源应用容器引擎。Docker可以将应用程序与基本架构分开,从而快速交付软件。 传统虚拟机的运行需要占用较高的资源,包括磁盘空间、内存和处理器性能。每个虚拟机都需要完整的操作系统和应用程序副本,这在资源利用和启动时间上存…

好玩!AI文字RPG游戏;播客进入全AI时代?LangChain项目实践手册;OpenAI联创科普GPT | ShowMeAI日报

👀日报&周刊合集 | 🎡生产力工具与行业应用大全 | 🧡 点赞关注评论拜托啦! 🤖 Microsoft Build 中国黑客松挑战赛,进入AI新纪元 近期,伴随着人工智能的新一轮浪潮,Hackathon (黑…

python+pytest接口自动化之HTTP协议基础

目录 HTTP协议简介 HTTP协议特点 HTTP接口请求方法 HTTP与HTTPS区别 HTTP与TCP/IP区别 HTTP请求过程 总结 HTTP协议简介 HTTP 即 HyperText Transfer Protocol(超文本传输协议),是互联网上应用最为广泛的一种网络协议。所有的 WWW 文件…

WePY小程序框架实践指南

在小程序开发中,提高开发效率、优化代码质量和增强用户体验是每位开发者都追求的目标。 下面我们从小程序开发框架来讲讲如何帮助开发提效,其中 WePY 是一个稍微冷门一些的开发框架,基于 Vue.js 的小程序开发的框架,提供了更好的…