stargan项目实战及源码解读

news2024/11/19 4:43:39

数据及代码链接见文末

​​​​​​​论文解析:Star GAN论文解析-CSDN博客

1.测试模块效果与实验分析

        测试数据需要准备两个文件夹src(源)和ref(目标),这两个文件夹下的文件夹名称代表各个domain。

运行测试模块:

python main.py --mode eval --num_domains 2 --w_hpf 1 \
               --resume_iter 100000 \
               --train_img_dir data/celeba_hq/train \
               --val_img_dir data/celeba_hq/val \
               --checkpoint_dir expr/checkpoints/celeba_hq \
               --eval_dir expr/eval/celeba_hq

或者指定参数:

 2.项目配置与数据源下载

        以人脸数据集为例,数据集下包含训练集和验证集,训练集和测试集下的文件夹代表一个一个domain 

        

        需要注意的是,数据集是做过特殊处理的,里面的人脸是对齐的,如果要训练自己的数据集,也需要做类似的处理 

环境配置:

  • 安装pytorch,默认为1.4版本,比1.4版本高也行
  • pip install ffmpeg
  • pip install opencv-python
  • pip install scikit-image
  • pip install pillow
  • pip install scipy
  • pip install tqdm
  • pip install munch

 常用参数

模型与损失函数相关

  

batch size

训练和测试输入与测试输出文件夹路径 

3.整体流程

         整个网络有四个网络组成,生成器、map映射网络、ecoder、判别器。

  • 生成网络,即对输入图像生成一张给定风格的图像
  • 映射网络,随机初始化一个向量,通过全连接层得到对应风格的转化向量。
  • ecoder:直接将图像编码为对应风格的向量
  • 判别器:对于输入图像,为每一种风格判断真假  

(1)生成器

        生成器生成特定风格的图像,生成器有U-net结构的网络堆叠而成,即先下采样,在上采样。此处的归一化策略采取Instance norm,即在实例维度进行归一化。并使用残差模块

代码

class Generator(nn.Module):
    def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, w_hpf=1):
        super().__init__()
        dim_in = 2**14 // img_size
        self.img_size = img_size
        self.from_rgb = nn.Conv2d(3, dim_in, 3, 1, 1) #(in_channels,out_channels,kernel_size,stride,padding)
        self.encode = nn.ModuleList()
        self.decode = nn.ModuleList()
        self.to_rgb = nn.Sequential(
            nn.InstanceNorm2d(dim_in, affine=True), # 在每个实例维度进行归一化
            nn.LeakyReLU(0.2),
            nn.Conv2d(dim_in, 3, 1, 1, 0))

        # down/up-sampling blocks
        repeat_num = int(np.log2(img_size)) - 4
        if w_hpf > 0:
            repeat_num += 1
        for _ in range(repeat_num):
            dim_out = min(dim_in*2, max_conv_dim)
            self.encode.append(
                ResBlk(dim_in, dim_out, normalize=True, downsample=True))
            self.decode.insert(
                0, AdainResBlk(dim_out, dim_in, style_dim,
                               w_hpf=w_hpf, upsample=True))  # stack-like
            dim_in = dim_out

        # bottleneck blocks
        for _ in range(2):
            self.encode.append(
                ResBlk(dim_out, dim_out, normalize=True)) # 残差模块
            self.decode.insert(
                0, AdainResBlk(dim_out, dim_out, style_dim, w_hpf=w_hpf))

        if w_hpf > 0:
            device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')
            self.hpf = HighPass(w_hpf, device)

    def forward(self, x, s, masks=None):
        x = self.from_rgb(x)
        cache = {}
        for block in self.encode:
            if (masks is not None) and (x.size(2) in [32, 64, 128]):
                cache[x.size(2)] = x
            x = block(x)
        for block in self.decode:
            x = block(x, s)
            if (masks is not None) and (x.size(2) in [32, 64, 128]):
                mask = masks[0] if x.size(2) in [32] else masks[1]
                mask = F.interpolate(mask, size=x.size(2), mode='bilinear')
                x = x + self.hpf(mask * cache[x.size(2)])
        return self.to_rgb(x)

 (2)Map映射网络

        map网络将随机初始化的隐向量转变为风格向量。 map映射网络主要由全连接层构成 

代码实现:

class MappingNetwork(nn.Module):
    def __init__(self, latent_dim=16, style_dim=64, num_domains=2):
        super().__init__()
        layers = []
        layers += [nn.Linear(latent_dim, 512)]
        layers += [nn.ReLU()]
        for _ in range(3):
            layers += [nn.Linear(512, 512)]
            layers += [nn.ReLU()]
        self.shared = nn.Sequential(*layers)

        self.unshared = nn.ModuleList()
        for _ in range(num_domains):
            self.unshared += [nn.Sequential(nn.Linear(512, 512),
                                            nn.ReLU(),
                                            nn.Linear(512, 512),
                                            nn.ReLU(),
                                            nn.Linear(512, 512),
                                            nn.ReLU(),
                                            nn.Linear(512, style_dim))]

    def forward(self, z, y):
        h = self.shared(z)
        out = []
        for layer in self.unshared:
            out += [layer(h)]
        out = torch.stack(out, dim=1)  # (batch, num_domains, style_dim)
        idx = torch.LongTensor(range(y.size(0))).to(y.device)
        s = out[idx, y]  # (batch, style_dim)
        return s

 (3)判别器

        判别器用于判断生成图片和原始图片的真假。其也是由残差模块堆叠而成。具体来说,生成图片向量预测接近于1,原始图片预测接近于0。但是,与传统的生成器不同,这里的生成器对于每一个domain都要预测。

 

(4)style ecoder

        style ecoder为生成图片预测对应的风格向量。其输入为生成的图片,输出为风格向量。风格向量应该与生成这张图片时生成器输入的风格向量非常相近。其网络结构也与判别器相同。

4. 损失函数

1.Style reconstruction

         首先,在使用生成网络生成图片时,我们会输入一张图片和对应风格的向量s,然后生成得到对应风格的图片。在得到生成图片后,我们再使用ecoder将生成图片编码为对应风格的向量s'。很显然,我们希望s和s'足够接近。

 2.Style diversification(多样性损失)

首先,初始化2组向量z1和z2,然后经过map网络得到对应风格的编码s1和s2,很显然,s1和s2是不同的,我们现在希望根据s1和s2生成的结果差异越大越好,差异越大,多样性越高。即损失函数越大越好

 

3.Preserving source characteristics 

        可以理解为一种重构损失,我们希望生成的结果还是同一个人,因此,对于生成图片还原回去要与原来的输入图片足够接近。

4.Adversarial objective

即判别器损失,原始图片预测接近于1,而生成图像预测接近于0

总损失为上述损失的加权和

数据及代码链接:链接:https://pan.baidu.com/s/1aNlghgo6mtD4iWqNgMOWOQ?pwd=s206 
提取码:s206 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1570774.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【随笔】Git 高级篇 -- 撤销变更(十四)

💌 所属专栏:【Git】 😀 作  者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! 💖 欢迎大…

基于单片机便携式太阳能充电器系统设计

**单片机设计介绍,基于单片机便携式太阳能充电器系统设计 文章目录 一 概要二、功能设计三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机便携式太阳能充电器系统设计概要主要涉及利用单片机作为核心控制器件,结合太阳能充电技术和便携式…

java-网络编程socket-聊天室-先导

这边我会简单介绍一下聊天室的组成部分,和思路的引导 涉及知识点 java 中异常处理机制 和 io流和网络编程socket 简单回顾异常机制 Java中的异常机制是一种用于处理程序运行期间出现的错误或异常情况的机制。这种机制允许程序员定义在特定情况下可能发生的错误,并…

Revit 2025新功能一览~

Hello大家好!我是九哥~ Revit2025已经更新,安装后,简单试了下,还是挺不错的,流畅度啊,新功能啊,看来还是有听取用户意见的,接下来就简单看看都有哪些新功能。 好了,今天的…

小红书自动化仿写发文机器人了解一下

您好,我是码农飞哥(wei158556),感谢您阅读本文,欢迎一键三连哦。💪🏻 1. Python基础专栏,基础知识一网打尽,9.9元买不了吃亏,买不了上当。 Python从入门到精通…

CCIE-08-BGP-Listen

目录 实验条件网络拓朴实验目的 开始配置配置动态路由协议配置BGP检查邻居配置 实验条件 网络拓朴 实验目的 将R1配置成Listen状态,自动接收来自其它路由器的建邻居请求、建立邻居 开始配置 配置动态路由协议 这里用EIGRP来配置,保证网络的可达性&a…

2024年华为OD机试真题-推荐多样性-Java-OD统一考试(C卷)

题目描述: 推荐多样性需要从多个列表中选择元素,一次性要返回N屏数据(窗口数量),每屏展示K个元素(窗口大小),选择策略: 1. 各个列表元素需要做穿插处理,即先从…

ES11 学习

文章目录 1. Promise.allSettled2. Module 新增2.1 ! 动态导入 import()2.2 import.meta2.3 export * as obj from module 3. 字符串 matchAll()4. BigInt实际开发相关使用 5. globalThis6. 空值合并运算符7. 可选链操作符 1. Promise.allSettled Promise.allSettled() 返回一个…

应急响应实战笔记05Linux实战篇(2)

第2篇:捕捉短连接 0x00 前言 ​ 短连接(short connnection)是相对于长连接而言的概念,指的是在数据传送过程中,只在需要发送数据时,才去建立一个连接,数据发送完成后,则断开此连接…

Azure service tag 导致的Exchange online 无法发送邮件的问题

最近碰到一个比较有趣的客户问题。 这个客户一直在使用Exchange online 与自己在Azure Vnet 里面的exchange server交换邮件。 客户的网络架构如下图所示。 客户说之前从exchange online往外发邮件一直是好的,但是最近两周开始只有百分之3左右的邮件可以发出去,其他的都pen…

C语言中的结构体:高级特性与扩展应用

前言 结构体在C语言中的应用不仅限于基本的定义和使用,还包含一些高级特性和扩展应用,这些特性和应用使得结构体在编程中发挥着更加重要的作用。 一、位字段(Bit-fields) 在结构体中,我们可以使用位字段来定义成员…

【性能测试】接口测试各知识第2篇:学习目标,1. 理解接口的概念【附代码文档】

接口测试完整教程(附代码资料)主要内容讲述:接口测试,学习目标学习目标,2. 接口测试课程大纲,3. 接口学完样品,4. 学完课程,学到什么,5. 参考:,1. 理解接口的概念。学习目标,RESTFUL1. 理解接口的概念,2.什么是接口测试…

mybatis流式游标查询-导出DB大数据量查询OOM问题

问题场景 Mysql数据处理类型分以下三种 com.mysql.cj.protocol.a.result.ResultsetRowsStatic:普通查询,将结果集一次性全部拉取到内存 com.mysql.cj.protocol.a.result.ResultsetRowsCursor:游标查询,将结果集分批拉取到内存&…

Windows集群部署项目

目录 一,环境准备 1.1.安装MySQL 1.2.安装JDK 1.3.安装TomCat 1.4.安装Nginx 二,部署 2.1.后台服务部署 2.2.Nginx配置负载均衡及静态资源部署 一,环境准备 1.1.安装MySQL 可以参考博客:http://t.csdnimg.cn/A75bg 1.2.…

我为什么会选择Vim来开发Go项目及Vim IDE安装配置和操作

你好,我是孔令飞,字节跳动云原生资深研发、前腾讯云原生技术专家。《企业级 Go 项目开发实战》、《从零开发企业级 Go 应用》作者,欢迎加入 孔令飞的云原生实战营,助你进阶 Go 云原生高级开发工程师。 作为一名 Golang 开发&…

训练营十六天(二叉树part03)

104.二叉树的最大深度 力扣题目链接(opens new window) 题目 给定一个二叉树,找出其最大深度。 二叉树的深度为根节点到最远叶子节点的最长路径上的节点数。 说明: 叶子节点是指没有子节点的节点。 示例: 给定二叉树 [3,9,20,null,null,15,7]&…

从头开发一个RISC-V的操作系统(四)嵌入式开发介绍

文章目录 前提嵌入式开发交叉编译GDB调试,QEMU,MAKEFILE练习 目标:通过这一个系列课程的学习,开发出一个简易的在RISC-V指令集架构上运行的操作系统。 前提 这个系列的大部分文章和知识来自于:[完结] 循序渐进&#x…

leetcode 热题 100(部分)C/C++

leetcode 热题 100 双指针 盛最多水的容器 【mid】【双指针】 思路: 好久没写代码sb了,加上之前写的双指针并不多,以及有点思维定势了。我对双指针比较刻板的印象一直是两层for循环i,j,初始时i,j都位于左界附近&…

MySQL数据库 数据库基本操作(三):表的增删查改(中)

1. 数据库的约束 1.1 约束类型(一般发生于表的创建中) NOT NULL - 指示某列不能存储 NULL 值。UNIQUE - 保证某列的每行必须有唯一的值。DEFAULT - 规定没有给列赋值时的默认值。PRIMARY KEY - NOT NULL 和 UNIQUE 的结合。确保某列(或两个列多个列的结合&#xf…

【CSS】浮动笔记及案例

CSS浮动 1. 认识浮动 float属性可以指定一个元素沿着左侧或者是右侧放置,允许文本和内联元素环绕它 float属性最初只使用文字环绕图片但却是早起CSS最好用的左右布局方案 绝对定位、浮动都会让元素脱标,以达到灵活布局的目的可以通过float属性让元素脱…