【深度学习】CodeFormer训练过程,如何训练人脸修复模型CodeFormer

news2024/11/6 3:05:19

文章目录

  • BasicSR介绍
  • 环境
  • 数据
  • 阶段 I - VQGAN
  • 阶段 II - CodeFormer (w=0)
  • 阶段 III - CodeFormer (w=1)

代码地址:https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0

论文的一些简略介绍:
https://qq742971636.blog.csdn.net/article/details/134562550

BasicSR介绍

CodeFormer整个项目都沿袭BasicSR,了解一下BasicSR很有必要:

https://mp.csdn.net/mp_blog/creation/success/135674803

环境

# git clone this repository
git clone https://github.com/sczhou/CodeFormer
cd CodeFormer

# create new anaconda env
conda create -n codeformer python=3.8 -y
conda activate codeformer

conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia

# install python dependencies
pip3 install -r requirements.txt
python basicsr/setup.py develop

conda install -c conda-forge dlib (only for face detection or cropping with dlib)

数据

找一些高清人脸数据1024*1024。

人脸数据需要对齐,对齐方式为: https://qq742971636.blog.csdn.net/article/details/135521146

阶段 I - VQGAN

训练VQGAN:

python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/VQGAN_512_ds32_nearest_stage1.yml --launcher pytorch
CUDA_VISIBLE_DEVICES=0,2,3 python -m torch.distributed.launch --nproc_per_node=3 --master_port=4321 basicsr/train.py -opt options/VQGAN_512_ds32_nearest_stage1.yml --launcher pytorch # 指定三张显卡训练,对应VQGAN_512_ds32_nearest_stage1.yaml也是需要修改的

训练完VQGAN后,可以通过下面代码预先获得训练数据集的密码本序列,从而加速后面阶段的训练过程:

python scripts/generate_latent_gt.py

如果你不需要训练自己的VQGAN,可以在Release v0.1.0文档中找到预训练的VQGAN (vqgan_code1024.pth)和对应的密码本序列 (latent_gt_code1024.pth): https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0

打开日志查看训练过程:

tensorboard --logdir="/ssd/xiedong/CodeFormer/tb_logger/20240116_182107_VQGAN-512-ds32-nearest-stage1" --bind_all

在这里插入图片描述

VQGAN本身就是一个图生图的网络,在中间使用transformer将特征图转为embedding. 而 CodeFormer就是要利用这每张图的embedding来进行面部修复。

下面代码里用vqgan_code1024.pth获取训练数据的密码本,vqgan_code1024.pth的encoder输出的是2563232的特征图,由embedding给到1*1024,最终所有图保存为一个pytorch文件。

import argparse
import glob
import numpy as np
import os
import cv2
import torch
from torchvision.transforms.functional import normalize
from tqdm import tqdm

from basicsr.utils import imwrite, img2tensor, tensor2img

from basicsr.utils.registry import ARCH_REGISTRY

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--test_path', type=str, default='/ssd/xiedong/FFHQ/faces_hq_sr')
    parser.add_argument('-o', '--save_root', type=str, default='/ssd/xiedong/FFHQ/lt_output')
    parser.add_argument('--codebook_size', type=int, default=1024)
    parser.add_argument('--ckpt_path', type=str, default='/ssd/xiedong/CodeFormer/weights/vqgan/vqgan_code1024.pth')
    args = parser.parse_args()

    if args.save_root.endswith('/'):  # solve when path ends with /
        args.save_root = args.save_root[:-1]
    dir_name = os.path.abspath(args.save_root)
    os.makedirs(dir_name, exist_ok=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    test_path = args.test_path
    save_root = args.save_root
    ckpt_path = args.ckpt_path
    codebook_size = args.codebook_size

    vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',
                                               codebook_size=codebook_size).to(device)
    checkpoint = torch.load(ckpt_path)['params_ema']

    vqgan.load_state_dict(checkpoint)
    vqgan.eval()

    sum_latent = np.zeros((codebook_size)).astype('float64')
    size_latent = 32
    latent = {}
    latent['orig'] = {}
    latent['hflip'] = {}
    for i in ['orig', 'hflip']:
        # for i in ['hflip']:
        for img_path in tqdm(sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g')))):
            img_name = os.path.basename(img_path)
            img = cv2.imread(img_path)
            if i == 'hflip':
                cv2.flip(img, 1, img)
            img = img2tensor(img / 255., bgr2rgb=True, float32=True)
            normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
            img = img.unsqueeze(0).to(device)
            with torch.no_grad():
                # output = net(img)[0]
                # x, feat_dict = vqgan.encoder(img, True)
                x = vqgan.encoder(img)
                x, _, log = vqgan.quantize(x)
            # del output
            torch.cuda.empty_cache()

            min_encoding_indices = log['min_encoding_indices']
            min_encoding_indices = min_encoding_indices.view(size_latent, size_latent)
            latent[i][img_name[:-4]] = min_encoding_indices.cpu().numpy()
            print(img_name, latent[i][img_name[:-4]].shape)

    latent_save_path = os.path.join(save_root, f'latent_gt_code{codebook_size}.pth')
    torch.save(latent, latent_save_path)
    print(f'\nLatent GT code are saved in {save_root}')

阶段 II - CodeFormer (w=0)

w=0 是需要模型完全追求抽象美学,w=1 是需要模型完全追求与原图相似。

在第一个阶段,得到了每张图对应的embedding。

训练密码本训练预测模块:

python -m torch.distributed.launch --nproc_per_node=8 --master_port=4322 basicsr/train.py -opt options/CodeFormer_stage2.yml --launcher pytorch

预训练CodeFormer第二阶段模型 (codeformer_stage2.pth)可以在Releases v0.1.0文档里下载: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0

阶段 III - CodeFormer (w=1)

训练可调模块:

python -m torch.distributed.launch --nproc_per_node=8 --master_port=4323 basicsr/train.py -opt options/CodeFormer_stage3.yml --launcher pytorch

预训练CodeFormer模型 (codeformer.pth)可以在Releases v0.1.0文档里下载: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1402633.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

React Hooks 源码解析:useEffect

React Hooks 源码解析(4):useEffect React 源码版本: v16.11.0源码注释笔记:airingursb/react 1. useEffect 简介 1.1 为什么要有 useEffect 我们在前文中说到 React Hooks 使得 Functional Component 拥有 Class Component 的…

206.反转链表(附带源码)

一、思路 二、代码 一、思路 将指针调转一个方向就行,很简单 做法: 定义2个指针:prev、 cur、 next 当next为空时,循环结束 思路清晰,操作清楚,开始敲代码。 二、代码 struct ListNode* reverseList(s…

Tide Quencher 8WS-Mal,TQ8WS-Mal,能够针对特定的荧光物质进行淬灭

您好,欢迎来到新研之家 文章关键词:Tide Quencher 8WS maleimide,TQ8WS maleimide ,Tide Quencher 8WS Mal,TQ8WS Mal,荧光淬灭剂Tide Quencher 8WS 马来酰亚胺 ,TQ8WS 马来酰亚胺 一、基本信…

【蓝桥杯备赛Java组】语言基础|竞赛常用库函数|输入输出|String的使用|常见的数学方法|大小写转换

🎥 个人主页:深鱼~🔥收录专栏:蓝桥杯🌄欢迎 👍点赞✍评论⭐收藏 目录 一、编程基础 1.1 Java类的创建 1.2 Java方法 1.3 输入输出 1.4 String的使用 二、竞赛常用库函数 1.常见的数学方法 2.大小写转…

链表的分类

链表的八种类别: 这三行属性结合,共有八种链表: 1.带头单向循环 2.带头双向循环 3.带头单向不循环 4.带头双向不循环 5.带头单向循环 6.带头双向循环 7.带头单向不循环 8.带头双向不循环 一.单向或双向 单向链表只有一个指向后续节点的指针 双向链表则有两个指针,分别…

高客单价企业必读:私域运营趋势分析与实操技巧

一、深入挖掘:场景洞察的新维度 当我们收到销售的群发信息时,通常会感到被打扰或骚扰,这是因为这些信息通常是基于广泛的受众群体发送的,缺乏针对个体消费者的定制化和个性化。这种缺乏个性化的沟通方式很容易被消费者视为不必要…

ITSS认证有用吗❓属于gj级证书吗❓

🔥ITSS由中国电子技术标准化研究院推出,包括“IT 服务工程师”和“IT 服务经理”两种认证。该系列认证符合GB/T 28827.1 的评估和ITSS服务资质升级要求。 🎯ITSS是受到gj认可的,在全国范围内对IT服务管理人员从业资格为一的权威的…

计算机网络学习The next day

在计算机网络first day中,我们了解了计算机网络这个科目要学习什么,因特网的概述,三种信息交换方式等,在今天,我们就来一起学习一下计算机网络的定义和分类,以及计算机网络中常见的几个性能指标。 废话不多…

express.js+mysql实现获取文章分类

var express require("express"); var router express.Router(); // 引入封装的获取验证码的方法 var art_handler require("../controllers/artcate"); // 获取文章分类的列表 router.get("/cates", art_handler.getArticleClassification)…

通付盾获2023年度移动互联网APP产品安全漏洞治理优秀案例 荣获工信部CAPPVD漏洞库技术支撑单位

为深入贯彻落实《网络产品安全漏洞管理规定》,规范移动互联网App产品安全漏洞发现、报告、修补和发布等行为,提升网络产品提供者安全漏洞管理意识,探索最前沿的漏洞挖掘技术发展趋势和创新应用,在上级主管部门指导支持下,1月16日&…

浅谈PCB设计与PCB制板的紧密关系

在现代电子领域,印刷电路板(PCB)是各种电子设备的核心组成部分。PCB设计和PCB制板是电子产品开发过程中不可或缺的两个重要环节。本文将深入探讨PCB设计与PCB制板之间的关系,以及如何通过协同工作实现高效的电子产品开发。 PCB设计…

【QT+QGIS跨平台编译】之三:【OpenSSL+Qt跨平台编译】(一套代码、一套框架,跨平台编译)

文章目录 一、OpenSSL介绍二、OpenSSL配置三、Window环境下配置四、Linux环境下配置五、Mac环境下配置 一、OpenSSL介绍 OpenSSL是一个开放源代码的软件库包,应用程序可以使用这个包来进行安全通信,避免窃听,同时确认另一端连接者的身份。这…

WorkPlus AI助理私有化部署,助力企业降本增效

在当今数字化时代,提供卓越的客户服务成为了企业成功的重要因素。而AI智能客服技术的兴起,则成为了实现高效、快捷客户服务的利器。作为一款领先的AI助理解决方案,WorkPlus AI助理能够私有化部署,为企业打造私有知识库&#xff0c…

无缝衔接Stable Diffusion,一张照片几秒钟就能生成个性化图片-InstantID

最近一段时间基于扩散模型的图像处理方法遍地开花,接下来为大家介绍一种风格化图像的方法InstantID,可以通过仅有一张人脸照片,几秒钟内生成不同风格的人物照片。与传统方法需要多张参考图像和复杂的微调过程不同,InstantID只需一…

Linux: dev: glibc: 里面有很多的关于系统调用的函数

其实都没有实体源代码klogctl.c,而是通过编译时构造出来的源代码实体,比如klogctl这个函数,glibc的反汇编如下: 直接是0x67这个系统调用:103: Reading symbols from /usr/lib64/libc-2.28.so... (No debugg…

vue3前端开发,一篇文章看懂何谓pinia

vue3前端开发,pinia的基础练习第一节! 前言,pinia是为了取代vuex而诞生的产品,它有一些改进。去掉了之前的mutations。只有一个action,既可以支持异步,又支持同步。还提供了解构函数,可以把返回的对象内部属性和方法直…

Prometheus配置Grafana监控大屏(Docker)

拉取镜像 docker pull grafana/grafana挂载目录 mkdir /data/prometheus/grafana -p chmod 777 /data/prometheus/grafana临时启动 docker run -d -p 3000:3000 --name grafana grafana/grafana从容器拷贝配置文件至对应目录 docker exec -it grafana cat /etc/grafana/gra…

【RHCSA服务搭建实验】之apache

虚拟web主机类型 一、基于端口 1.vim /etc/httpd/conf.d/vhost2.conf ---- — 改变http服务默认访问路径 <directory /testweb1>allowoverride none 表示不允许覆盖其他配置require all granted 表示允许所有请求 </directory> <virtualhost 0.0.0.0:…

x-cmd pkg | jq - 命令行 JSON 处理器

目录 简介首次用户功能特点类似工具进一步探索 简介 jq 是轻量级的 JSON 处理工具&#xff0c;由 Stephen Dolan 于 2012 年使用 C 语言开发。 它的功能极为强大&#xff0c;语法简洁&#xff0c;可以灵活高效地完成从 JSON 数据中提取特定字段、过滤和排序数据、执行复杂的转…

有色金属市场分析:预计2023年产量增幅在3.5%左右

上周各有色金属品种走势接近&#xff0c;均呈现出周初持续走弱、最后两个交易日反弹的走势。影响有色金属行情的主线逻辑一个是美国债务上限谈判的进展情况&#xff0c;另一个是全球经济衰退的预期。上周四和上周五市场整体反弹&#xff0c;主要由于美国债务上限谈判出现进展&a…