SegFormer学习笔记(4)train续2

news2024/11/17 17:42:55

这次关注一下最关键的东西:用什么网络,用什么数据,预训练数据在哪里呢?

为了方便,重新贴一下 train.py

import torch 
import argparse
import yaml
import time
import multiprocessing as mp
from tabulate import tabulate
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler, RandomSampler
from torch import distributed as dist
from semseg.models import *
from semseg.datasets import * 
from semseg.augmentations import get_train_augmentation, get_val_augmentation
from semseg.losses import get_loss
from semseg.schedulers import get_scheduler
from semseg.optimizers import get_optimizer
from semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp
from val import evaluate


def main(cfg, gpu, save_dir):
    start = time.time()
    best_mIoU = 0.0
    num_workers = mp.cpu_count()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #device = torch.device(cfg['DEVICE'])
    train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL']
    dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL']
    loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER']
    epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR']
    
    traintransform = get_train_augmentation(train_cfg['IMAGE_SIZE'], seg_fill=dataset_cfg['IGNORE_LABEL'])
    valtransform = get_val_augmentation(eval_cfg['IMAGE_SIZE'])

    trainset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'train', traintransform)
    valset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'val', valtransform)
    
    model = eval(model_cfg['NAME'])(model_cfg['BACKBONE'], trainset.n_classes)
    model.init_pretrained(model_cfg['PRETRAINED'])
    model = model.to(device)

    if train_cfg['DDP']: 
        sampler = DistributedSampler(trainset, dist.get_world_size(), dist.get_rank(), shuffle=True)
        model = DDP(model, device_ids=[gpu])
    else:
        sampler = RandomSampler(trainset)
    
    trainloader = DataLoader(trainset, batch_size=train_cfg['BATCH_SIZE'], num_workers=num_workers, drop_last=True, pin_memory=True, sampler=sampler)
    valloader = DataLoader(valset, batch_size=1, num_workers=1, pin_memory=True)

    iters_per_epoch = len(trainset) // train_cfg['BATCH_SIZE']
    # class_weights = trainset.class_weights.to(device)
    loss_fn = get_loss(loss_cfg['NAME'], trainset.ignore_label, None)
    optimizer = get_optimizer(model, optim_cfg['NAME'], lr, optim_cfg['WEIGHT_DECAY'])
    scheduler = get_scheduler(sched_cfg['NAME'], optimizer, epochs * iters_per_epoch, sched_cfg['POWER'], iters_per_epoch * sched_cfg['WARMUP'], sched_cfg['WARMUP_RATIO'])
    scaler = GradScaler(enabled=train_cfg['AMP'])
    writer = SummaryWriter(str(save_dir / 'logs'))

    for epoch in range(epochs):
        model.train()
        if train_cfg['DDP']: sampler.set_epoch(epoch)

        train_loss = 0.0
        pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.8f}")

        for iter, (img, lbl) in pbar:
            optimizer.zero_grad(set_to_none=True)

            img = img.to(device)
            lbl = lbl.to(device)
            
            with autocast(enabled=train_cfg['AMP']):
                logits = model(img)
                loss = loss_fn(logits, lbl)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            torch.cuda.synchronize()

            lr = scheduler.get_lr()
            lr = sum(lr) / len(lr)
            train_loss += loss.item()

            pbar.set_description(f"Epoch: [{epoch+1}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss / (iter+1):.8f}")
        
        train_loss /= iter+1
        writer.add_scalar('train/loss', train_loss, epoch)
        torch.cuda.empty_cache()

        if (epoch+1) % train_cfg['EVAL_INTERVAL'] == 0 or (epoch+1) == epochs:
            miou = evaluate(model, valloader, device)[-1]
            writer.add_scalar('val/mIoU', miou, epoch)

            if miou > best_mIoU:
                best_mIoU = miou
                torch.save(model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}.pth")
            print(f"Current mIoU: {miou} Best mIoU: {best_mIoU}")

    writer.close()
    pbar.close()
    end = time.gmtime(time.time() - start)

    table = [
        ['Best mIoU', f"{best_mIoU:.2f}"],
        ['Total Training Time', time.strftime("%H:%M:%S", end)]
    ]
    print(tabulate(table, numalign='right'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', type=str, default='configs/custom.yaml', help='Configuration file to use')
    args = parser.parse_args()

    with open(args.cfg) as f:
        cfg = yaml.load(f, Loader=yaml.SafeLoader)

    fix_seeds(3407)
    setup_cudnn()
    gpu = setup_ddp()
    save_dir = Path(cfg['SAVE_DIR'])
    save_dir.mkdir(exist_ok=True)
    main(cfg, gpu, save_dir)
    cleanup_ddp()

一、model_cfg

上面第32行,本质说的是

model_cfg = cfg['MODEL']

你看在custom.yaml中,

MODEL:                                    
  NAME          : SegFormer                                           # name of the model you are using
  BACKBONE      : MiT-B2                                                 # model variant
  PRETRAINED    : 'checkpoints/backbones/mit/mit_b2.pth'              # backbone model's weight

第42行,重量级代码来了

model = eval(model_cfg['NAME'])(model_cfg['BACKBONE'], trainset.n_classes)

model_cfg['NAME']其实就是'SegFormer'

这就需要细心地你,注意第15、16行如下:

from semseg.models import *
from semseg.datasets import * 

那么,第42行,就是要实现SegFormer类,并且BACKBONE 为 MiT-B2

第43行,说的是预训练模型

model.init_pretrained(model_cfg['PRETRAINED'])

你会发现,init_pretrained是个多态的,在这里,由于model已经是SegFormer类,而在SegFormer中,继承了BaseModel,所以,执行的是BaseModel的init_pretrained.

所以,43行执行的是啥?

model.init_pretrained(model_cfg['PRETRAINED'])

预训练模型来自model_cfg['PRETRAINED']

对于我来说,

PRETRAINED : 'checkpoints/backbones/mit/mit_b2.pth' # backbone model's weight

细心的你,

BACKBONE : MiT-B2 # model variant

还没用上呢。

再看一遍segformer.py:

import torch
from torch import Tensor
from torch.nn import functional as F
from semseg.models.base import BaseModel
from semseg.models.heads import SegFormerHead


class SegFormer(BaseModel):
    def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19) -> None:
        super().__init__(backbone, num_classes)
        self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 768, num_classes)
        self.apply(self._init_weights)

    def forward(self, x: Tensor) -> Tensor:
        y = self.backbone(x)
        y = self.decode_head(y)   # 4x reduction in image size
        y = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False)    # to original image shape
        return y


if __name__ == '__main__':
    model = SegFormer('MiT-B0')
    # model.load_state_dict(torch.load('checkpoints/pretrained/segformer/segformer.b0.ade.pth', map_location='cpu'))
    x = torch.zeros(1, 3, 512, 512)
    y = model(x)
    print(y.shape)

上面第11行,就用上了backbone。

二、model_cfg总结

MODEL:                                    
  NAME          : SegFormer                                           # name of the model you are using
  BACKBONE      : MiT-B2                                                 # model variant
  PRETRAINED    : 'checkpoints/backbones/mit/mit_b2.pth'              # backbone model's weight

NAME 决定了采用哪个类。

BACKBONE 决定了用哪个backbone

PRETRAINED 决定了预编译文件

他们之间是有约束关系的,不是随便乱选。

三、train_cfg

TRAIN:
  IMAGE_SIZE    : [512, 512]    # training image size in (h, w)
  BATCH_SIZE    : 2               # batch size used to train
  EPOCHS        : 6             # number of epochs to train
  EVAL_INTERVAL : 2             # evaluation interval during training
  AMP           : false           # use AMP in training
  DDP           : false           # use DDP training

四、dataset_cfg

DATASET:
  NAME          : HELEN                                          # dataset name to be trained with (camvid, cityscapes, ade20k)
  ROOT          : 'data/SmithCVPR2013_dataset_resized'                                      # dataset root path
  IGNORE_LABEL  : 255

这里有意思不?

NAME : HELEN

怎么解释?

五、eval_cfg

EVAL:
  MODEL_PATH    : 'checkpoints/pretrained/ddrnet/ddrnet_23slim_city.pth'     # trained model file path
  IMAGE_SIZE    : [1024, 1024]                            # evaluation image size in (h, w)                       
  MSF: 
    ENABLE      : false                                   # multi-scale and flip evaluation  
    FLIP        : true                                    # use flip in evaluation  
    SCALES      : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]       # scales used in MSF evaluation    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/164132.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JVM笔记——根据黑马jvm课程课件+自己总结

JVM一、内存结构1、程序计数器(PC Register)2、虚拟机栈(JVM Stacks)3、本地方法栈(Native Method Stacks)4、堆(Heap)5、方法区(Method Area)6、直接内存二、…

【爬虫】第七部分 scrapy

【爬虫】第七部分 scrapy 文章目录【爬虫】第七部分 scrapy7. scrapy7.1 基本使用7.2 项目的文件结构7.3 response的方法和属性7.4 小案例7.5 scrapy 工作原理7.6 管道的使用7.7 多管道下载7.8 下载分页类型和get请求的使用7.9 下载多层级类型7.10 post请求的使用总结7. scrapy…

清华大学出版——C语言从入门到精通(第4版)

《C语言从入门到精通(第4版)》是清华大学出版社出版的图书,该书从初学者的角度出发,以通俗易懂的语言,丰富多彩的实例,详细介绍了使用C语言进行程序开发需要掌握的各方面知识。《C语言从入门到精通&#xf…

YOLO v8详解

回顾一下YOLOv5 Backbone:CSPDarkNet结构,主要结构思想的体现在C3模块,这里也是梯度分流的主要思想所在的地方;PAN-FPN:双流的FPN,但是量化还是有些需要图优化才可以达到最优的性能,比如cat前后…

VSCode 配置Go环境,弹出警告“golps”等插件要求下载但下载时超时、失去连接等 解决方案

1. 背景: 下载完GO环境和VSCode的GO配套插件后,试图运行hello world程序,此时VSCode弹出警告: 提示有几个go的工具没有下载,于是我点击install 下载: 》下载时报错,一般是出现超时timeout错误…

57 mac 中 SIGINFO 信号, jdk8 支持, 但是 jdk9 不支持?

前言 问题来自于文章 shell脚本 后台启动 程序1 “tail -f log“, ctrl c 导致程序1中断 中的测试用例 Test07Signal2ParentProcess, 可以看到 我当时标记了一个 "todo, not work in hostpostVM9" 然后 问题是这样的, 我同一台机器, 然后 jdk8 带上 SIGINFO 去执行…

【已解决】右键以某应用打开xx文件时,没有“默认”选项怎么办

问题解决方案简单来说详细操作解释问题 右键以某应用打开xx文件时,没有“默认”选项 解决方案 简单来说 在注册表:计算机\HKEY_CURRENT_USER\Software\Microsoft\Windows\CurrentVersion\Explorer\FileExts\找到要打开的文件后缀名,删除…

mybatis plus基本使用初体验01

我们都知道MyBatis是目前比较常用的持久层框架;这个框架的使用也是很简单的,我们在使用的时候,只需要关注mapper的接口层和对应的xml文件即可。 但是MyBatis作为一个半自动框架,是需要我们自己手动编写sql语句的,对于…

Linux系统软件安装

在Linux上部署各类软件MySQL数据库管理系统安装部署简介注意MySQL5.7版本在CentOS系统安装安装配置MySQL8.0版本在CentOS系统安装安装配置MySQL5.7版本在Ubuntu(WSL环境)系统安装安装MySQL8.0版本在Ubuntu(WSL环境)系统安装安装To…

Web服务统一身份认证协议设计与实现

单点登录(SSO)是目前比较流行的企业业务整合的解决方案之一,它的机制是在企业网络用户访问企业网站时作一次身份认证,随后就可以对所有被授权的网络资源进行无缝的访问,而不需要多次输入自己的认证信息.Web服务具有松散耦合、语言中立、平台无关性、开放性的特性,通过对集中单点…

Qt扫盲-Qt 属性系统记录

Qt 属性系统记录一、概述二、属性声明三、通过元对象系统读写属性四、简单例子五、动态属性六、对一个类添加额外的属性一、概述 Qt 提供了一个复杂的属性系统,类似于一些编译器供应商提供的系统。然而,作为一个独立于编译器和平台的库,Qt并…

Java基础07——集合

Java基础07——集合一、集合和数组的对比二、ArrayList成员方法三、集合练习1. 添加数字并遍历2. 添加学生对象并遍历学生类测试类输出结果3. 添加用户对象并判断是否存在用户类测试类输出结果4. 添加手机对象并返回要求的数据(返回多个数据)手机类测试类…

【算法】Day06

努力经营当下,直至未来明朗! 文章目录1. BST二叉搜索树的后序遍历序列2. 二叉树中和为某一值的路径(二)[回溯法]3. 字符串的排列 [全排列问题]4. 最小的K个数 [topK问题]普通小孩也要热爱生活! 1. BST二叉搜索树的后序…

IF:6+ 综合分析揭示了一种炎症性癌症相关的成纤维细胞亚型在预测膀胱癌患者的预后和免疫治疗反应方面具有重要意义...

桓峰基因的教程不但教您怎么使用,还会定期分析一些相关的文章,学会教程只是基础,但是如果把分析结果整合到文章里面才是目的,觉得我们这些教程还不错,并且您按照我们的教程分析出来不错的结果发了文章记得告知我们&…

Linux 中断子系统(七):注册中断

Linux 注册中断的 API request_irq():不使用中断线程化request_threaded_irq():使用中断线程化中断线程化 为什么需要将中断下半部处理线程化,原因如下: 中断具有最高优先级,有中断发生时,会抢占进程,导致实时任务不能及时处理。中断上下文总是可以抢占进程上下文,这…

【PyTorch】教程:学习基础知识-(3) Datasets-DataLoader

Dataset & DataLoader PyTorch 提供了两个数据处理的基本方法:torch.utils.data.DataLoader torch.utils.data.Dataset 允许使用预加载的数据集以及自己的数据。 Dataset 存储样本及其对应的标签, DataLoader 在 Dataset 基础上封装了一个可迭代的对…

Python文本颜色设置

Python文本颜色设置实现过程:书写格式:数值表示的参数含义:常见开头格式:实例:实现过程: 终端的字符颜色是用转义序列控制的,是文本模式下的系统显示功能,和具体的语言无关。 转义序…

Acwing4699. 如此编码

某次测验后,顿顿老师在黑板上留下了一串数字 23333 便飘然而去。 凝望着这个神秘数字,小 P 同学不禁陷入了沉思…… 已知某次测验包含 nn 道单项选择题,其中第 i 题(1≤i≤n)有 ai 个选项,正确选项为 bi&…

CAS And Atomic

CAS(Compare And Swap 比较并交换),通常指的是这样一种原子操作:针对一个变量,首先比较它的内存值与某个期望值是否相同,如果相同,就给它赋一个新值,底层是能保证cas是原子性的CAS的应用 在Java 中,CAS 操作…

Android开发-AS学习(三)(布局)

相关文章链接:Android开发-AS学习(一)(控件)Android开发-AS学习(二)(控件)Android开发应用案例——简易计算器(附完整源码)二、布局2.1 Linearyout常见属性说…