Pytorch训练RCAN QAT超分模型

news2024/9/30 15:18:47

Pytorch训练RCAN QAT超分模型

  • 版本信息
  • 测试步骤
    • 准备数据集
    • 创建容器
    • 生成文件列表
      • 创建文件列表的代码
      • 执行脚本,生成文件列表
    • 训练RCAN模型
      • 准备工作
      • 修改开源代码
      • 编写训练代码
      • 执行训练脚本
    • 可视化

本文以RCAN超分模型为例,演示了QAT的训练过程,步骤如下:

  • 先训练FP32模型
  • 再加载FP32训练的权值,进行QAT训练
  • 连续5次loss没有下降则停止训练
  • 为了加快演示,当psnr大于33.0时就停止训练
  • 采用tensorboard观察Loss曲线

版本信息

属性
训练环境 搭建步骤
GPU型号 NVIDIA GeForce RTX 3080 12GB
数据集下载链接 http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip
http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip
开源模型结构 https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/rcan.py
https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/option.py
https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/common.py
https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/template.py

测试步骤

准备数据集

wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip

创建容器

按https://editor.csdn.net/md/?articleId=136176989的步骤构建镜像

docker stop rcan_dev
docker rm rcan_dev
nvidia-docker run -ti -e NVIDIA_VISIBLE_DEVICES=all --privileged \
				--net=host -p 6006:6006 -v $PWD:/home -w /home  \
				-v /mnt/disk/RCAN/:/RCAN --name rcan_dev  cuda_dev_image:v1.0 /bin/bash
conda activate ai_dev

生成文件列表

创建文件列表的代码

# generate_datalist.py

import os
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm

train_HR_path = './DIV2K_train_HR'
train_LR_path = './DIV2K_train_LR_bicubic/X2'
valid_HR_path = './DIV2K_valid_HR'
valid_LR_path = './DIV2K_valid_LR_bicubic/X2'

train_file = 'datalist_div2k_train.txt'
valid_file = 'datalist_div2k_valid.txt'

def get_images(input_path, format='png'):
    names = [os.path.splitext(fname)[0]
            for fname in os.listdir(input_path)
            if fname.endswith(format)]
    names.sort()
    return names

def get_folders(input_path):
    names = [directory 
            for directory in os.listdir(input_path)
            if os.path.isdir(os.path.join(input_path, directory))]
    names.sort()
    return names

the_train_file = open(train_file, 'w')
image_names = get_images(train_HR_path)
for image_name in image_names:
    the_train_file.write('DIV2K_train_LR_bicubic/X2/' + image_name + 
            'x2.png' + ' ' + 'DIV2K_train_HR/' + image_name + '.png' + '\n')
the_train_file.close()

the_valid_file = open(valid_file, 'w')
image_names = get_images(valid_HR_path)
for image_name in image_names: 
    the_valid_file.write('DIV2K_valid_LR_bicubic/X2/' + image_name + 
            'x2.png' + ' ' + 'DIV2K_valid_HR/' + image_name + '.png' + '\n')
the_valid_file.close()

执行脚本,生成文件列表

cd /RCAN/
unzip DIV2K_train_HR.zip
unzip DIV2K_valid_HR.zip
unzip DIV2K_train_LR_bicubic_X2.zip
unzip DIV2K_valid_LR_bicubic_X2.zip
python generate_datalist.py

训练RCAN模型

准备工作

# 安装依赖
pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install scikit-image -i https://pypi.tuna.tsinghua.edu.cn/simple

# 设置环境变量
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

# 下载开源模型源码
cd /RCAN/
mkdir model
curl -L -o model/rcan.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/rcan.py
curl -L -o model/option.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/option.py
curl -L -o model/common.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/common.py
curl -L -o template.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/template.py

修改开源代码

  • model/rcan.py

image-20240220142852491

image-20240220144639916

  • model/common.py

    image-20240220143210588

编写训练代码

# train.py

import os
import torch
import torch.nn as nn
import torch.optim as optim
import json
import copy
import time
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.quantization.quantize_fx import prepare_qat_fx,convert_fx
from torch.ao.quantization import qconfig
from torch.ao.quantization.fake_quantize import *
from torch.ao.quantization.observer import *
from torch.utils import tensorboard
from torch.autograd import Variable
from torch.utils.data import Dataset
from skimage.color import rgb2hsv, hsv2rgb
import imageio
import random
import numpy as np

def _apply(func, x):

    if isinstance(x, (list, tuple)):
        return [_apply(func, x_i) for x_i in x]
    elif isinstance(x, dict):
        y = {
   }
        for key, value in x.items():
            y[key] = _apply(func, value)
        return y
    else:
        return func(x)

def get_patch(*args, patch_size=96, scale=2, input_large=False):
    ih, iw = args[0].shape[:2]

    if not input_large:
        p = scale
        tp = p * patch_size
        ip = tp // scale
    else:
        tp = patch_size
        ip = patch_size

    ix = random.randrange(0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1472704.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

壹[1],图像源

1,工具名称:图像源 2,参数说明 2.1,图像源 注: 本地图像,使用本地图片以及本地图像文件夹 相机,连接的相机 SDK,使用相机的SDK,而不是海康SDK 2.2,像素格式 注&…

Jeecg项目部署

说明:Jeecg是一款低代码开发平台,简单说是一款现成的项目,该项目集成了许多功能,我们可以在这个项目之上开发自己的业务代码。 本文介绍Jeecg项目的部署,包括后端jeecg-boot项目、前端vue3项目。前端项目在本地Window…

VScode连接远端服务器一直输入密码解决方法

文章目录 1 关闭远程连接2打开命令面板3 输入remote-ssh: kill vs code server on host… 1 关闭远程连接 2打开命令面板 3 输入remote-ssh: kill vs code server on host… remote-ssh: kill vs code server on host… 然后一路回车(选中出问题的主机),输一遍密码…

真正理解微软Windows程序运行机制——窗口机制(第一部分)

我是荔园微风,作为一名在IT界整整25年的老兵,今天说说Windows程序的运行机制。经常被问到MFC到底是一个什么技术,为了解释这个我之前还写过帖子,但是很多人还是不理解。其实这没什么,我在学生时代也被这个问题困绕过。…

【日常聊聊】Sora- 探索AI视频模型的无限可能

🍎个人博客:个人主页 🏆个人专栏:日常聊聊 ⛳️ 功不唐捐,玉汝于成 目录 前言 正文 方向一:技术解析 方向二:应用场景 方向三:未来展望 方向四:伦理与创意 方向…

深入理解JS的执行上下文、词法作用域和闭包(下)

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

【PX4SimulinkGazebo联合仿真】在Simulink中使用ROS2控制无人机沿自定义圆形轨迹正向飞行(带偏航角控制)并在Gazebo中可视化

在Simulink中使用ROS2控制无人机沿自定义圆形轨迹正向飞行(带偏航角控制)并在Gazebo中可视化 系统架构Matlab官方例程Control a Simulated UAV Using ROS 2 and PX4 Bridge运行所需的环境配置PX4&Simulink&Gazebo联合仿真实现方法建立Simulink模…

vue3自定义实现悬浮固定按钮组件

目录 一、需求描述二、代码解读三、结果展示 一、需求描述 需要5个固定的悬浮圆,居于页面的右侧。鼠标悬浮在圆上面会显示对应的文字提示其中包含返回顶部悬浮圆,当页面滑至底部时出现,点击页面滑到顶部。点击按钮会给出弹窗 二、代码解读…

LCR 172. 统计目标成绩的出现次数

解题思路&#xff1a;二分查找 题解一 class Solution {public int countTarget(int[] scores, int target) {// 搜索右边界 rightint i 0, j scores.length - 1;while(i < j) {int m (i j) / 2;if(scores[m] < target) i m 1;else j m - 1;}int right i;// 若数…

UE5 C++ Gas开发 学习记录(一)

一个新坑,在TPS的空余时间学习 创建了自己,敌人的BaseCharacter和子类,创建了Gamemode,创建了Controller AuraCharacterBase.h // Fill out your copyright notice in the Description page of Project Settings. #pragma once #include "CoreMinimal.h" #include &…

H264/H265基本编码参数1

本文主要讲解一些视频编码相关的基本概念 像素 像素是图像的基本单元&#xff0c;一个个像素就组成了图像。你可以认为像素就是图像中的一个点。我们来直观地看看像素是怎么组成图像的。在下面这张图中&#xff0c;你可以看到一个个方块&#xff0c;这些方块就是像素。 分辨…

【GameFramework框架内置模块】4、内置模块之调试器(Debugger)

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址QQ群&#xff1a;398291828 大家好&#xff0c;我是佛系工程师☆恬静的小魔龙☆&#xff0c;不定时更新Unity开发技巧&#xff0c;觉得有用记得一键三连哦。 一、前言 【GameFramework框架】系列教程目录&#xff1a;…

开启数字内容创作的新时代

目录 技术解析 未来展望 技术解析 Sora是一款由OpenAI开发的先进AI视频模型&#xff0c;其技术架构基于深度学习和自然语言处理技术。该模型的核心算法原理包括使用深度神经网络进行视频内容的理解、生成和互动。 在技术架构方面&#xff0c;Sora采用了一种混合的神经网络结…

五种多目标优化算法(NSWOA、MOJS、MOAHA、MOPSO、NSGA2)性能对比(提供MATLAB代码)

一、5种多目标优化算法简介 1.1NSWOA 1.2MOJS 1.3MOAHA 1.4MOPSO 1.5NSGA2 二、5种多目标优化算法性能对比 为了测试5种算法的性能将其求解9个多目标测试函数&#xff08;zdt1、zdt2 、zdt3、 zdt4、 zdt6 、Schaffer、 Kursawe 、Viennet2、 Viennet3&#xff09;&#xff0…

15:00面试,15:06就出来了,问的问题过于变态了。。。

我从一家小公司转投到另一家公司&#xff0c;期待着新的工作环境和机会。然而&#xff0c;新公司的加班文化让我有些始料未及。虽然薪资相对较高&#xff0c;但长时间的工作和缺乏休息使我身心俱疲。 就在我逐渐适应这种高强度的工作节奏时&#xff0c;公司突然宣布了一则令人…

EXCEL如何从另一个表查找匹配信息

目录 1.背景&#xff1a;我们有一个目标呈现表&#xff0c;想要从另一个表中查询得到信息&#xff0c;比如根据身份证id查询该id的名字、性别等个人基本信息&#xff0c;或者从另一个财务信息表查询该id的工资信息等&#xff1b; 2.基础方法&#xff1a;利用VLOOKUP函数根据单…

NGINX服务器配置实现加密的WebSocket连接WSS协议

一、背景 最近在做小程序开发&#xff0c;需要在nginx中配置websocket加密模式&#xff0c;即wss。初次配置wss时&#xff0c;踩了两个小时的坑&#xff0c;本文将踩坑过程分享给大家&#xff0c;有需要用到的伙伴可以直接copy即可实现&#xff0c;节省宝贵时间。 二、WebSo…

VS2022调试技巧(一)

什么是bug&#xff1f; 在1945年&#xff0c;美国科学家Grace Hopper在进行计算机编程时&#xff0c;发现一只小虫子钻进了一个真空管&#xff0c;导致计算机无法正常工作。她取出虫子后&#xff0c;计算机恢复了正常&#xff0c;由此&#xff0c;她首次将“Bug”这个词用来描…

用html编写的小广告板

用html编写的小广告板 相关代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</tit…

刘知远LLM——Transformer与预训练模型

文章目录 注意力机制原理介绍注意力机制的各种变式注意力机制的特点 Transformer结构概述Transformer整体结构 输入层byte pair encodingpositional encoding Transformer BlockEncoder BlockMulti-Head Attention Decoder Block其他tricks总结 预训练语言模型语言建模概述预训…