BasicSR项目(通用图像超分、修复、增强工具库)介绍

news2024/9/22 3:46:58

项目地址:https://github.com/XPixelGroup/BasicSR
文档地址:https://github.com/XPixelGroup/BasicSR-docs/releases
在这里插入图片描述

BasicSR 是一个开源项目,旨在提供一个方便易用的图像、视频的超分、复原、增强的工具箱。BasicSR 代码库从2018年4月20日开始第一个提交,然后随着做研究、打比赛、发论文,逐渐发展与完善起来。它从最开始的针对超分辨率算法到后来拓展到其他更多复原增强相关的算法,
因此,BasicSR 中 SR 的涵义也从 Super-Resolution 延拓到 Super-Restoration。2022年5月9日,BasicSR 迎来新的里程碑,它加入到 XPixel 大家庭中,和更多的小伙伴们一起致力于把 BasicSR 建设得更好!

1、基本说明

1.1 支持的模型

在BasicSR项目中支持的模型如下所示,虽然数量不多,但是可以轻易的添加其他模型结构到BasicSR项目中。很多最新的图像超分论文也是基于BasicSR项目完成,但代码更新没有同步到BasicSR库中。

1.2 运行环境要求

Python 和 Python 库 (对于 Python 库,我们提供了相应的安装脚本):
a) Python >= 3.7 (推荐使用Anaconda或者Miniconda)
b) PyTorch >= 1.7:目前深度学习领域广泛使用的深度学习框架

1.3 项目安装

打开https://github.com/XPixelGroup/BasicSR?tab=readme-ov-file,下载项目,然后在命令行里执行: pip install -e .
在这里插入图片描述
在上图可以看到,特殊原因导致安装失败。参考:https://blog.csdn.net/weixin_46455141/article/details/131353266 ,执行命令,pip config set global.index-url https://mirrors.aliyun.com/pypi/simple 更换源,然后重新执行安装命令 pip install -e . 可以看到成功安装。在这里插入图片描述

1.4 特殊算子支持

通过1.3方式安装的库不支持DCN(可变形卷积)、StyleGAN 中的特定的算子,比如:upfirdn2d, fused_act。安装时需要附加额外信息,支持可变形卷积。若无特殊需求,可以忽略。
需编译特殊算子的安装命令如下,可以看到是多了一个环境变量参数BASICSR_EXT=True

BASICSR_EXT=True pip install -e .

作者也提到,可以在执行代码时加入BASICSR_JIT=True参数,即时加加加载载载 (JIT) PyTorch C++ 编译算子

BASICSR_JIT=True python inference/inference_stylegan2.py

二者对比如下:
在这里插入图片描述

2、项目代码结构

2.1 基本结构

红色 表示和跑实验直接相关的文件,即我们平时打交道最多的文件;
蓝色 表示其他与 BasicSR 存在相关的代码文件;
通常只需要了解红色的部分即可。
在这里插入图片描述
basicsr目录下是该库的核心代码,其目录结构如下所示。可以看到关于模型有archs与models,archs才是模型网络结构与forward的定义。
在这里插入图片描述

2.2 models详情

models目录下详情如下,包含多种模型结构,主要以SRModel与SRRANModel为基类,最原始的基类是BaseModel。
在这里插入图片描述
快速过base_model.py,可以发现有函数model_ema,用于实现模型参数的指数更新。

对比SRModel与SRRANModel,可以发现SRRANModel多了一个net_d相关的参数(应该是鉴别器),对net_d进行检索,可以发现,SRRANModel,同样多了cri_gan(loss函数)的使用。与之对应,配置文件里当有相应的配置项。

        self.optimizer_d.zero_grad()
        # real
        real_d_pred = self.net_d(self.gt)
        l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
        loss_dict['l_d_real'] = l_d_real
        loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
        l_d_real.backward()
        # fake
        fake_d_pred = self.net_d(self.output.detach())
        l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
        loss_dict['l_d_fake'] = l_d_fake
        loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
        l_d_fake.backward()
        self.optimizer_d.step()

在观察RealESRNetModel与RealESRGANModel,可以发现其在feed_data函数上与原始基类不一样,代码含量极大。具体如下所示,可以看出其是包含了在线下采样策略。对其入参data进行分析,可以看到多了kernel1,kernel2,sinc_kernel等属性项,这表明这两类模型使用的dataload与原始基类模型不一样。

    @torch.no_grad()
    def feed_data(self, data):
        """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
        """
        if self.is_train and self.opt.get('high_order_degradation', True):
            # training data synthesis
            self.gt = data['gt'].to(self.device)
            # USM sharpen the GT images
            if self.opt['gt_usm'] is True:
                self.gt = self.usm_sharpener(self.gt)

            self.kernel1 = data['kernel1'].to(self.device)
            self.kernel2 = data['kernel2'].to(self.device)
            self.sinc_kernel = data['sinc_kernel'].to(self.device)

            ori_h, ori_w = self.gt.size()[2:4]

            # ----------------------- The first degradation process ----------------------- #
            # blur
            out = filter2D(self.gt, self.kernel1)
            # random resize
            updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
            if updown_type == 'up':
                scale = np.random.uniform(1, self.opt['resize_range'][1])
            elif updown_type == 'down':
                scale = np.random.uniform(self.opt['resize_range'][0], 1)
            else:
                scale = 1
            mode = random.choice(['area', 'bilinear', 'bicubic'])
            out = F.interpolate(out, scale_factor=scale, mode=mode)
            # add noise
            gray_noise_prob = self.opt['gray_noise_prob']
            if np.random.uniform() < self.opt['gaussian_noise_prob']:
                out = random_add_gaussian_noise_pt(
                    out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
            else:
                out = random_add_poisson_noise_pt(
                    out,
                    scale_range=self.opt['poisson_scale_range'],
                    gray_prob=gray_noise_prob,
                    clip=True,
                    rounds=False)
            # JPEG compression
            jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
            out = torch.clamp(out, 0, 1)  # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
            out = self.jpeger(out, quality=jpeg_p)

            # ----------------------- The second degradation process ----------------------- #
            # blur
            if np.random.uniform() < self.opt['second_blur_prob']:
                out = filter2D(out, self.kernel2)
            # random resize
            updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
            if updown_type == 'up':
                scale = np.random.uniform(1, self.opt['resize_range2'][1])
            elif updown_type == 'down':
                scale = np.random.uniform(self.opt['resize_range2'][0], 1)
            else:
                scale = 1
            mode = random.choice(['area', 'bilinear', 'bicubic'])
            out = F.interpolate(
                out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
            # add noise
            gray_noise_prob = self.opt['gray_noise_prob2']
            if np.random.uniform() < self.opt['gaussian_noise_prob2']:
                out = random_add_gaussian_noise_pt(
                    out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
            else:
                out = random_add_poisson_noise_pt(
                    out,
                    scale_range=self.opt['poisson_scale_range2'],
                    gray_prob=gray_noise_prob,
                    clip=True,
                    rounds=False)

            # JPEG compression + the final sinc filter
            # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
            # as one operation.
            # We consider two orders:
            #   1. [resize back + sinc filter] + JPEG compression
            #   2. JPEG compression + [resize back + sinc filter]
            # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
            if np.random.uniform() < 0.5:
                # resize back + the final sinc filter
                mode = random.choice(['area', 'bilinear', 'bicubic'])
                out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
                out = filter2D(out, self.sinc_kernel)
                # JPEG compression
                jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
                out = torch.clamp(out, 0, 1)
                out = self.jpeger(out, quality=jpeg_p)
            else:
                # JPEG compression
                jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
                out = torch.clamp(out, 0, 1)
                out = self.jpeger(out, quality=jpeg_p)
                # resize back + the final sinc filter
                mode = random.choice(['area', 'bilinear', 'bicubic'])
                out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
                out = filter2D(out, self.sinc_kernel)

            # clamp and round
            self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.

            # random crop
            gt_size = self.opt['gt_size']
            self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])

            # training pair pool
            self._dequeue_and_enqueue()
            self.lq = self.lq.contiguous()  # for the warning: grad and param do not obey the gradient layout contract
        else:
            # for paired training or validation
            self.lq = data['lq'].to(self.device)
            if 'gt' in data:
                self.gt = data['gt'].to(self.device)
                self.gt_usm = self.usm_sharpener(self.gt)

2.3 archs详情

arch下是具体的超分模型或者是gan超分模型中的生成器
在这里插入图片描述
任意打开一个文件,如ridnet_arch.py,可以发现关键代码如下, 只要在模型类上添加@ARCH_REGISTRY.register()修饰,即可注册为BasicSR库中的模型。然后,模型只要能正常forward即可。

import torch
import torch.nn as nn

from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import ResidualBlockNoBN, make_layer
@ARCH_REGISTRY.register()
class RIDNet(nn.Module):

    def __init__(self,
                 in_channels,
                 mid_channels,
                 out_channels,
                 num_block=4,
                 img_range=255.,
                 rgb_mean=(0.4488, 0.4371, 0.4040),
                 rgb_std=(1.0, 1.0, 1.0)):
        super(RIDNet, self).__init__()

        self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)
        self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)

        self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
        self.body = make_layer(
            EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)
        self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        res = self.sub_mean(x)
        res = self.tail(self.body(self.relu(self.head(res))))
        res = self.add_mean(res)

        out = x + res
        return out

在通过对basicsr\archs_init_.py进行分析,可以看到只会将以‘_arch.py’结尾的模型文件添加到库中。

import importlib
from copy import deepcopy
from os import path as osp

from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import ARCH_REGISTRY

__all__ = ['build_network']

# automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with '_arch.py'
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]


def build_network(opt):
    opt = deepcopy(opt)
    network_type = opt.pop('type')
    net = ARCH_REGISTRY.get(network_type)(**opt)
    logger = get_root_logger()
    logger.info(f'Network [{net.__class__.__name__}] is created.')
    return net

2.4 losses详情

losses目录下的文件到比较少,但通过对__init__.py进行分析可以发现只会将以‘_loss.py’结尾的文件注册到系统中。

在这里插入图片描述
通过对basic_loss.py进行查看,只要5种loss。但需要注意的是,PerceptualLoss是感知损失,需要依赖vgg模型对y_true与y_pred进行推理然后计算中间层特征的差异。
在这里插入图片描述
WeightedTVLoss是一种不需要y_true的loss,其主要目的是使梯度信息最少,然后使超分后的数据更加平滑。其实现代码如下所示:

@LOSS_REGISTRY.register()
class WeightedTVLoss(L1Loss):
    def __init__(self, loss_weight=1.0, reduction='mean'):
        if reduction not in ['mean', 'sum']:
            raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')
        super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction)

    def forward(self, pred, weight=None):
        if weight is None:
            y_weight = None
            x_weight = None
        else:
            y_weight = weight[:, :, :-1, :]
            x_weight = weight[:, :, :, :-1]

        y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
        x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)

        loss = x_diff + y_diff

        return loss

CharbonnierLoss的核心是charbonnier_loss函数,其对charbonnier_loss输出的结果进行加权。可以看到charbonnier_loss与rmse loss类似,但有一个eps充当类似的正则化参数。

@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):
    return torch.sqrt((pred - target)**2 + eps)

2.5 dataloader详情

dataloader对应着basicsr\data目录下的文件,通过对__init__.py进行分析可以发现只会将以‘_dataset.py’结尾的文件注册到系统中。可以看到一共有7个dataset文件,表明其支持7种存储结构下的数据集。
在这里插入图片描述
FFHQDataset与SingleImageDataset 通过对代码中__getitem__函数分析,可以看到FFHQDataset与支持数据中只有一种图片。SingleImageDataset与FFHQDataset类似,也是支持只有一种图片的数据集。

@DATASET_REGISTRY.register()
class FFHQDataset(data.Dataset):
    def __getitem__(self, index):
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)

        # load gt image
        gt_path = self.paths[index]
        # avoid errors caused by high latency in reading files
        retry = 3
        while retry > 0:
            try:
                img_bytes = self.file_client.get(gt_path)
            except Exception as e:
                logger = get_root_logger()
                logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
                # change another file to read
                index = random.randint(0, self.__len__())
                gt_path = self.paths[index]
                time.sleep(1)  # sleep 1s for occasional server congestion
            else:
                break
            finally:
                retry -= 1
        img_gt = imfrombytes(img_bytes, float32=True)

        # random horizontal flip
        img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
        # BGR to RGB, HWC to CHW, numpy to tensor
        img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
        # normalize
        normalize(img_gt, self.mean, self.std, inplace=True)
        return {'gt': img_gt, 'gt_path': gt_path}

使用代码如下,可以看到支持一种ffhq_256.lmdb的数据结构。

datasets:
  train:
    name: FFHQ
    type: FFHQDataset
    dataroot_gt: datasets/ffhq/ffhq_256.lmdb
    io_backend:
      type: lmdb

    use_hflip: true
    mean: [0.5, 0.5, 0.5]
    std: [0.5, 0.5, 0.5]

PairedImageDataset 通过对代码中__getitem__函数分析,可以看到PairedImageDataset支持数据是需要gt_path与lq_path。使用代码如下,可以看到需要配置dataroot_gt与dataroot_lq;meta_info_file与filename_tmpl只是可选参数。

datasets:
  train:
    name: DIV2K
    type: PairedImageDataset
    dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub
    dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub
    meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
    # (for lmdb)
    # dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
    # dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
    filename_tmpl: '{}'
    io_backend:
      type: disk
      # (for lmdb)
      # type: lmdb

    gt_size: 128
    use_hflip: true
    use_rot: true

    # data loader
    num_worker_per_gpu: 6
    batch_size_per_gpu: 16
    dataset_enlarge_ratio: 100
    prefetch_mode: ~

  val:
    name: Set5
    type: PairedImageDataset
    dataroot_gt: datasets/Set5/GTmod12
    dataroot_lq: datasets/Set5/LRbicx4
    io_backend:
      type: disk

其返回数据结构为

{'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}

RealESRGANDataset 是用于RealESRGAN模型的数据加载器。格式数据的使用如下所示,可以看到需要设置各种数据在线下采样的参数配置。但最为关键的是dataroot_gt与meta_info,但从使用中可以看出在meta_info的txt中只是存储了图像的相对文件名。

# Each line in the meta_info describes the relative path to an image
            with open(self.opt['meta_info']) as fin:
                paths = [line.strip().split(' ')[0] for line in fin]
                self.paths = [os.path.join(self.gt_folder, v) for v in paths]
datasets:
  train:
    name: DF2K+OST
    type: RealESRGANDataset
    dataroot_gt: datasets/DF2K
    meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
    io_backend:
      type: disk

    blur_kernel_size: 21
    kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
    kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
    sinc_prob: 0.1
    blur_sigma: [0.2, 3]
    betag_range: [0.5, 4]
    betap_range: [1, 2]

    blur_kernel_size2: 21
    kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
    kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
    sinc_prob2: 0.1
    blur_sigma2: [0.2, 1.5]
    betag_range2: [0.5, 4]
    betap_range2: [1, 2]

    final_sinc_prob: 0.8

    gt_size: 256
    use_hflip: True
    use_rot: False

其返回数据结构为

{'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}

RealESRGANPairedDataset 也是用于RealESRGAN的数据加载器,但其参数结构、返回数据结构与PairedImageDataset是一样的。返回结构为:

{'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}

**VideoTestDataset ** 是针对图像序列(视频)的数据加载,其同样需要设置’dataroot_gt’,‘dataroot_lq’,每一个路径要求的数据格式为如下所示。

        dataroot
        ├── subfolder1
            ├── frame000
            ├── frame001
            ├── ...
        ├── subfolder2
            ├── frame000
            ├── frame001
            ├── ...
        ├── ...

返回的数据格式为:


        return {
            'lq': imgs_lq,  # (t, c, h, w)
            'gt': img_gt,  # (c, h, w)
            'folder': folder,  # folder name
            'idx': self.data_info['idx'][index],  # e.g., 0/99
            'border': border,  # 1 for border, 0 for non-border
            'lq_path': lq_path  # center frame
        }

REDSDatasetVimeo90KDataset 是特定数据的加载方法

2.6 metrics详情

metrics目录下的是评价指标,目前通__init__.py文件可以看到只支持’calculate_psnr’, ‘calculate_ssim’, ‘calculate_niqe’ 三种。其中niqe是一种无参考的评价指标,我们可以将自行将其他

对应配置文件中的使用代码如下:


  metrics:
    psnr: # metric name, can be arbitrary
      type: calculate_psnr
      crop_border: 0
      test_y_channel: false
    ssim:
      type: calculate_ssim
      crop_border: 0
      test_y_channel: false
    niqe:
      type: calculate_niqe
      crop_border: 4
      num_thread: 8

2.7 options详情(配置文件)

options目录与核心代码目录平级,在项目根路径下。主要包含train与test两个分支,里面存储的是对应的使用配置(含模型结构、数据加载器配置、loss配置等)。
以options\train\RealESRGAN\train_realesrgan_x2plus.yml为例

# general settings
name: train_RealESRGANx2plus_400k_B12G4
model_type: RealESRGANModel  #指定模型结构
scale: 2
num_gpu: auto  # auto: can infer from your visible devices automatically. official: 4 GPUs  #指定gpu数量
manual_seed: 0

# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
# USM the ground-truth
l1_gt_usm: True
percep_gt_usm: True
gan_gt_usm: False

# the first degradation process
resize_prob: [0.2, 0.7, 0.1]  # up, down, keep
resize_range: [0.15, 1.5]
gaussian_noise_prob: 0.5
noise_range: [1, 30]
poisson_scale_range: [0.05, 3]
gray_noise_prob: 0.4
jpeg_range: [30, 95]

# the second degradation process
second_blur_prob: 0.8
resize_prob2: [0.3, 0.4, 0.3]  # up, down, keep
resize_range2: [0.3, 1.2]
gaussian_noise_prob2: 0.5
noise_range2: [1, 25]
poisson_scale_range2: [0.05, 2.5]
gray_noise_prob2: 0.4
jpeg_range2: [30, 95]

gt_size: 256
queue_size: 180

# dataset and data loader settings
datasets:
  train:
    name: DF2K+OST
    type: RealESRGANDataset
    dataroot_gt: datasets/DF2K
    meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
    io_backend:
      type: disk

    blur_kernel_size: 21
    kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
    kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
    sinc_prob: 0.1
    blur_sigma: [0.2, 3]
    betag_range: [0.5, 4]
    betap_range: [1, 2]

    blur_kernel_size2: 21
    kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
    kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
    sinc_prob2: 0.1
    blur_sigma2: [0.2, 1.5]
    betag_range2: [0.5, 4]
    betap_range2: [1, 2]

    final_sinc_prob: 0.8

    gt_size: 256
    use_hflip: True
    use_rot: False

    # data loader
    num_worker_per_gpu: 5
    batch_size_per_gpu: 12
    dataset_enlarge_ratio: 1
    prefetch_mode: ~

  # Uncomment these for validation
  # val:
  #   name: validation
  #   type: PairedImageDataset
  #   dataroot_gt: path_to_gt
  #   dataroot_lq: path_to_lq
  #   io_backend:
  #     type: disk

# network structures
network_g:
  type: RRDBNet
  num_in_ch: 3
  num_out_ch: 3
  num_feat: 64
  num_block: 23
  num_grow_ch: 32
  scale: 2

network_d:
  type: UNetDiscriminatorSN
  num_in_ch: 3
  num_feat: 64
  skip_connection: True

# path
path:
  # use the pre-trained Real-ESRNet model
  pretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pth
  param_key_g: params_ema
  strict_load_g: true
  resume_state: ~

# training settings
train:
  ema_decay: 0.999
  optim_g:
    type: Adam
    lr: !!float 1e-4
    weight_decay: 0
    betas: [0.9, 0.99]
  optim_d:
    type: Adam
    lr: !!float 1e-4
    weight_decay: 0
    betas: [0.9, 0.99]

  scheduler:
    type: MultiStepLR
    milestones: [400000]
    gamma: 0.5

  total_iter: 400000
  warmup_iter: -1  # no warm up

  # losses
  pixel_opt:
    type: L1Loss
    loss_weight: 1.0
    reduction: mean
  # perceptual loss (content and style losses)
  perceptual_opt:
    type: PerceptualLoss
    layer_weights:
      # before relu
      'conv1_2': 0.1
      'conv2_2': 0.1
      'conv3_4': 1
      'conv4_4': 1
      'conv5_4': 1
    vgg_type: vgg19
    use_input_norm: true
    perceptual_weight: !!float 1.0
    style_weight: 0
    range_norm: false
    criterion: l1
  # gan loss
  gan_opt:
    type: GANLoss
    gan_type: vanilla
    real_label_val: 1.0
    fake_label_val: 0.0
    loss_weight: !!float 1e-1

  net_d_iters: 1
  net_d_init_iters: 0

# Uncomment these for validation
# validation settings
# val:
#   val_freq: !!float 5e3
#   save_img: True

#   metrics:
#     psnr: # metric name
#       type: calculate_psnr
#       crop_border: 4
#       test_y_channel: false

# logging settings
logger:
  print_freq: 100
  save_checkpoint_freq: !!float 5e3
  use_tb_logger: true
  wandb:
    project: ~
    resume_id: ~

# dist training settings
dist_params:
  backend: nccl
  port: 29500

3、其他关键信息

3.1 训练与测试

训练命令如下,主要是-opt对应的yml文件,该文件即为2.7中对应的配置项

python basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml

完整的命令行参数如下:

  • -opt,配置文件的路径,一般采用这个命令配置训练或者测试的 yml 文件。
  • – laucher,用于指定 distibuted training 的,比如 pytorch 或者 slurm。默认是 none,
    即单卡非 distributed training。
  • – auto_resume,是否自动 resume,即自动查找最近的 checkpoint ,然后 resume。
  • – debug,能够快速帮助 debug。
  • – local_rank,这个不用管,是 distributed training 中程序自动会传入。
  • – force_yml,方便在命令行中修改 yml 中的配置文件。

测试命令如下:

python basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml

3.2 模型保存与训练状态恢复

训练的时候, checkpoints 会保存两个文件:

  1. 网 络 参 数 .pth 文 件 。 在 每 个 实 验 的 models 文 件 夹 中 , 文 件 名 诸
    如:net_g_5000.pth、net_g_10000.pth
  2. 包含 optimizer 和 scheduler 信息的 .state 文件。在每个实验的 training_states 文件夹中,
    文件名诸如:5000.state、10000.state
    根据这两个文件,就可以 resume 了。

对应的参数配置如下,对应pretrain_network_g与resume_state的配置

path:
  # use the pre-trained Real-ESRNet model
  pretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pth
  param_key_g: params_ema
  strict_load_g: true
  resume_state: True # 默认为 ~, 表示删除该参数

也可以在命令行中加入 ‘–auto_resume’,程序就会找到保存的最近的模型
参数和状态,并加载进来,接着训练啦。

分布式训练 单机多卡 8 GPU训练命令

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7  python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher pytorch

单机多卡 4 GPU训练命令

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher pytorch

Slurm训练 4 GPU训练命令

GLOG_vmodule=MemcachedClient=-1 
srun -p [partition] --mpi=pmi2 --job-name=EDVRMwoTSA --gres=gpu:4 --ntasks=4 --ntasks-per-node=4 --cpus-per-task=4 --kill-on-bad-exit=1 \
python -u basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher="slurm"

3.3 模型EMA

EMA (Exponential Moving Average),指数移动平均。 它是用来“平均”一个变量在历史上的值。使用怎样的权重平均呢?如名字所示,随着时间,越是过往的时间,以一个指数衰减的权重来平均。

在 BasicSR 里面,EMA 一般作用在模型的参数上。它的效果一般是:
• 稳定训练效果。GAN 训练的结果一般瑕疵更少,视觉效果更好
• 对于以 PSNR 为目的的模型,其 PSNR 一般会更高一些

由于开启 EMA的代价几乎可以不计,所以我们推荐开启 EMA。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1925413.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【QT】Qt事件

目录 前置知识 事件概念 常见的事件描述 进入和离开事件 代码示例&#xff1a; 鼠标事件 鼠标点击事件 鼠标释放事件 鼠标双击事件 鼠标滚轮动作 键盘事件 定时器事件 开启定时器事件 窗口相关事件 窗口移动触发事件 窗口大小改变时触发的事件 扩展 前置知识…

知识改变命运 第七集(下):Java中数组的定义与使用

4. 数组练习 4.1 数组转字符串 import java.util.Arrays int[] arr {1,2,3,4,5,6}; String newArr Arrays.toString(arr); System.out.println(newArr); // 执行结果 [1, 2, 3, 4, 5, 6]使用这个方法后续打印数组就更方便一些. Java 中提供了 java.util.Arrays 包, 其中包含…

SwiftUI 截图(snapshot)视频画面的极简方法

功能需求 在 万物皆可截图:SwiftUI 中任意视图(包括List和ScrollView)截图的通用实现 这篇博文中,我们实现了在 SwiftUI 中截图几乎任何视图的功能,不幸的是它对视频截图却无能为力。不过别着急,我们还有妙招。 在上面的演示图片中,我们在 SwiftUI 中可以随心所欲的截图…

【ZooKeeper学习笔记】

1. ZooKeeper基本概念 Zookeeper官网&#xff1a;https://zookeeper.apache.org/index.html Zookeeper是Apache Hadoop项目中的一个子项目&#xff0c;是一个树形目录服务Zookeeper翻译过来就是动物园管理员&#xff0c;用来管理Hadoop&#xff08;大象&#xff09;、Hive&…

浪潮信息F-OCC算法夺冠,自动驾驶感知技术再创新高

浪潮信息&#xff0c;作为行业领先的AI技术提供商&#xff0c;其AI团队在近期举办的全球权威CVPR 2024自动驾驶国际挑战赛(Autonomous Grand Challenge)中大放异彩&#xff0c;凭借“F-OCC”算法模型以48.9%的卓越成绩&#xff0c;一举夺得占据栅格和运动估计(Occupancy & …

前端Vue组件化实践:打造仿京东天猫商品属性选择器组件

在前端开发领域&#xff0c;随着业务需求的日益复杂和技术的不断进步&#xff0c;传统的整体式应用开发模式已逐渐显得捉襟见肘。面对日益庞大的系统&#xff0c;每次微小的功能修改或增加都可能导致整个逻辑结构的重构&#xff0c;形成牵一发而动全身的困境。为了解决这一问题…

基于Node.js将个人网站部署到ECS

基于Node.js将个人网站部署到云端ECS 一、如何购买ECS以及如何使用学生认证优惠&#xff1f;1.进入阿里云网站&#xff0c;进行学生认证2.购买学生优惠&#xff0c;免费试用一个月3.重置个人密码 二、服务器的配置以及与宝塔面板的链接1.个人电脑打开终端&#xff08;winR->…

探索性数据分析:使用Python与Pandas库实现数据洞察

探索性数据分析&#xff1a;使用Python与Pandas库实现数据洞察 引言 在当今数据驱动的时代&#xff0c;数据分析已成为决策制定、策略规划和业务优化的关键环节。无论是商业智能、金融分析还是市场研究&#xff0c;数据分析都扮演着至关重要的角色。Pandas库作为Python生态系统…

一文SpringCloud

Springcloud 什么是Springcloud&#xff1f; 官网&#xff1a;Spring Cloud Data Flow Spring Cloud是一系列框架的有序集合。它利用Spring Boot的开发便利性巧妙地简化了分布式系统基础设施的开发&#xff0c;如服务发现注册、配置中心、消息总线、负载均衡、断路器、数据监控…

Flat Ads:金融APP海外广告投放素材的优化指南

在当今全球化的数字营销环境中,金融APP的海外营销推广已成为众多金融机构与开发者最为关注的环节之一。面对不同地域、文化及用户习惯的挑战,如何优化广告素材,以吸引目标受众的注意并促成有效转化,成为了广告主们亟待解决的问题。 作为领先的全球化营销推广平台,Flat Ads凭借…

树莓派PICO使用INA226测量电流和总线电压(3)

上一篇文章我们讲了如何测试电流&#xff0c;但是INA226有一个非常典型的问题&#xff0c;那就是误差比较大&#xff0c;因为采样电阻非常小&#xff0c;我的开发板用的是100mΩ的采样电阻&#xff0c;在设定中我也用的是这个采样电阻值&#xff0c;但事实上&#xff0c;测试得…

文件内容查阅

cat concatenate files and print on the standard output Linux中一个最简单的且最常用的命令是cat命令。其功能是在终端设备上显示文件内容。 cat命令-n选项用于显示行号。 tac concatenate and print files in reverse tac命令的功能是用于反向显示文件内容&#xff0c;即…

【Qt 基础】绘图

画笔 QPen pen; pen.setWidth(3); // 线条宽度 pen.setColor(Qt::red);// 画笔颜色 pen.setStyle(Qt::DashLine);// 线条样式 pen.setCapStyle(Qt::RoundCap);// 线端样式 pen.setJoinStyle(Qt::BevelJoin);// 连接样式 painter.setPen(pen);线条 线端 连接 画刷 QBrush bru…

css设置弹性flex后,如果设置100vh高度不撑满的原因

问题 父元素设置height为100%&#xff0c;有两个子元素&#xff0c;第一个设置height:100vh&#xff0c;第二个设置flex:1&#xff0c;此时第一个高度无法撑满盒子 原因解决方式 当父元素设置display为flex,第一个div设置高度64px,剩一个div设置高度为flex&#xff1a;1,这时…

数据处理-Matplotlib 绘图展示

文章目录 1. Matplotlib 简介2. 安装3. Matplotlib Pyplot4. 绘制图表1. 折线图2. 散点图3. 柱状图4. 饼图5. 直方图 5. 中文显示 1. Matplotlib 简介 Matplotlib 是 Python 的绘图库&#xff0c;它能让使用者很轻松地将数据图形化&#xff0c;并且提供多样化的输出格式。 Ma…

Qt项目中添加自定义文件夹,进行整理归类

Qt项目中添加文件夹进行归类 1、在windows的工程目录下创建一个文件夹&#xff0c;如widgets 2、将.h 、.cpp、.ui文件拷贝到windows该文件夹widgets 3、在qt工程中&#xff0c;根目录右键&#xff0c;选择添加现有文件&#xff0c;批量选择 .h 、.cpp、.ui文件之后&#xf…

初识影刀:EXCEL根据部门筛选低值易耗品

第一次知道这个办公自动化的软件还是在招聘网站上&#xff0c;了解之后发现对于办公中重复性的工作还是挺有帮助的&#xff0c;特别是那些操作非EXCEL的重复性工作&#xff0c;当然用在EXCEL上更加方便&#xff0c;有些操作比写VBA便捷。 下面就是一个了解基本操作后&#xff…

开发总结 - H5/web C端评论区开发逻辑

1. 背景 平时做的系统都是偏公司业务的系统&#xff0c;这次开发了一个用户评论的功能&#xff0c;同时开发了web版和H5版本的&#xff0c;因为没有做过这种C端的常用的功能&#xff0c;所以记录一下此次的开发&#xff0c;从参考友商设计到独立思考业务之间的区别再到实际开发…

everything搜索不到任何文件-设置

版本&#xff1a; V1.4.1.1024 (x64) 问题&#xff1a;搜索不到任何文件 click:[工具]->[选项]->下图所示 将本地磁盘都选中包含

2024 辽宁省大学数学建模竞赛B 题 钢铁产品质量优化完整思路 代码 结果分享(仅供学习)

冷轧带钢是钢铁企业的高附加值产品,其产品质量稳定性对于钢铁企业的经济效益具有非常重要的影响。在实际生产中&#xff0c;冷连轧之后的带钢需要经过连续退火处理来消除因冷轧产生的内应力并提高其机械性能。连续退火的工艺流程如图1 所示&#xff0c;一般包括加热、保温、缓冷…