basic_sr介绍

news2025/1/22 15:46:32

文章目录

  • pytorch基础知识和basicSR中用到的语法
    • 1.Sampler类与4种采样方式
    • 2.python dict的get方法使用
    • 3.prefetch_dataloader.py
    • 4. pytorch 并行和分布式训练
      • 4.1 选择要使用的cuda
      • 4.2 DataParallel使用方法
        • 常规使用方法
        • 保存和载入
      • 4.3 DistributedDataParallel
    • 5.wangdb 入门
      • 5.1 sign up(https://wandb.ai/site)
      • 5.2 安装和login
      • 5.3 demo
    • 5.model and train
      • 5.1 create model
      • 5.2 opt中设置
      • 5.2 SRModel 类

pytorch基础知识和basicSR中用到的语法

1.Sampler类与4种采样方式

一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
pytorch源码阅读(三)Sampler类与4种采样方式

下面代码是自定义的一个采样器:
ratio控制扩充数据集的倍数
num_replicas是进程数,一般是world_size
rank: 当前进程的rank

其实目的就是把数据集的索引划分为num_replicas组,供每个进程(process) 处理
至于ratio,是为了使每个epoch训练的数据增多,for saving time when restart the dataloader after each epoch

import math
import torch
from torch.utils.data.sampler import Sampler


class EnlargedSampler(Sampler):
    """Sampler that restricts data loading to a subset of the dataset.

    Modified from torch.utils.data.distributed.DistributedSampler
    Support enlarging the dataset for iteration-based training, for saving
    time when restart the dataloader after each epoch

    Args:
        dataset (torch.utils.data.Dataset): Dataset used for sampling.
        num_replicas (int | None): Number of processes participating in
            the training. It is usually the world_size.
        rank (int | None): Rank of the current process within num_replicas.
        ratio (int): Enlarging ratio. Default: 1.
    """

    def __init__(self, dataset, num_replicas, rank, ratio=1):
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        indices = torch.randperm(self.total_size, generator=g).tolist()

        dataset_size = len(self.dataset)
        indices = [v % dataset_size for v in indices]

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch

测试一下:

import numpy as np
if __name__ == "__main__":
    data = np.arange(20).tolist()
    en_sample = EnlargedSampler(data, 2, 0)
    en_sample.set_epoch(1)
    for i in en_sample:
        print(i)
    print('\n------------------\n')
    en_sample = EnlargedSampler(data, 2, 1)
    en_sample.set_epoch(1) # 设置为同一个epoch .  rank=0或者1时生成的index是互补的

    # 或者不用设置,默认为0即可。
    for i in en_sample:
        print(i)

结果:
在这里插入图片描述

2.python dict的get方法使用

在这里插入图片描述

3.prefetch_dataloader.py

在这里插入图片描述

载入本批数据的时候,预先载入下一批数据。主要看next函数

import queue as Queue
import threading
import torch
from torch.utils.data import DataLoader


class PrefetchGenerator(threading.Thread):
    """A general prefetch generator.

    Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch

    Args:
        generator: Python generator.
        num_prefetch_queue (int): Number of prefetch queue.
    """

    def __init__(self, generator, num_prefetch_queue):
        threading.Thread.__init__(self)
        self.queue = Queue.Queue(num_prefetch_queue)
        self.generator = generator
        self.daemon = True
        self.start()

    def run(self):
        for item in self.generator:
            self.queue.put(item)
        self.queue.put(None)

    def __next__(self):
        next_item = self.queue.get()
        if next_item is None:
            raise StopIteration
        return next_item

    def __iter__(self):
        return self


class PrefetchDataLoader(DataLoader):
    """Prefetch version of dataloader.

    Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#

    TODO:
    Need to test on single gpu and ddp (multi-gpu). There is a known issue in
    ddp.

    Args:
        num_prefetch_queue (int): Number of prefetch queue.
        kwargs (dict): Other arguments for dataloader.
    """

    def __init__(self, num_prefetch_queue, **kwargs):
        self.num_prefetch_queue = num_prefetch_queue
        super(PrefetchDataLoader, self).__init__(**kwargs)

    def __iter__(self):
        return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)


class CPUPrefetcher():
    """CPU prefetcher.

    Args:
        loader: Dataloader.
    """

    def __init__(self, loader):
        self.ori_loader = loader
        self.loader = iter(loader)

    def next(self):
        try:
            return next(self.loader)
        except StopIteration:
            return None

    def reset(self):
        self.loader = iter(self.ori_loader)


class CUDAPrefetcher():
    """CUDA prefetcher.

    Reference: https://github.com/NVIDIA/apex/issues/304#

    It may consume more GPU memory.

    Args:
        loader: Dataloader.
        opt (dict): Options.
    """

    def __init__(self, loader, opt):
        self.ori_loader = loader
        self.loader = iter(loader)
        self.opt = opt
        self.stream = torch.cuda.Stream()
        self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
        self.preload()

    def preload(self):
        try:
            self.batch = next(self.loader)  # self.batch is a dict
        except StopIteration:
            self.batch = None
            return None
        # put tensors to gpu
        with torch.cuda.stream(self.stream):
            for k, v in self.batch.items():
                if torch.is_tensor(v):
                    self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream) # 等待下一批处理完毕
        batch = self.batch # 赋值
        self.preload()     # 预先载入下一批
        return batch

    def reset(self):
        self.loader = iter(self.ori_loader)
        self.preload()

4. pytorch 并行和分布式训练

4.1 选择要使用的cuda

当我们的服务器上有多个GPU,我们应该指明我们使用的GPU是哪一块,如果我们不设置的话,tensor.cuda()方法会默认将tensor保存到第一块GPU上,等价于tensor.cuda(0),这将会导致爆出out of memory的错误。我们可以通过以下两种方式继续设置。

  1. 在文件最开始部分
    #设置在文件最开始部分
    import os
    os.environ["CUDA_VISIBLE_DEVICE"] = "0,1,2" # 设置默认的显卡
    
  2. 在命令行运行的时候设置
     CUDA_VISBLE_DEVICE=0,1 python train.py # 使用0,1两块GPU
    

4.2 DataParallel使用方法

常规使用方法
   model = UNetSeeInDark()
   model._initialize_weights()

   gpus = [0123]
   model = nn.DataParallel(model, device_ids=gpus)
   device = torch.device('cuda:0')
   model = model.to(device)
   # 如果不使用并行,只需要注释掉 model = nn.DataParallel(model, device_ids=gpus)
   # 如果要更改要使用的gpu, 更改gpus,和device中的torch.device('cuda:0')中的number即可
保存和载入

保存可以使用

# 因为model被DP wrap了,得先取出模型
save_model_path = os.path.join(save_model_dir, f'checkpoint_{epoch:05d}.pth')
# torch.save(model.state_dict(), save_model_path)
torch.save(model.module.state_dict(), save_model_path)

然后载入模型:

model_copy.load_state_dict(torch.load(m_path, map_location=device))

如果没有提出model.module进行保存
在载入的时候可能需要如下方式:

checkpoint = torch.load(m_path)
model_copy.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint.items()})

4.3 DistributedDataParallel

首先DataParallel是单进程多线程的方法,并且仅能工作在单机多卡的情况。而DistributedDataParallel方法是多进程,多线程的,并且适用与单机多卡和多机多卡的情况。即使在在单机多卡的情况下DistributedDataParallell也比DataParallel的速度更快。
目前还未深入理解:
深入理解Pytorch中的分布式训练
pytorch分布式训练
Pytorch中多GPU并行计算教程
PyTorch 并行训练极简 Demo

5.wangdb 入门

直接参看:https://docs.wandb.ai/quickstart
最详细的介绍和入门

5.1 sign up(https://wandb.ai/site)

在这里插入图片描述

5.2 安装和login

pip install wandb
wandb.login() 然后复制API key

5.3 demo

import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",

    # track hyperparameters and run metadata
    config={
        "learning_rate": 0.02,
        "architecture": "CNN",
        "dataset": "CIFAR-100",
        "epochs": 10,
    }
)

# simulate training
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
    acc = 1 - 2 ** -epoch - random.random() / epoch - offset
    loss = 2 ** -epoch + random.random() / epoch + offset

    # log metrics to wandb
    wandb.log({"acc": acc, "loss": loss})

# [optional] finish the wandb run, necessary in notebooks5b1bb8a27da51a7375b4b52c24a82fe1807877f1
wandb.finish()

运行之后:

wandb: Currently logged in as: wangty537. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.15.10
wandb: Run data is saved locally in D:\code\denoise\noise-synthesis-main\wandb\run-20230921_103737-j9ezjcqo
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run wobbly-jazz-1
wandb:  View project at https://wandb.ai/wangty537/my-awesome-project
wandb:  View run at https://wandb.ai/wangty537/my-awesome-project/runs/j9ezjcqo
wandb: Waiting for W&B process to finish... (success).
wandb: 
wandb: Run history:
wandb:  acc ▁▆▇██▇▇█
wandb: loss █▄█▁▅▁▄▁
wandb: 
wandb: Run summary:
wandb:  acc 0.88762
wandb: loss 0.12236
wandb: 
wandb:  View run wobbly-jazz-1 at: https://wandb.ai/wangty537/my-awesome-project/runs/j9ezjcqo
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: .\wandb\run-20230921_103737-j9ezjcqo\logs

然后可以在 https://wandb.ai/home 查看相关信息
在这里插入图片描述

https://docs.wandb.ai/quickstart 还介绍了更多高阶应用。

5.model and train

5.1 create model

利用注册机制

# create model
model = build_model(opt)
def build_model(opt):
    """Build model from options.

    Args:
        opt (dict): Configuration. It must contain:
            model_type (str): Model type.
    """
    opt = deepcopy(opt)
    model = MODEL_REGISTRY.get(opt['model_type'])(opt)
    logger = get_root_logger()
    logger.info(f'Model [{model.__class__.__name__}] is created.')
    return model

5.2 opt中设置

model_type: SRModel
scale: 2

5.2 SRModel 类

BaseModel是基类

@MODEL_REGISTRY.register()
class SRModel(BaseModel):
    xxx

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1106597.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

5秒用Java写一个快速排序算法?这个我在行

快速排序是一种非常高效的排序算法,由英国计算机科学家霍尔在1960年提出。它的基本思想是选择一个基准元素将待排序数组分成两部分,其中一部分的所有元素都比基准元素小,另一部分的所有元素都比基准元素大,然后对这两部分再分别进…

雷军在微博发文:小米澎湃 OS(Xiaomi HyperOS)正式版已完成封包

本心、输入输出、结果 文章目录 雷军在微博发文:小米澎湃 OS(Xiaomi HyperOS)正式版已完成封包前言搭载 小米澎湃 OS(Xiaomi HyperOS)的小米 14回顾 MIUI小米澎湃 OS(Xiaomi HyperOS) 相关跳转小…

spring boot MongoDB实战

项目搭建 <?xml version"1.0" encoding"UTF-8"?><project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 …

中文编程语言开发工具开发的软件案例:定制开发扫码识别位置程序适用于车间物品摆放管理

中文编程语言开发工具开发的软件案例&#xff1a;定制开发扫码识别位置程序适用于车间物品摆放管理 中文编程语言开发工具开发的软件案例&#xff1a;定制开发扫码识别位置程序适用于车间物品摆放管理&#xff0c; 中文编程系统化教程&#xff0c;不需英语基础。学习链接 htt…

string类模拟实现(c++)(学习笔记)

string 1.构造函数1.1 不带参构造1.2 带参数的构造函数1.3 合并两个构造函数。 2. 析构函数3.拷贝构造函数4. 赋值运算符重载5. size()/capacity()6. 解引用[]8.iterator迭代器7.Print()8.> 8. push_back()&append()8.1 reserve() 9. 10.insert()10.1 任意位置插入一个字…

在线客服软件的市场需求及前景如何?

随着互联网的不断发展&#xff0c;越来越多的企业开始意识到在线客服软件在客户服务中的重要作用。现在&#xff0c;各种形态的在线客服软件涌现出来&#xff0c;如何选择适合自己公司的线上客服软件成为了企业面临的一个挑战。本文将从市场需求和前景方面分析在线客服软件行业…

【Loopback Detection 环回检测以及原理解读】

Loopback Detection简介 Loopback Detection&#xff08;环回检测&#xff09;通过周期性发送环回检测报文来检测设备下挂网络是否存在环路。 网络中的环路会导致设备对广播、组播以及未知单播等报文进行重复发送&#xff0c;造成网络资源浪费甚至网络瘫痪。为了能够及时发现…

MS4553S双向电平转换器可pin对pin兼容TXB0102/TXS0102

MS4553S是一款双向电平转换器&#xff0c;可以用作混合电压的数字信号系统中。其使用两个独立构架的电源供电&#xff0c;A端供电电压范围是1.65V到5.5V&#xff0c;B端供电电压范围是2.3V到5.5V。可用在电压为1.8V、2.5V、3.3V和5V的信号转换系统中。当OE端为低电平时&#xf…

大学生毕业嵌入式和JAVA哪条未来更有前景?

今日话题&#xff0c;大学生毕业后选择嵌入式和Java两个岗位哪个更具前景&#xff1f;答案因个人情况而异。通常来说&#xff0c;对于零基础转行的同学&#xff0c;学习Java可能会是一个更广泛选择的建议&#xff0c;因为Java岗位更多&#xff0c;且不需要涉及硬件知识。然而&a…

Live800:一个优秀的客服应具备哪些技能?

一个优秀的客服应该具备哪些技能&#xff1f;这是每个企业在招聘和培训客服人员时都需要考虑的问题。一名优秀的客服不仅需要善于沟通&#xff0c;还需要具备专业知识、灵活应变、耐心细致等多方面的能力。在这篇文章中&#xff0c;我们将从多个方面探讨一个优秀客服应该具备什…

MATLAB——极限学习机参考程序

欢迎关注“电击小子程高兴的MATLAB小屋” %% I. 清空环境变量 clear all clc %% II. 训练集/测试集产生 %% % 1. 导入数据 load iris_data.mat %% % 2. 随机产生训练集和测试集 P_train []; T_train []; P_test []; T_test []; for i 1:3 temp_input features((i-…

Bug小能手系列(python)_12: 使用mne库读取.set文件报错 TypeError: ‘int‘ object is not iterable

使用mne库读取.set文件报错 0 引言1. 报错原因2. 推荐解决方案3. 总结 0 引言 在使用mne库读取.set文件&#xff0c;然后对文件进行处理。在运行过程中出现报错&#xff1a;TypeError: int object is not iterable 其中&#xff0c;代码库包的版本这里主要介绍mne的版本&…

MYSQL 连接

高频 SQL 50 题&#xff08;基础版&#xff09; - 学习计划 - 力扣&#xff08;LeetCode&#xff09;全球极客挚爱的技术成长平台 1378. 使用唯一标识码替换员工ID SELECT COALESCE(unique_id, NULL) AS unique_id,name FROM Employees LEFT JOIN EmployeeUNI ON Employees.…

7.MidBook项目经验之阿里OSS,微信支付(退款),定时任务,图表数据处理

1.阿里云实名认证 阿里云对象存储oss,标准高频率访问, 低频访问存储,归档存储(根本不经常访问) 冗余存储(备份) 读写权限(所有人还是自己访问) Component public class ConstantOssPropertiesUtils implements InitializingBean {不用注入,由spring创建bean使用流 MultipartFil…

SpringBoot整合RabbitMQ并实现消息发送与接收

系列文章目录 解析JSON格式参数 & 修改对象的key VUE整合Echarts实现简单的数据可视化 Java中运用BigDecimal对字符串的数值进行加减乘除等操作 List&#xff1c;HashMap&#xff1c;String,String&#xff1e;&#xff1e;实现自定义字符串排序&#xff08;key排序、Valu…

Node介绍(nvm安装和npm常用命令)

文章目录 Node 介绍为什么要学习 Node.jsNode.js 是什么Node能做什么nvm常用的 nvm 命令npm 快捷键npm 常用命令切换 npm 下包镜像源常用命令 Node 介绍 为什么要学习 Node.js 企业需求 具有服务端开发经验更改front-endback-end全栈开发工程师基本的网站开发能力 服务端前端…

TP4067带电池反接保护500MA线性锂电池充电芯片

概述 TP4067 是一款完整的单节锂电池充电器&#xff0c;带电池正负极反接保护输入电源正负极反接保护的单芯片&#xff0c;兼容大小3mA-600mA充电电流。采用涓流、恒流、恒压控制&#xff0c;SOT23-6封装与较少的外部元件数目使得TP4067成为便携式应用的理想选择.TP4067可以适…

ims-go项目搭建

通过集成开发工具Goland创建项目 整合Gin框架&#xff0c;在终端中输入如下命令&#xff1a; go get -u github.com/gin-gonic/gin 整合Gorm&#xff0c;安装命令如下&#xff1a; go get -u gorm.io/gorm 安装sqlserver驱动&#xff0c;安装命令如下&#xff1a; go get -u…

docker报错问题解决:Error Invalid or corrupt jarfile app.jar

文章目录 1.问题描述2.问题分析3.问题解决 1.问题描述 此时处在 /home/ubuntu/app 目录下&#xff0c;并且在该目录下有一个 jenkins-0.0.1-SNAPSHOT.jar。 我在 /home/ubuntu/app 目录下执行了 docker 容器运行命令&#xff1a; # 映射 8859 端口 # 容器名为 jenkins-demo #…

常用Python自动化测试框架有哪些?优缺点对比

随着技术的进步和自动化技术的出现&#xff0c;市面上出现了一些自动化测试框架。只需要进行一些适用性和效率参数的调整&#xff0c;这些自动化测试框架就能够开箱即用&#xff0c;大大节省了测试时间。而且由于这些框架被广泛使用&#xff0c;他们具有很好的健壮性&#xff0…