1. 项目结构
如何生成文件夹的文件目录呢?
按住shift键,右击你要生成目录的文件夹,选择“在此处打开Powershell窗口”
在命令窗口里输入命令“tree”,按回车。就会显示出目录结构。
├─.idea
│ └─inspectionProfiles
├─benchmark_results
├─data
│ ├─test
│ │ ├─Manga109
│ │ ├─Set14
│ │ ├─Set5
│ │ └─Urban100
│ ├─train_DIV2K_HR
│ └─valid_DIV2K_HR
├─epochs
├─pytorch_ssim
│ └─__pycache__
├─statistics
├─training_results
│ └─SRF_4
└─__pycache__
为了更好地记录这个代码文件夹的结构,我再把.py文件添上去
├─.idea
│ └─inspectionProfiles
├─benchmark_results
├─data
│ ├─test
│ │ ├─Manga109
│ │ ├─Set14
│ │ ├─Set5
│ │ └─Urban100
│ ├─train_DIV2K_HR
│ └─valid_DIV2K_HR
├─epochs
├─pytorch_ssim
│ └─__pycache__
├─statistics
├─training_results
│ └─SRF_4
├─data_utils.py
├─loss.py
├─model.py
├─README.md
├─test_benchmark.py
├─test_image.py
└─train.py
才拿到代码包的时候,每一个空文件夹下都有一个“.gitkeep文件”。
那么什么是“.gitkeep文件”呢?
因为Git 是一个文件追踪系统,所以Git 不会追踪一个空目录。当我们需要保留空目录的时候,“.gitkeep文件”可以使 Git 保留一个空文件夹。
2. 实验细节
算法名称 | SRGAN |
图像域 | RGB |
下采样方法 | 双三次核函数下采样4⨉ |
目标函数 | 内容损失+对抗损失 |
生成器 | SRResNet |
判别器 | VGG:判别HR与SR |
训练集 | DIV2K,800张 |
验证集 | DIV2K,100张 |
测试集 | Set5、Set14、BSD100、Urban100、Manga109 |
参数配置(在train.py中)
parser = argparse.ArgumentParser(description='Train Super Resolution Models')
parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')
parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8],
help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')
GPU
为了方便,然后训练集本来也不大,就在本地的NVIDIA GeForce RTX 3050上跑的
持续运行
因为经常把电脑背来背去的,会放进包里,所以要求电脑合上的时候程序也能继续运行。具体实现方法是:
1、点击开始图标,点击控制面板。
2、查看方式选择为“类别”,找到“硬件和声音”功能并点击。
3、在硬件和声音页面,找到更改电源按钮的功能选项并点击。
4、将“关闭盖子时”后方都设置为“不采取任何操作”,最后保存修改即可。
3. 项目解析
benchmark_results
训练完成后,训练结果会保存到benchmark_results 文件夹中
data
存放训练集、验证集、测试集的地方。
epochs
用于存放每个epoch训练得到的生成器和判别器的模型参数。
pytorch_ssim
计算结构相似性指数SSIM
statistics
存放记录每个epoch训练结果的表格,每跑10个epochs记录一次
training_results
存放验证集结果
里面有一个名为“SRF_4”的文件夹,意思是4⨉的双三次核函数下采样、放大因子为4。
“SRF_4”的文件夹存放着每一个epoch在验证集上的可视化结果,于展示图像超分辨率模型在训练过程中的性能表现。
每组图片包含三列:原始低分辨率图像(val_hr_restore
)、对应的高分辨率图像(val_hr
)以及模型生成的超分辨率图像(sr
)。
data_utils.py
数据集加载
train.py有一行导包代码
from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform
所以data_utils.py有一些有关数据集加载的函数以供train.py使用。
都是从库里导的包
from os import listdir
from os.path import join
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
对图像的一些处理
# 判断文件名是否为常见图像文件格式(不区分大小写)
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
# 根据裁剪尺寸和放大因子计算有效的裁剪尺寸,确保能被放大因子整除
def calculate_valid_crop_size(crop_size, upscale_factor):
return crop_size - (crop_size % upscale_factor)
# 定义高分辨率训练图像的变换操作:随机裁剪后转换为张量
def train_hr_transform(crop_size):
return Compose([
RandomCrop(crop_size),# 随机裁剪图像,裁剪尺寸为传入的crop_size
ToTensor(),# 将裁剪后的图像转换为张量格式
])
# 定义低分辨率训练图像的变换操作:先转换为PIL图像,缩放后再转换为张量
def train_lr_transform(crop_size, upscale_factor):
return Compose([
ToPILImage(),# 将高分辨率图像张量转换为PIL图像对象
Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),# 将图像缩小为原始尺寸除以放大因子,使用双三次插值
ToTensor()# 将缩放后的低分辨率图像转换为张量格式
])
# 定义用于显示图像的变换操作:调整大小、中心裁剪后转换为张量
def display_transform():
return Compose([
ToPILImage(), # 将图像转换为PIL图像对象
Resize(400), # 将图像大小调整为400(可能是为了统一显示尺寸)
CenterCrop(400), # 进行中心裁剪,确保关键部分完整
ToTensor() # 将处理后的图像转换为张量格式
])
从文件夹中加载和预处理训练图像数据
class TrainDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, crop_size, upscale_factor):
super(TrainDatasetFromFolder, self).__init__()
self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
self.hr_transform = train_hr_transform(crop_size)
self.lr_transform = train_lr_transform(crop_size, upscale_factor)
def __getitem__(self, index):
hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
lr_image = self.lr_transform(hr_image)
return lr_image, hr_image
def __len__(self):
return len(self.image_filenames)
从文件夹中加载和预处理验证图像数据
class ValDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(ValDatasetFromFolder, self).__init__()
self.upscale_factor = upscale_factor
self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
def __getitem__(self, index):
hr_image = Image.open(self.image_filenames[index])
w, h = hr_image.size
crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
hr_image = CenterCrop(crop_size)(hr_image)
lr_image = lr_scale(hr_image)
hr_restore_img = hr_scale(lr_image)
return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
def __len__(self):
return len(self.image_filenames)
从文件夹中加载和预处理测试图像数据
class TestDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(TestDatasetFromFolder, self).__init__()
self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
self.upscale_factor = upscale_factor
self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]
def __getitem__(self, index):
image_name = self.lr_filenames[index].split('/')[-1]
lr_image = Image.open(self.lr_filenames[index])
w, h