Pytorch训练RCAN QAT超分模型
- 版本信息
- 测试步骤
-
- 准备数据集
- 创建容器
- 生成文件列表
-
- 创建文件列表的代码
- 执行脚本,生成文件列表
- 训练RCAN模型
-
- 准备工作
- 修改开源代码
- 编写训练代码
- 执行训练脚本
- 可视化
本文以RCAN超分模型为例,演示了QAT的训练过程,步骤如下:
- 先训练FP32模型
- 再加载FP32训练的权值,进行QAT训练
- 连续5次loss没有下降则停止训练
- 为了加快演示,当psnr大于33.0时就停止训练
- 采用tensorboard观察Loss曲线
版本信息
属性 | 值 |
---|---|
训练环境 | 搭建步骤 |
GPU型号 | NVIDIA GeForce RTX 3080 12GB |
数据集下载链接 | http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip |
开源模型结构 | https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/rcan.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/option.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/common.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/template.py |
测试步骤
准备数据集
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip
创建容器
按https://editor.csdn.net/md/?articleId=136176989的步骤构建镜像
docker stop rcan_dev
docker rm rcan_dev
nvidia-docker run -ti -e NVIDIA_VISIBLE_DEVICES=all --privileged \
--net=host -p 6006:6006 -v $PWD:/home -w /home \
-v /mnt/disk/RCAN/:/RCAN --name rcan_dev cuda_dev_image:v1.0 /bin/bash
conda activate ai_dev
生成文件列表
创建文件列表的代码
# generate_datalist.py
import os
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
train_HR_path = './DIV2K_train_HR'
train_LR_path = './DIV2K_train_LR_bicubic/X2'
valid_HR_path = './DIV2K_valid_HR'
valid_LR_path = './DIV2K_valid_LR_bicubic/X2'
train_file = 'datalist_div2k_train.txt'
valid_file = 'datalist_div2k_valid.txt'
def get_images(input_path, format='png'):
names = [os.path.splitext(fname)[0]
for fname in os.listdir(input_path)
if fname.endswith(format)]
names.sort()
return names
def get_folders(input_path):
names = [directory
for directory in os.listdir(input_path)
if os.path.isdir(os.path.join(input_path, directory))]
names.sort()
return names
the_train_file = open(train_file, 'w')
image_names = get_images(train_HR_path)
for image_name in image_names:
the_train_file.write('DIV2K_train_LR_bicubic/X2/' + image_name +
'x2.png' + ' ' + 'DIV2K_train_HR/' + image_name + '.png' + '\n')
the_train_file.close()
the_valid_file = open(valid_file, 'w')
image_names = get_images(valid_HR_path)
for image_name in image_names:
the_valid_file.write('DIV2K_valid_LR_bicubic/X2/' + image_name +
'x2.png' + ' ' + 'DIV2K_valid_HR/' + image_name + '.png' + '\n')
the_valid_file.close()
执行脚本,生成文件列表
cd /RCAN/
unzip DIV2K_train_HR.zip
unzip DIV2K_valid_HR.zip
unzip DIV2K_train_LR_bicubic_X2.zip
unzip DIV2K_valid_LR_bicubic_X2.zip
python generate_datalist.py
训练RCAN模型
准备工作
# 安装依赖
pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install scikit-image -i https://pypi.tuna.tsinghua.edu.cn/simple
# 设置环境变量
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
# 下载开源模型源码
cd /RCAN/
mkdir model
curl -L -o model/rcan.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/rcan.py
curl -L -o model/option.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/option.py
curl -L -o model/common.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/common.py
curl -L -o template.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/template.py
修改开源代码
- model/rcan.py
-
model/common.py
编写训练代码
# train.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
import json
import copy
import time
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.quantization.quantize_fx import prepare_qat_fx,convert_fx
from torch.ao.quantization import qconfig
from torch.ao.quantization.fake_quantize import *
from torch.ao.quantization.observer import *
from torch.utils import tensorboard
from torch.autograd import Variable
from torch.utils.data import Dataset
from skimage.color import rgb2hsv, hsv2rgb
import imageio
import random
import numpy as np
def _apply(func, x):
if isinstance(x, (list, tuple)):
return [_apply(func, x_i) for x_i in x]
elif isinstance(x, dict):
y = {
}
for key, value in x.items():
y[key] = _apply(func, value)
return y
else:
return func(x)
def get_patch(*args, patch_size=96, scale=2, input_large=False):
ih, iw = args[0].shape[:2]
if not input_large:
p = scale
tp = p * patch_size
ip = tp // scale
else:
tp = patch_size
ip = patch_size
ix = random.randrange(0