深度学习之超分辨率算法——FRCNN

news2024/12/20 21:55:46

– 对之前SRCNN算法的改进

    1. 输出层采用转置卷积层放大尺寸,这样可以直接将低分辨率图片输入模型中,解决了输入尺度问题。
    2. 改变特征维数,使用更小的卷积核和使用更多的映射层。卷积核更小,加入了更多的激活层。
    3. 共享其中的映射层,如果需要训练不同上采样倍率的模型,只需要修改最后的反卷积层大小,就可以训练出不同尺寸的图片。
  • 模型实现
  • 在这里插入图片描述
import math
from torch import nn


class FSRCNN(nn.Module):
    def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
        super(FSRCNN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),
            nn.PReLU(d)
        )
        # 添加入多个激活层和小卷积核
        self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]
        for _ in range(m):
            self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])
        self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])
        self.mid_part = nn.Sequential(*self.mid_part)
        # 最后输出
        self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,
                                            output_padding=scale_factor-1)

        self._initialize_weights()

    def _initialize_weights(self):
        # 初始化
        for m in self.first_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
        nn.init.zeros_(self.last_part.bias.data)

    def forward(self, x):
        x = self.first_part(x)
        x = self.mid_part(x)
        x = self.last_part(x)
        return x

以上代码中,如起初所说,将SRCNN中给的输出修改为转置卷积,并且在中间添加了多个11卷积核和多个线性激活层。且应用了权重初始化,解决协变量偏移问题。
备注:1
1卷积核虽然在通道的像素层面上,针对一个像素进行卷积,貌似没有什么作用,但是卷积神经网络的特性,我们在利用多个卷积核对特征图进行扫描时,单个卷积核扫描后的为sum©,那么就是尽管在像素层面上无用,但是在通道层面上进行了融合,并且进一步加深了层数,使网络层数增加,网络能力增强。

  • 上代码
  • train.py

训练脚本

import argparse
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

from models import FSRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # 训练文件
    parser.add_argument('--train-file', type=str,help="the dir of train data",default="./Train/91-image_x4.h5")
    # 测试集文件
    parser.add_argument('--eval-file', type=str,help="thr dir of test data ",default="./Test/Set5_x4.h5")
    # 输出的文件夹
    parser.add_argument('--outputs-dir',help="the output dir", type=str,default="./outputs")
    parser.add_argument('--weights-file', type=str)
    parser.add_argument('--scale', type=int, default=2)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--num-epochs', type=int, default=20)
    parser.add_argument('--num-workers', type=int, default=8)
    parser.add_argument('--seed', type=int, default=123)
    args = parser.parse_args()

    args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))

    if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)

    model = FSRCNN(scale_factor=args.scale).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam([
        {'params': model.first_part.parameters()},
        {'params': model.mid_part.parameters()},
        {'params': model.last_part.parameters(), 'lr': args.lr * 0.1}
    ], lr=args.lr)

    train_dataset = TrainDataset(args.train_file)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    eval_dataset = EvalDataset(args.eval_file)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    for epoch in range(args.num_epochs):
        model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:
            t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))

            for data in train_dataloader:
                inputs, labels = data

                inputs = inputs.to(device)
                labels = labels.to(device)

                preds = model(inputs)

                loss = criterion(preds, labels)

                epoch_losses.update(loss.item(), len(inputs))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
                t.update(len(inputs))

        torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))

        model.eval()
        epoch_psnr = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

test.py 测试脚本

import argparse

import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image

from models import FSRCNN
from utils import convert_ycbcr_to_rgb, preprocess, calc_psnr


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights-file', type=str, required=True)
    parser.add_argument('--image-file', type=str, required=True)
    parser.add_argument('--scale', type=int, default=3)
    args = parser.parse_args()

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = FSRCNN(scale_factor=args.scale).to(device)

    state_dict = model.state_dict()
    for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()

    image = pil_image.open(args.image_file).convert('RGB')

    image_width = (image.width // args.scale) * args.scale
    image_height = (image.height // args.scale) * args.scale

    hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)
    bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
    bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

    lr, _ = preprocess(lr, device)
    hr, _ = preprocess(hr, device)
    _, ycbcr = preprocess(bicubic, device)

    with torch.no_grad():
        preds = model(lr).clamp(0.0, 1.0)

    psnr = calc_psnr(hr, preds)
    print('PSNR: {:.2f}'.format(psnr))

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    # 保存图片
    output.save(args.image_file.replace('.', '_fsrcnn_x{}.'.format(args.scale)))

datasets.py

数据集的读取

import h5py
import numpy as np
from torch.utils.data import Dataset


class TrainDataset(Dataset):
    def __init__(self, h5_file):
        super(TrainDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])


class EvalDataset(Dataset):
    def __init__(self, h5_file):
        super(EvalDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

工具文件utils.py

  • 主要用来测试psnr指数,图片的格式转换(悄悄说一句,opencv有直接实现~~~)
import torch
import numpy as np


def calc_patch_size(func):
    def wrapper(args):
        if args.scale == 2:
            args.patch_size = 10
        elif args.scale == 3:
            args.patch_size = 7
        elif args.scale == 4:
            args.patch_size = 6
        else:
            raise Exception('Scale Error', args.scale)
        return func(args)
    return wrapper


def convert_rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.


def convert_rgb_to_ycbcr(img, dim_order='hwc'):
    if dim_order == 'hwc':
        y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
        cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.
        cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.
    else:
        y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
        cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.
        cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.
    return np.array([y, cb, cr]).transpose([1, 2, 0])


def convert_ycbcr_to_rgb(img, dim_order='hwc'):
    if dim_order == 'hwc':
        r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921
        g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576
        b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836
    else:
        r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921
        g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576
        b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836
    return np.array([r, g, b]).transpose([1, 2, 0])


def preprocess(img, device):
    img = np.array(img).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(img)
    x = ycbcr[..., 0]
    x /= 255.
    x = torch.from_numpy(x).to(device)
    x = x.unsqueeze(0).unsqueeze(0)
    return x, ycbcr


def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


先跑他个几十轮~
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2262921.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VSCode 搭建Python编程环境 2024新版图文安装教程(Python环境搭建+VSCode安装+运行测试+背景图设置)

名人说:一点浩然气,千里快哉风。—— 苏轼《水调歌头》 创作者:Code_流苏(CSDN) 目录 一、Python环境安装二、VScode下载及安装三、VSCode配置Python环境四、运行测试五、背景图设置 很高兴你打开了这篇博客,更多详细的安装教程&…

使用Docker启用MySQL8.0.11

目录 一、Docker减小镜像大小的方式 1、基础镜像选择 2、减少镜像层数 3、清理无用文件和缓存 4、优化文件复制(COPY和ADD指令) 二、Docker镜像多阶段构建 1、什么是dockers镜像多阶段构建 1.1 概念介绍 1.2 构建过程和优势 2、怎样在Dockerfil…

Windows安全中心(病毒和威胁防护)的注册

文章目录 Windows安全中心(病毒和威胁防护)的注册1. 简介2. WSC注册初探3. WSC注册原理分析4. 关于AMPPL5. 参考 Windows安全中心(病毒和威胁防护)的注册 本文我们来分析一下Windows安全中心(Windows Security Center…

Hive其一,简介、体系结构和内嵌模式、本地模式的安装

目录 一、Hive简介 二、体系结构 三、安装 1、内嵌模式 2、测试内嵌模式 3、本地模式--最常使用的模式 一、Hive简介 Hive 是一个框架,可以通过编写sql的方式,自动的编译为MR任务的一个工具。 在这个世界上,会写SQL的人远远大于会写ja…

时空AI赋能低空智能科技创新

随着人工智能技术的不断进步,时空人工智能(Spatio-Temporal AI,简称时空AI)正在逐渐成为推动低空经济发展的新引擎。时空AI结合了地理空间智能、城市空间智能和时空大数据智能,为低空智能科技创新提供了强大的数据支持…

java 通过jdbc连接sql2000方法

1、java通过jdbc连接sql2000 需要到三个jar包:msbase.jar mssqlserver.jar msutil.jar 下载地址:https://download.csdn.net/download/sunfor/90145580 2、将三个jar包解压到程序中的LIB下: 导入方法: ①在当前目录下&#xff…

web实验二

web实验二 2024.12.19 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>青岛理工大学</title>&l…

纯前端实现更新检测

通过判断打包后的html文件中的js入口是否发生变化&#xff0c;进而实现前端的代码更新 为了使打包后的文件带有hash值&#xff0c;需要对vite打包进行配置 import { defineConfig } from vite; import vue from vitejs/plugin-vue; import { resolve } from path; import AutoI…

云原生周刊:Kubernetes v1.32 正式发布

开源项目推荐 Helmper Helmper 简化了将 Helm Charts导入OCI&#xff08;开放容器倡议&#xff09;注册表的过程&#xff0c;并支持可选的漏洞修复功能。它确保您的 Helm Charts不仅安全存储&#xff0c;还能及时应用最新的安全修复。该工具完全兼容 OCI 标准&#xff0c;能够…

【游戏中orika完成一个Entity的复制及其Entity异步落地的实现】 1.ctrl+shift+a是飞书下的截图 2.落地实现

一、orika工具使用 1)工具类 package com.xinyue.game.utils;import ma.glasnost.orika.MapperFactory; import ma.glasnost.orika.impl.DefaultMapperFactory;/*** author 王广帅* since 2022/2/8 22:37*/ public class XyBeanCopyUtil {private static MapperFactory mappe…

如何在繁忙的生活中找到自己的节奏?

目录 一、理解生活节奏的重要性 二、分析当前生活节奏 1. 时间分配 2. 心理状态 3. 身体状况 4. 生活习惯 1. 快慢适中 2. 张弛结合 3. 与目标相符 三、掌握调整生活节奏的策略 1. 设定优先级 2. 合理规划时间 3. 学会拒绝与取舍 4. 保持健康的生活方式 5. 留出…

CORDIC 算法实现 _FPGA

注&#xff1a;本文为 “CORDIC 算法” 相关文章合辑。 未整理去重。 如有内容异常&#xff0c;请看原文。 Cordic 算法的原理介绍 乐富道 2014-01-28 23:05 Cordic 算法知道正弦和余弦值&#xff0c;求反正切&#xff0c;即角度。 采用用不断的旋转求出对应的正弦余弦值&…

鸿蒙学习笔记:用户登录界面

文章目录 1. 提出任务2. 完成任务2.1 创建鸿蒙项目2.2 准备图片资源2.3 编写首页代码2.4 启动应用 3. 实战小结 1. 提出任务 本次任务聚焦于运用 ArkUI 打造用户登录界面。需呈现特定元素&#xff1a;一张图片增添视觉感&#xff0c;两个分别用于账号与密码的文本输入框&#…

着色器 (三)

今天&#xff0c;是我们介绍opengl着色器最后一章&#xff0c;着色器(Shader)是运行在GPU上的小程序。这些小程序为图形渲染管线的某个特定部分而运行。从基本意义上来说&#xff0c;着色器只是一种把输入转化为输出的程序。着色器也是一种非常独立的程序&#xff0c;因为它们之…

leetcode:3285. 找到稳定山的下标(python3解法)

难度&#xff1a;简单 有 n 座山排成一列&#xff0c;每座山都有一个高度。给你一个整数数组 height &#xff0c;其中 height[i] 表示第 i 座山的高度&#xff0c;再给你一个整数 threshold 。 对于下标不为 0 的一座山&#xff0c;如果它左侧相邻的山的高度 严格大于 thresho…

深度学习之超分辨率算法——SRGAN

更新版本 实现了生成对抗网络在超分辨率上的使用 更新了损失函数&#xff0c;增加先验函数 SRresnet实现 import torch import torchvision from torch import nnclass ConvBlock(nn.Module):def __init__(self, kernel_size3, stride1, n_inchannels64):super(ConvBlock…

集成方案 | Docusign + 金蝶云,实现合同签署流程自动化!

本文将详细介绍 Docusign 与金蝶云的集成步骤及其效果&#xff0c;并通过实际应用场景来展示 Docusign 的强大集成能力&#xff0c;以证明 Docusign 集成功能的高效性和实用性。 在当今商业环境中&#xff0c;流程的无缝整合与数据的实时性对于企业的成功至关重要。金蝶云&…

数据结构----链表头插中插尾插

一、链表的基本概念 链表是一种线性数据结构&#xff0c;它由一系列节点组成。每个节点包含两个主要部分&#xff1a; 数据域&#xff1a;用于存储数据元素&#xff0c;可以是任何类型的数据&#xff0c;如整数、字符、结构体等。指针域&#xff1a;用于存储下一个节点&#…

Service Discovery in Microservices 客户端/服务端服务发现

原文链接 Client Side Service Discovery in Microservices - GeeksforGeeks 原文链接 Server Side Service Discovery in Microservices - GeeksforGeeks 目录 服务发现介绍 Server-Side 服务发现 实例&#xff1a; Client-Side 服务发现 实例&#xff1a; 服务发现介绍…

Git连接远程仓库(超详细)

目录 一、Gitee 远程仓库连接 1. HTTPS 方式 2. SSH公钥方式 &#xff08;1&#xff09;账户公钥 &#xff08;2&#xff09;仓库公钥 仓库的 SSH Key 和账户 SSH Key 的区别&#xff1f;​ 二、GitHub远程仓库连接 1. HTTPS方式 2.SSH公钥方式 本文将介绍如何通过 H…