SEAN代码(1)

news2025/1/20 10:54:49

代码地址
首先定义一个trainer。

trainer = Pix2PixTrainer(opt)

在Pix2PixTrainer内部,首先定义Pix2PixModel模型。

self.pix2pix_model = Pix2PixModel(opt)

在Pix2PixModel内部定义生成器,判别器。

self.netG, self.netD, self.netE = self.initialize_networks(opt)

在initialize_networks内部定义功能。

netG = networks.define_G(opt)
netD = networks.define_D(opt) if opt.isTrain else None
netE = networks.define_E(opt) if opt.use_vae else None

首先看生成器:

def define_G(opt):
    netG_cls = find_network_using_name(opt.netG, 'generator')#netG=spade
    return create_network(netG_cls, opt)

输入的参数是opt.netG,在option中对应的是spade。在find_network_using_name中:

def find_network_using_name(target_network_name, filename):#spade,generator
    target_class_name = target_network_name + filename#spadegenerator
    module_name = 'models.networks.' + filename#models.networks.generator
    network = util.find_class_in_module(target_class_name, module_name)#<class 'models.networks.generator.SPADEGenerator'>
    assert issubclass(network, BaseNetwork), \
        "Class %s should be a subclass of BaseNetwork" % network

    return network

根据target_network_name和对应的filename输入到find_class_in_module中:

def find_class_in_module(target_cls_name, module):
    target_cls_name = target_cls_name.replace('_', '').lower()#spadegenerator
    clslib = importlib.import_module(module)#import_module()返回指定的包或模块
    cls = None
    for name, clsobj in clslib.__dict__.items():
        if name.lower() == target_cls_name:
            cls = clsobj

    if cls is None:
        print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
        exit(0)

    return cls

我们通过import_module函数载入module这个模块,module对应的是models.networks.generator。即clslib 就是generator文件中的类。我们遍历clslib的字典,如果name等于spadegenerator,令cls = clsobj。
即network等于cls。

network = util.find_class_in_module(target_class_name, module_name)

这里有两个语法问题:
①:导入importlib,调用import_module()方法,根据输入的字符串可以获得模块clslib ,clslib 可以调用models.networks.generator文件下所有的属性和方法。
在这里插入图片描述
在generator内部是:
在这里插入图片描述可以通过clslib.SPADEGenerator来实例化SPADEGenerator,然后再调用SPADEGenerator内部的方法。
举个例子:新建三个文件。
在这里插入图片描述
train:
在这里插入图片描述
用不到test,在tt文件内部中导入train中的类s。
在这里插入图片描述
因为是同级目录,直接导入字符串train即可,如果不在同级目录,需要导入前一个目录。
接着a就会变成一个module,即train。然后实例化train文件夹下的类s。最后调用类s的方法kill和qqq。
输出:
在这里插入图片描述
②: dict,该属性可以用类名或者类的实例对象来调用,用**类名直接调用 dict,会输出该由类中所有类属性组成的字典;**而使用类的实例对象调用 dict,会输出由类中所有实例属性组成的字典。
参考
这里SPADEGenerator继承了BaseNetwork,对于具有继承关系的父类和子类来说,父类有自己的 dict,同样子类也有自己的 dict,它不会包含父类的 dict
例子:按上面的例子,a是一个module,查看a的__dict__:
在这里插入图片描述
输出:
在这里插入图片描述
回到代码中:我们输出的network就是类<class ‘models.networks.generator.SPADEGenerator’>。
下一步我们创建网络:在这里插入图片描述
在这里插入图片描述
cls对应的是SPADEGenerator网络。
在SPADE中:

"""
Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
from models.networks.architecture import ResnetBlock as ResnetBlock
from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
from models.networks.architecture import Zencoder

class SPADEGenerator(BaseNetwork):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.set_defaults(norm_G='spectralspadesyncbatch3x3')
        parser.add_argument('--num_upsampling_layers',
                            choices=('normal', 'more', 'most'), default='normal',
                            help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")

        return parser

    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        self.Zencoder = Zencoder(3, 512)


        self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)

        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='head_0')

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_0')
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_1')

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt, Block_Name='up_0')
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt, Block_Name='up_1')
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt, Block_Name='up_2')
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt, Block_Name='up_3', use_rgb=False)

        final_nc = nf

        if opt.num_upsampling_layers == 'most':
            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt, Block_Name='up_4')
            final_nc = nf // 2

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)
        #self.up = nn.Upsample(scale_factor=2, mode='bilinear')
    def compute_latent_vector_size(self, opt):
        if opt.num_upsampling_layers == 'normal':#默认
            num_up_layers = 5
        elif opt.num_upsampling_layers == 'more':
            num_up_layers = 6
        elif opt.num_upsampling_layers == 'most':
            num_up_layers = 7
        else:
            raise ValueError('opt.num_upsampling_layers [%s] not recognized' %
                             opt.num_upsampling_layers)

        sw = opt.crop_size // (2**num_up_layers)#256//32=16
        sh = round(sw / opt.aspect_ratio)#8

        return sw, sh

    def forward(self, input, rgb_img, obj_dic=None):
        seg = input
        x = F.interpolate(seg, size=(self.sh, self.sw))#(16,16)
        x = self.fc(x)#(b,1024,16,16)

        style_codes = self.Zencoder(input=rgb_img, segmap=seg)
        x = self.head_0(x, seg, style_codes, obj_dic=obj_dic)

        x = self.up(x)
        x = self.G_middle_0(x, seg, style_codes, obj_dic=obj_dic)

        if self.opt.num_upsampling_layers == 'more' or \
           self.opt.num_upsampling_layers == 'most':
            x = self.up(x)

        x = self.G_middle_1(x, seg, style_codes,  obj_dic=obj_dic)

        x = self.up(x)
        x = self.up_0(x, seg, style_codes, obj_dic=obj_dic)
        x = self.up(x)
        x = self.up_1(x, seg, style_codes, obj_dic=obj_dic)
        x = self.up(x)
        x = self.up_2(x, seg, style_codes, obj_dic=obj_dic)
        x = self.up(x)
        x = self.up_3(x, seg, style_codes,  obj_dic=obj_dic)

        # if self.opt.num_upsampling_layers == 'most':
        #     x = self.up(x)
        #     x= self.up_4(x, seg, style_codes,  obj_dic=obj_dic)

        x = self.conv_img(F.leaky_relu(x, 2e-1))
        x = F.tanh(x)
        return x

首先计算潜在空间向量的大小:
在这里插入图片描述
接着计算style matrixST。对应文章的 :
在这里插入图片描述
在代码中:通过卷积,下采样,下采样,上采样,卷积。输出一个通道为512的向量。
在这里插入图片描述
接着是连续的四个上采样模块:
在这里插入图片描述
对应于:
在这里插入图片描述
在SPADEResnetBlock内部:使用ACE类定义了SEAN块。
在这里插入图片描述
在ACE内部定义了归一化的参数和噪声等。
在这里插入图片描述
下面设计python正则表达式,没学过,下去补。只能先用debug获得结果。
在这里插入图片描述
这里使用SynchronizedBatchNorm2d进行归一化:
在这里插入图片描述
γ和β通过卷积获得:
在这里插入图片描述
执行完上采样的四个SEAN块之后,最后进过一个卷积输出合成图像。这就是整个network的流程。
生成器打印参数:
在这里插入图片描述
接着是判别器:
按照生成器的逻辑,target_class_name=multiscalediscriminator,module_name=models.networks.discriminator
然后我们导入判别器模块。
在这里插入图片描述
在多尺度判别器内部:创建两个single_discriminator。
在这里插入图片描述
在这里插入图片描述
在单个判别器内部定义参数:在这里插入图片描述
定义判别器的输入:将label通道和RGB图片拼接后输入。
在这里插入图片描述
接着经过一个4x4大小步长为2的卷积,再经过两个步长为2的卷积,最后再经过输出通道为1,步长为1的卷积。将每一个卷积都注册到模型中。
在这里插入图片描述
即判别器由五个卷积组成。
将单个判别器注册到判别器中。注册两次,这样盘比起由10个卷积组成,且都有对应的吗名称。
在这里插入图片描述

MultiscaleDiscriminator(
  (discriminator_0): NLayerDiscriminator(
    (model0): Sequential(
      (0): Conv2d(16, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model1): Sequential(
      (0): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model2): Sequential(
      (0): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model3): Sequential(
      (0): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    )
  )
  (discriminator_1): NLayerDiscriminator(
    (model0): Sequential(
      (0): Conv2d(16, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model1): Sequential(
      (0): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model2): Sequential(
      (0): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model3): Sequential(
      (0): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    )
  )
)

这样生成器判别器都狗仔完毕,netE为空。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/966848.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

11.Redis的慢操作之rehash

Redis为什么快 它接收到一个键值对操作后&#xff0c;能以微秒级别的速度找到数据&#xff0c;并快速完成操作。 数据库这么多&#xff0c;为啥 Redis 能有这么突出的表现呢&#xff1f; 内存数据结构 一方面&#xff0c;这是因为它是内存数据库&#xff0c;所有操作都在内存上…

Redis—常用数据结构

Redis—常用数据结构 &#x1f50e;数据结构与内部编码 Redis 中常用的数据结构包括 Strings—字符串Hashes—哈希表Lists—列表Sets—集合Sorted sets—有序集合 Redis 底层在实现上述数据结构时, 会在源码层面针对上述实现进行特定优化, 以达到节省时间 / 节省空间的效果 …

卡片介绍、EMV卡组织、金融认证---安全行业基础篇2

一、卡片介绍 卡片是一种用于存储和传输数据的可携带式物品&#xff0c;通常由塑料或纸质材料制成。卡片通常具有特定的尺寸和形状&#xff0c;以适应各类读写设备。不同类型的卡片可以用于不同的应用&#xff0c;如身份验证、支付、门禁控制等。 接触卡 接触卡是一种需要与读…

量化策略:CTA,市场中性,指数增强

CTA 策略 commodity Trading Advisor Strategy&#xff0c;即“商品交易顾问策略”&#xff0c;也被称作管理期货策略。 期货T0&#xff0c;股票T1双向交易&#xff1a;就单向交易而言的&#xff0c;不仅能先买入再卖出&#xff08;做多&#xff09;&#xff0c;而且可以先卖…

Java异常(Error与Exception)与常见异常处理——第八讲

前言 前面我们讲解了Java的基础语法以及面向对象的思想,相信大家已经基本掌握了Java的基本编程。在之前代码中,我们也看到代码写错了编译器会提示报错,或者编译器没有提示,但是运行的时候报错了,比如前面的数组查询下标超过数组的长度。所以在使用计算机语言进行项目开发的…

CLIP:连接文本-图像

Contrastive Language-Image Pre-Training CLIP的主要目标是通过对比学习&#xff0c;学习匹配图像和文本。CLIP最主要的作用&#xff1a;可以将文本和图像表征映射到同一个表示空间 这是通过训练模型来预测哪个图像属于给定的文本&#xff0c;反之亦然。在训练过程中&#…

高频策略:抢盘口,做市,短期趋势

利润来源 价格短期趋势随机游走震荡 策略分类 抢盘口&#xff1a;盘口大单封堵&#xff0c;快速在盘口中双向下单&#xff0c;赚取价差做市&#xff1a;盘口买卖活跃&#xff0c;预测市价单击穿距离&#xff0c;在盘口外双向下单&#xff0c;赚取价差 挂单范围要小于市价击穿距…

十二、分组查询

1、分组查询 &#xff08;1&#xff09;基础语法&#xff1a; select 字段列表 from 表名 [where 条件] group by 分组字段名 [having 分组之后的过滤条件] &#xff08;2&#xff09;注意事项&#xff1a; &#xff08;3&#xff09;理解&#xff1a; select后的“字段列表…

personalized image enhancement 调研

Personalized Image Enhancement Using Neural Spline Color Transforms 这是TIP期刊 2020年的一篇论文&#xff0c;首先提出了一个能预测曲线的网络&#xff0c;预测一些锚点&#xff0c;根据锚点插值出连续的曲线&#xff0c;然后用曲线对raw image进行retouching。然后提出了…

ODC现已开源:与开发者共创企业级的数据库协同开发工具

OceanBase 开发者中心&#xff08;OceanBase Developer Center&#xff0c;以下简称 ODC&#xff09;是一款开源的数据库开发和数据库管理协同工具&#xff0c;从首个版本上线距今已经发展了三年有余&#xff0c;ODC 逐步由一款专为 OceanBase 打造的开发者工具演进成为支持多数…

第 361 场 LeetCode 周赛题解

A 统计对称整数的数目 枚举 x x x class Solution { public:int countSymmetricIntegers(int low, int high) {int res 0;for (int i low; i < high; i) {string s to_string(i);if (s.size() & 1)continue;int s1 0, s2 0;for (int k 0; k < s.size(); k)if …

快速为RPG辅助工具MTool增加更多快捷键(一键保存等)

起源&#xff1a;MTool是个好工具&#xff0c;本身固然好用&#xff0c;但是它本身的快捷键功能很少&#xff0c;虽然内置了一个录制工具&#xff0c;但是一个个的录&#xff0c;又麻烦&#xff0c;一般人也难以掌握 本文用快速方法增加更多快捷键&#xff0c;可以做到一键保存…

利用非线性解码模型从人类听觉皮层的活动中重构音乐

音乐是人类体验的核心&#xff0c;但音乐感知背后的精确神经动力学仍然未知。本研究分析了29名患者的独特颅内脑电图(iEEG)数据集&#xff0c;这些患者听了Pink Floyd的歌曲&#xff0c;并应用了先前在语音领域使用的刺激重建方法。本研究成功地从直接神经录音中重建了可识别的…

【性能优化】聊聊性能优化那些事

针对于互联网应用来说&#xff0c;性能优化其实就是一直需要做的事情&#xff0c;因为系统响应慢&#xff0c;是非常影响用户的体验&#xff0c;可能回造成用户流失。所以对于性能非常重要。最近正好接到一个性能优化的需求&#xff0c;需要对所负责的系统进行性能提升。目前接…

Python自动化小技巧22——获取中国高校排名数据

背景 【软科排名】2023年最新软科中国大学排名|中国最好大学排名 (shanghairanking.cn) 爬取这个网站所有的高校的数据&#xff0c;包括学习名称&#xff0c;层次&#xff0c;地区&#xff0c;分数等等信息&#xff1a;[办学层次,学科水平,办学资源,师资规模与结构,人才培养,…

红日靶场五(vulnstack5)渗透分析

环境搭建 win7 192.168.111.132&#xff08;仅主机&#xff09; 192.168.123.212&#xff08;桥接&#xff09; .\heart p-0p-0p-0win2008 ip: 192.168.111.131&#xff08;仅主机&#xff09; sun\admin 2020.comkali ip: 192.168.10.131&#xff08;nat&#xff09;vps&…

字节一面:说说地址栏输入 URL 敲下回车后发生了什么?

前言 最近博主在字节面试中遇到这样一个面试题&#xff0c;这个问题也是前端面试的高频问题&#xff0c;作为一名前端开发工程师&#xff0c;我们只有了解地址栏输入 URL 敲下回车后发生的事件&#xff0c;才知道性能优化如何下手&#xff0c;性能优化也是前端必备知识&#xf…

第一章_线程基础知识

先拜拜大神 Doug Lea&#xff08;道格.利&#xff09; java.util.concurrent在并发编程中使用的工具包 为什么学习并用好多线程极其重要 硬件方面 摩尔定律失效 摩尔定律&#xff1a;它是由英特尔创始人之一Gordon Moore&#xff08;戈登.摩尔&#xff09;提出来的。其内容为…

Centos 6.5 升级到Centos7指导手册

一、背景 某业务系统因建设较早&#xff0c;使用的OS比较过时&#xff0c;还是centos6.5的系统&#xff0c;因国产化需要&#xff0c;需将该系统升级到BClinux 8.6&#xff0c;但官方显示不支持centos 6.x升级到8&#xff0c;需先将centos6.5升级到centos7的最新版&#xff0c…

Python之作业(一)

Python之作业&#xff08;一&#xff09; 作业 打印九九乘法表 用户登录验证 用户依次输入用户名和密码&#xff0c;然后提交验证用户不存在、密码错误&#xff0c;都显示用户名或密码错误提示错误3次&#xff0c;则退出程序验证成功则显示登录信息 九九乘法表 代码分析 先…