最近学了一个超轻量化的超分辨率网络SESR,效果还不错。
目录
- 一、 源码包
- 二、 数据集的准备
- 2.1 官网下载
- 2.2 网盘下载
- 三、 训练环境配置
- 四、训练
- 4.1 修改配置参数
- 4.2 导入数据集
- 4.3 2倍超分网络训练
- 4.3.1 训练SESR-M5网络
- 4.3.2 训练SESR-M5网络
- 4.3.3 训练SESR-M11网络
- 4.4.4 训练SESR-XL网络
- 4.4 2倍超分网络模型
- 4.5 修改模型保存格式
- 4.6 4倍超分网络训练
- 4.6.1 训练SESR-M5网络
- 4.6.2 训练SESR-M5网络
- 4.6.3 训练SESR-M11网络
- 4.6.4 训练SESR-XL网络
- 4.7 4倍超分网络模型
- 五、量化训练
- 5.1 量化训练模型
- 六、模型推理测试
- 七、超分效果
- 八、总结
一、 源码包
SESR官网的地址为:官网
我自己调整过的源码包地址为:SESR完整包 提取码:b80m
论文地址:论文
源码包推荐使用我给的,我注释过很多地方,看起来不吃力,且我自己添加了推理测试脚本。
下载好源码包解压后的样子如下:
二、 数据集的准备
获取数据集可以有两种方法:
2.1 官网下载
直接运行源码包中的脚本文件train.py,会自动先下载div2k数据集,但是下载的非常慢,高分辨率数据集有3G多,容易下蹦了。默认会下载到系统C盘下,具体路径为:
C:\Users\Administrator\tensorflow_datasets\downloads,每次下载失败后再次运行又会重新生成序列码并重新下载,很麻烦。
如下:
2.2 网盘下载
我提供了一个我下载好并整理好的数据集,文件存放对应关系我都整理好了,学者可以直接下载导入使用,下载链接为:
三、 训练环境配置
该网络结构是在TensorFlow框架下运行的TensorFlow版本是2.3,还有一个包的版本是tensorflow_datasets==4.1,Pyhton3.6版本,额。。。。。。。。。。。。。。。。。。
踩了很多坑,最后我自己调通的版本是TensorFlow-gpu2.9,Python 3.7版本,tensorflow_datasets4.8.2,如下:
安装好TensorFlow-GPU后先测试一下能不能正常调用GPU,测试方法参考:添加链接描述
四、训练
4.1 修改配置参数
打开train.py文件,里面有些配置参数根据自己电脑情况修改:
train.py脚本中对应上图修改的地方如下:
4.2 导入数据集
下载好我提供的数据集后,解压好讲整个tensorflow_datasets文件夹放到data文件夹中,并将tensorflow_datasets文件夹所在路径赋值给变量data_dir,代码中具体的修改地方如下:
4.3 2倍超分网络训练
根据自己需求选择要训练深度:
4.3.1 训练SESR-M5网络
其中m = 5,f = 16,feature_size = 256,具有折叠线性块:
python train.py
4.3.2 训练SESR-M5网络
m = 5,f = 16,feature_size = 256,扩展线性块:
python train.py --linear_block_type expanded
4.3.3 训练SESR-M11网络
其中m = 11,f = 16,feature_size = 64,具有折叠线性块:
python train.py --m 11 --feature_size 64
4.4.4 训练SESR-XL网络
其中m = 11,f = 16,feature_size = 64,具有折叠线性块:
python train.py --m 11 --int_features 32 --feature_size 64
4.4 2倍超分网络模型
通过上面步骤训练好后会在logs文件中自动保存权重文件和模型,如下:
上面各个文件代表内容为:
.pb:表示protocol buffers,是模型结构和参数的二进制序列化文件。存储了模型的网络结构,变量,权重等信息。是模型persist的主要文件。
.data-00000-of-00001:存储了模型变量的取值,即模型权重参数的值。模型训练完成后保存的权重。
.index:索引文件,存放了参数tensor的meta信息,如tensor名称、维度等。用于定位data文件中的tensor数据。
checkpoints文件:存储模型训练过程中的参数,用于恢复训练。
4.5 修改模型保存格式
上面是默认的保存方式,学长如果需要其他格式的自己修改保存方法,具体修改地方如下:
4.6 4倍超分网络训练
4倍超分网络得在2倍超分模型基础上训练才行,网络深度自己选择:
4.6.1 训练SESR-M5网络
其中m = 5,f = 16,feature_size = 256,具有折叠线性块:
python train.py --scale 4
4.6.2 训练SESR-M5网络
m = 5,f = 16,feature_size = 256,扩展线性块:
python train.py --linear_block_type expanded --scale 4
4.6.3 训练SESR-M11网络
其中m = 11,f = 16,feature_size = 64,具有折叠线性块:
python train.py --m 11 --feature_size 64 --scale 4
4.6.4 训练SESR-XL网络
其中m = 11,f = 16,feature_size = 64,具有折叠线性块:
python train.py --m 11 --int_features 32 --feature_size 64 --scale 4
4.7 4倍超分网络模型
训练好后,模型会自动保存在logs文件中,如下:
五、量化训练
运行以下命令,在训练时对网络进行调试,并生成TFLITE(用于x2 SISR、SESR-M5网络):
python train.py --quant_W --quant_A --gen_tflite
5.1 量化训练模型
训练好后自动保存在logs/x2_models文件下,如下:
六、模型推理测试
推理脚本是我自己写的,具体使用如下,根据需求自行选择:
七、超分效果
八、总结
以上就是超分辨率——SESR网络训练并推理测试的详细图文教程,总结不易,给个三连多多支持,谢谢!欢迎留言讨论。