利用MMSegmentation微调Mask2Former模型

news2024/11/23 13:14:17

前言

  • 本文介绍了专用于语义分隔模型的pythonmmsegmentationgithub项目地址,运行环境为Kaggle notebookGPUP100
  • 针对环境配置、预训练模型推理、在西瓜数据集上微调新sota模型mask2former模型,数据说明
  • 由于西瓜数据集较小,我们最后在组织病理切片肾小球数据集上微调了mask2former模型,数据说明
  • 该教程有部分参考github项目MMSegmentation_Tutorials,项目地址

环境配置

  • 跑通代码需要openmimmmsegmentationmmenginemmdetectionmmcv环境,mmcv环境在kaggle配置比较麻烦,需要预配置包,这里我将所有预配置包都打包好了,放到了数据集frozen-packages-mmdetection中,详情页
import IPython.display as display
!pip install -U openmim

!rm -rf mmsegmentation
!git clone https://github.com/open-mmlab/mmsegmentation.git
%cd mmsegmentation
!pip install -v -e .

!pip install "mmdet>=3.0.0rc4"

!pip install -q /kaggle/input/frozen-packages-mmdetection/mmcv-2.0.1-cp310-cp310-linux_x86_64.whl

!pip install wandb
display.clear_output()
  • 实测运行上述代码,在kaggle中可以达到运行项目需求,无报错(2023年7月13日)。
  • 导入常用基础包
import io
import os
import cv2
import glob
import time
import torch
import shutil
import mmcv
import wandb
import random
import mmengine
import numpy as np
from PIL import Image
from tqdm import tqdm
from mmengine import Config

import matplotlib.pyplot as plt
%matplotlib inline

from mmseg.datasets import cityscapes
from mmseg.utils import register_all_modules
register_all_modules()

from mmseg.datasets import CityscapesDataset
from mmengine.model.utils import revert_sync_batchnorm
from mmseg.apis import init_model, inference_model, show_result_pyplot

# 忽略警告
import warnings
warnings.filterwarnings('ignore')

display.clear_output()
  • 创建文件夹,用于放置数据集、模型预训练权重和模型推理输出
# 创建 checkpoint 文件夹,用于存放预训练模型权重文件
os.mkdir('checkpoint')

# 创建 outputs 文件夹,用于存放预测结果
os.mkdir('outputs')

# 创建 data 文件夹,用于存放图片和视频素材
os.mkdir('data')
  • 分别下载pspnet、segformer、mask2former在cityscapes上的预训练权重,并保存在checkpoint文件夹中
# 从Model Zoo预训练模型,下载并保存在 checkpoint 文件夹中
!wget https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth -P checkpoint
!wget https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_8x1_1024x1024_160k_cityscapes/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth -P checkpoint
!wget https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth -P checkpoint
display.clear_output()
  • 下载一些测试模型用的图片以及视频,并存放到data文件夹中。
# 伦敦街景图片
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_uk.jpeg -P data

# 上海驾车街景视频,视频来源:https://www.youtube.com/watch?v=ll8TgCZ0plk
!wget https://zihao-download.obs.cn-east-3.myhuaweicloud.com/detectron2/traffic.mp4 -P data

# 街拍视频,2022年3月30日
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_20220330_174028.mp4 -P data
display.clear_output()

图片推理

命令行推理

  • 使用命令行对图片进行推理,并使用PIL对结果进行可视化
  • 分别使用了pspnet模型和segformer模型进行推理
# pspnet模型
!python demo/image_demo.py \
        data/street_uk.jpeg \
        configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \
        checkpoint/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth \
        --out-file outputs/B1_uk_pspnet.jpg \
        --device cuda:0 \
        --opacity 0.5

display.clear_output()
Image.open('outputs/B1_uk_pspnet.jpg')

请添加图片描述

# segformer模型
!python demo/image_demo.py \
        data/street_uk.jpeg \
        configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py \
        checkpoint/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth \
        --out-file outputs/B1_uk_segformer.jpg \
        --device cuda:0 \
        --opacity 0.5
display.clear_output()
Image.open('outputs/B1_uk_segformer.jpg')

请添加图片描述

  • 可以看到其实segformer的效果比pspnet模型效果要好,基本上能将不同物体分割开。

API推理

  • 使用mmsegmentation的Python API进行图片推理
  • 使用mask2former模型推理,并利用matplotlib对结果进行可视化
img_path = 'data/street_uk.jpeg'
img_pil = Image.open(img_path)
# 模型 config 配置文件
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'

# 模型 checkpoint 权重文件
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'

model = init_model(config_file, checkpoint_file, device='cuda:0')

if not torch.cuda.is_available():
    model = revert_sync_batchnorm(model)

result = inference_model(model, img_path)
pred_mask = result.pred_sem_seg.data[0].detach().cpu().numpy()

display.clear_output()
img_bgr = cv2.imread(img_path)
plt.figure(figsize=(14, 8))
plt.imshow(img_bgr[:,:,::-1])
plt.imshow(pred_mask, alpha=0.55) # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.savefig('outputs/B2-1.jpg')
plt.show()

请添加图片描述

  • mask2former作为sota模型,效果确实非常棒!

视频推理

命令行推理

  • 不推荐,速度很慢
!python demo/video_demo.py \
        data/street_20220330_174028.mp4 \
        configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py \
        checkpoint/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth \
        --device cuda:0 \
        --output-file outputs/B3_video.mp4 \
        --opacity 0.5

API推理

  • mask2former模型使用API对视频进行推理
# 模型 config 配置文件
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'

# 模型 checkpoint 权重文件
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'

model = init_model(config_file, checkpoint_file, device='cuda:0')

if not torch.cuda.is_available():
    model = revert_sync_batchnorm(model)

display.clear_output()

input_video = 'data/street_20220330_174028.mp4'

temp_out_dir = time.strftime('%Y%m%d%H%M%S')
os.mkdir(temp_out_dir)
print('创建临时文件夹 {} 用于存放每帧预测结果'.format(temp_out_dir))

# 获取 Cityscapes 街景数据集 类别名和调色板
classes = cityscapes.CityscapesDataset.METAINFO['classes']
palette = cityscapes.CityscapesDataset.METAINFO['palette']

def pridict_single_frame(img, opacity=0.2):

    result = inference_model(model, img)

    # 将分割图按调色板染色
    seg_map = np.array(result.pred_sem_seg.data[0].detach().cpu().numpy()).astype('uint8')
    seg_img = Image.fromarray(seg_map).convert('P')
    seg_img.putpalette(np.array(palette, dtype=np.uint8))

    show_img = (np.array(seg_img.convert('RGB')))*(1-opacity) + img*opacity

    return show_img

# 读入待预测视频
imgs = mmcv.VideoReader(input_video)

prog_bar = mmengine.ProgressBar(len(imgs))

# 对视频逐帧处理
for frame_id, img in enumerate(imgs):

    ## 处理单帧画面
    show_img = pridict_single_frame(img, opacity=0.15)
    temp_path = f'{temp_out_dir}/{frame_id:06d}.jpg' # 保存语义分割预测结果图像至临时文件夹
    cv2.imwrite(temp_path, show_img)

    prog_bar.update() # 更新进度条

# 把每一帧串成视频文件
mmcv.frames2video(temp_out_dir, 'outputs/B3_video.mp4', fps=imgs.fps, fourcc='mp4v')

shutil.rmtree(temp_out_dir) # 删除存放每帧画面的临时文件夹
print('删除临时文件夹', temp_out_dir)

小样本数据集微调mask2former

  • 在西瓜语义分隔数据集上对模型进行微调

下载数据集

!rm -rf Watermelon87_Semantic_Seg_Mask.zip Watermelon87_Semantic_Seg_Mask

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/watermelon/Watermelon87_Semantic_Seg_Mask.zip

!unzip Watermelon87_Semantic_Seg_Mask.zip >> /dev/null # 解压

!rm -rf Watermelon87_Semantic_Seg_Mask.zip # 删除压缩包

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/watermelon_test1.jpg -P data

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/video_watermelon_2.mp4 -P data

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/video_watermelon_3.mov -P data

# 删除系统自动生成的多余文件
!find . -iname '__MACOSX'
!find . -iname '.DS_Store'
!find . -iname '.ipynb_checkpoints'

# 删除多余文件
!for i in `find . -iname '__MACOSX'`; do rm -rf $i;done
!for i in `find . -iname '.DS_Store'`; do rm -rf $i;done
!for i in `find . -iname '.ipynb_checkpoints'`; do rm -rf $i;done

# 验证多余文件已删除
!find . -iname '__MACOSX'
!find . -iname '.DS_Store'
!find . -iname '.ipynb_checkpoints'

display.clear_output()

可视化探索语义分割数据集

  • 可视化语义信息
# 指定单张图像路径
img_path = 'Watermelon87_Semantic_Seg_Mask/img_dir/train/04_35-2.jpg'
mask_path = 'Watermelon87_Semantic_Seg_Mask/ann_dir/train/04_35-2.png'

img = cv2.imread(img_path)
mask = cv2.imread(mask_path)

# 可视化原图叠加
plt.figure(figsize=(8, 8))
plt.imshow(img[:,:,::-1])
plt.imshow(mask[:,:,0], alpha=0.6) # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.show()

请添加图片描述

定义Dataset和Pipeline

  • Dataset部分,可以设定数值对应的具体类别,以及不同类别的标注颜色。图像格式,是否忽略类别0
  • Pipeline部分,可以设定训练、验证的数据处理步骤。以及规定图像裁剪尺寸
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset

@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):
    # 类别和对应的 RGB配色
    METAINFO = {
        'classes':['background', 'red', 'green', 'white', 'seed-black', 'seed-white'],
        'palette':[[127,127,127], [200,0,0], [0,200,0], [144,238,144], [30,30,30], [251,189,8]]
    }
    
    # 指定图像扩展名、标注扩展名
    def __init__(self,
                 seg_map_suffix='.png',   # 标注mask图像的格式
                 reduce_zero_label=False, # 类别ID为0的类别是否需要除去
                 **kwargs) -> None:
        super().__init__(
            seg_map_suffix=seg_map_suffix,
            reduce_zero_label=reduce_zero_label,
            **kwargs)
"""

with io.open('mmseg/datasets/MyCustomDataset.py', 'w', encoding='utf-8') as f:
    f.write(custom_dataset)
  • custom_dataset加入__init__.py文件
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
                         BioMedical3DRandomCrop, BioMedical3DRandomFlip,
                         BioMedicalGaussianBlur, BioMedicalGaussianNoise,
                         BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
                         LoadBiomedicalAnnotation, LoadBiomedicalData,
                         LoadBiomedicalImageFromFile, LoadImageFromNDArray,
                         PackSegInputs, PhotoMetricDistortion, RandomCrop,
                         RandomCutOut, RandomMosaic, RandomRotate,
                         RandomRotFlip, Rerange, ResizeShortestEdge,
                         ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset

# yapf: enable
__all__ = [
    'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
    'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
    'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
    'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
    'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
    'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
    'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
    'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
    'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
    'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
    'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
    'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
    'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
    'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
    'SynapseDataset', 'MyCustomDataset'
]

"""

with io.open('mmseg/datasets/__init__.py', 'w', encoding='utf-8') as f:
    f.write(custom_init)
  • 定义数据集预处理通道
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'Watermelon87_Semantic_Seg_Mask/' # 数据集路径(相对于mmsegmentation主目录)

# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)

# 训练预处理
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(
        type='RandomResize',
        scale=(2048, 1024),
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='PackSegInputs')
]

# 测试预处理
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs')
]

# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(
        type='TestTimeAug',
        transforms=[
            [
                dict(type='Resize', scale_factor=r, keep_ratio=True)
                for r in img_ratios
            ],
            [
                dict(type='RandomFlip', prob=0., direction='horizontal'),
                dict(type='RandomFlip', prob=1., direction='horizontal')
            ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
        ])
]

# 训练 Dataloader
train_dataloader = dict(
    batch_size=2,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='img_dir/train', seg_map_path='ann_dir/train'),
        pipeline=train_pipeline))

# 验证 Dataloader
val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='img_dir/val', seg_map_path='ann_dir/val'),
        pipeline=test_pipeline))

# 测试 Dataloader
test_dataloader = val_dataloader

# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])

# 测试 Evaluator
test_evaluator = val_evaluator
"""

with io.open('configs/_base_/datasets/custom_pipeline.py', 'w', encoding='utf-8') as f:
    f.write(custom_pipeline)

修改配置文件

  • 主要修改类别个数、预训练权重路径、初始化图片尺寸(一般为128的整数倍)、batch_size、缩放学习率(修改的比例是 base_lr_default * (your_bs / default_bs))、更改学习率衰减策略
  • 关于学习率:主要修改optimizer中的lr,不用修改optim_wrapper
  • 冻结模型的骨干网络,对mask2former来说可以加快训练
cfg = Config.fromfile('configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py')
dataset_cfg = Config.fromfile('configs/_base_/datasets/custom_pipeline.py')
cfg.merge_from_dict(dataset_cfg)
# 类别个数
NUM_CLASS = 6
# 单卡训练时,需要把 SyncBN 改成 BN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.crop_size = (640, 640)
cfg.model.data_preprocessor.size = cfg.crop_size

# 预训练模型权重
cfg.load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'

# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = NUM_CLASS
cfg.model.decode_head.loss_cls.class_weight = [1.0] * NUM_CLASS + [0.1]
cfg.model.backbone.frozen_stages = 4


# 训练 Batch Size
cfg.train_dataloader.batch_size = 2
cfg.test_dataloader = cfg.val_dataloader


cfg.optimizer.lr = cfg.optimizer.lr / 8

# 结果保存目录
cfg.work_dir = './work_dirs'

cfg.train_cfg.max_iters = 4000 # 训练迭代次数
cfg.train_cfg.val_interval = 50 # 评估模型间隔
cfg.default_hooks.logger.interval = 50 # 日志记录间隔
cfg.default_hooks.checkpoint.interval = 50 # 模型权重保存间隔
cfg.default_hooks.checkpoint.max_keep_ckpts = 2 # 最多保留几个模型权重
cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重

cfg.param_scheduler[0].end = cfg.train_cfg.max_iters
# 随机数种子
cfg['randomness'] = dict(seed=0)

cfg.visualizer.vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
  • 保存配置文件
cfg.dump('custom_mask2former.py')
  • 开始训练
!python tools/train.py custom_mask2former.py
  • 选取最优模型,测试模型精度
# 取最佳模型权重
best_pth = glob.glob('work_dirs/best_mIoU*.pth')[0]
# 测试精度
!python tools/test.py custom_mask2former.py '{best_pth}'
  • 输出:
+------------+-------+-------+-------+--------+-----------+--------+
|   Class    |  IoU  |  Acc  |  Dice | Fscore | Precision | Recall |
+------------+-------+-------+-------+--------+-----------+--------+
| background | 98.55 | 99.12 | 99.27 | 99.27  |   99.42   | 99.12  |
|    red     | 96.54 | 98.83 | 98.24 | 98.24  |   97.65   | 98.83  |
|   green    | 94.37 | 96.08 |  97.1 |  97.1  |   98.14   | 96.08  |
|   white    | 85.96 | 92.67 | 92.45 | 92.45  |   92.24   | 92.67  |
| seed-black | 81.98 | 90.87 |  90.1 |  90.1  |   89.34   | 90.87  |
| seed-white | 65.57 | 69.98 | 79.21 | 79.21  |   91.24   | 69.98  |
+------------+-------+-------+-------+--------+-----------+--------+

可视化训练指标

在这里插入图片描述

肾小球数据集微调模型

  • 在单类别数据集(组织病理切片肾小球)上微调mask2former模型
  • 首先清空工作目录、data文件夹和outputs文件
# 清空工作目录
!rm -r work_dirs/*
# 清空data文件夹
!rm -r data/*
# 清空outputs文件夹
!rm -r outputs/*

可视化探索语义分割数据集

# 指定图像和标注路径
PATH_IMAGE = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
PATH_MASKS = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'

mask = cv2.imread('/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024/VUHSK_1762_29.png')
# 查看类别
np.unique(mask)
  • 输出
array([0, 1], dtype=uint8)
  • 可视化语义分割信息
# n行n列可视化
n = 5

# 标注区域透明度,透明度越小,越接近原图
opacity = 0.65

fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, figsize=(12,12))

for i, file_name in enumerate(os.listdir(PATH_IMAGE)[:n**2]):
    
    # 载入图像和标注
    img_path = os.path.join(PATH_IMAGE, file_name)
    mask_path = os.path.join(PATH_MASKS, file_name.split('.')[0]+'.png')
    img = cv2.imread(img_path)
    mask = cv2.imread(mask_path)
    
    # 可视化
    axes[i//n, i%n].imshow(img[:,:,::-1])
    axes[i//n, i%n].imshow(mask[:,:,0], alpha=opacity)
    axes[i//n, i%n].axis('off') # 关闭坐标轴显示
fig.suptitle('Image and Semantic Label', fontsize=20)
plt.tight_layout()
plt.savefig('outputs/C2-1.jpg')
plt.show()

请添加图片描述

分割训练集与测试集

  • 新建各类训练、验证文件夹
# 新建图片训练、验证文件夹
!mkdir -p data/images/train
!mkdir -p data/images/val

# 新建mask训练、验证文件夹
!mkdir -p data/masks/train
!mkdir -p data/masks/val
  • 随机打乱数据,并按照90%训练集、10%测试集分割
def copy_file(og_images, og_masks, tr_images, tr_masks, thor):
    # 获取源文件夹中的所有文件名
    file_names = os.listdir(og_images)
    
    # 随机打乱文件名列表
    random.shuffle(file_names)
    
    # 计算分割点
    split_index = int(thor * len(file_names))
    
    # 复制训练集文件
    for file_name in file_names[:split_index]:
        og_image = os.path.join(og_images, file_name)
        og_mask = os.path.join(og_masks, file_name)
        tr_image = os.path.join(tr_images, 'train', file_name)
        tr_mask = os.path.join(tr_masks, 'train', file_name)
        shutil.copyfile(og_image, tr_image)
        shutil.copyfile(og_mask, tr_mask)

    # 复制验证集文件
    for file_name in file_names[split_index:]:
        og_image = os.path.join(og_images, file_name)
        og_mask = os.path.join(og_masks, file_name)
        tr_image = os.path.join(tr_images, 'val', file_name)
        tr_mask = os.path.join(tr_masks, 'val', file_name)
        shutil.copyfile(og_image, tr_image)
        shutil.copyfile(og_mask, tr_mask)

og_images = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
og_masks = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'

tr_images = 'data/images'
tr_masks = 'data/masks'

copy_file(og_images, og_masks, tr_images, tr_masks, 0.9)

重新定义Dataset和Pipeline

  • 主要是修改类别及对应RGB配色
  • 以及dataload的路径信息
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset

@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):
    # 类别和对应的RGB配色
    METAINFO = {
        'classes':['normal','sclerotic'],
        'palette':[[127,127,127],[251,189,8]]
    }
    
    # 指定图像扩展名、标注扩展名
    def __init__(self,img_suffix='.png',
                 seg_map_suffix='.png',   # 标注mask图像的格式
                 reduce_zero_label=False, # 类别ID为0的类别是否需要除去
                 **kwargs) -> None:
        super().__init__(
            img_suffix=img_suffix,
            seg_map_suffix=seg_map_suffix,
            reduce_zero_label=reduce_zero_label,
            **kwargs)
"""

with io.open('mmseg/datasets/MyCustomDataset.py', 'w', encoding='utf-8') as f:
    f.write(custom_dataset)
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
                         BioMedical3DRandomCrop, BioMedical3DRandomFlip,
                         BioMedicalGaussianBlur, BioMedicalGaussianNoise,
                         BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
                         LoadBiomedicalAnnotation, LoadBiomedicalData,
                         LoadBiomedicalImageFromFile, LoadImageFromNDArray,
                         PackSegInputs, PhotoMetricDistortion, RandomCrop,
                         RandomCutOut, RandomMosaic, RandomRotate,
                         RandomRotFlip, Rerange, ResizeShortestEdge,
                         ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset

# yapf: enable
__all__ = [
    'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
    'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
    'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
    'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
    'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
    'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
    'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
    'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
    'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
    'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
    'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
    'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
    'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
    'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
    'SynapseDataset', 'MyCustomDataset'
]

"""

with io.open('mmseg/datasets/__init__.py', 'w', encoding='utf-8') as f:
    f.write(custom_init)
  • 定义数据预处理管道
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'data/' # 数据集路径(相对于mmsegmentation主目录)

# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)

# 训练预处理
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(
        type='RandomResize',
        scale=(2048, 1024),
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='PackSegInputs')
]

# 测试预处理
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs')
]

# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
    dict(
        type='TestTimeAug',
        transforms=[
            [
                dict(type='Resize', scale_factor=r, keep_ratio=True)
                for r in img_ratios
            ],
            [
                dict(type='RandomFlip', prob=0., direction='horizontal'),
                dict(type='RandomFlip', prob=1., direction='horizontal')
            ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
        ])
]

# 训练 Dataloader
train_dataloader = dict(
    batch_size=2,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='images/train', seg_map_path='masks/train'),
        pipeline=train_pipeline))

# 验证 Dataloader
val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='images/val', seg_map_path='masks/val'),
        pipeline=test_pipeline))

# 测试 Dataloader
test_dataloader = val_dataloader

# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])

# 测试 Evaluator
test_evaluator = val_evaluator
"""

with io.open('configs/_base_/datasets/custom_pipeline.py', 'w', encoding='utf-8') as f:
    f.write(custom_pipeline)

修改配置文件

cfg = Config.fromfile('configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py')
dataset_cfg = Config.fromfile('configs/_base_/datasets/custom_pipeline.py')
cfg.merge_from_dict(dataset_cfg)
  • 更改配置文件
# 类别个数
NUM_CLASS = 2
# 单卡训练时,需要把 SyncBN 改成 BN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.crop_size = (640, 640)
cfg.model.data_preprocessor.size = cfg.crop_size

# 预训练模型权重
cfg.load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'

# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = NUM_CLASS
cfg.model.decode_head.loss_cls.class_weight = [1.0] * NUM_CLASS + [0.1]
cfg.model.backbone.frozen_stages = 4


# 训练 Batch Size
cfg.train_dataloader.batch_size = 2
cfg.test_dataloader = cfg.val_dataloader


cfg.optimizer.lr = cfg.optimizer.lr / 8

# 结果保存目录
cfg.work_dir = './work_dirs'

cfg.train_cfg.max_iters = 40000 # 训练迭代次数
cfg.train_cfg.val_interval = 500 # 评估模型间隔
cfg.default_hooks.logger.interval = 50 # 日志记录间隔
cfg.default_hooks.checkpoint.interval = 2500 # 模型权重保存间隔
cfg.default_hooks.checkpoint.max_keep_ckpts = 2 # 最多保留几个模型权重
cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重

# 随机数种子
cfg['randomness'] = dict(seed=0)

cfg.visualizer.vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
  • 保存配置文件,并开始训练
cfg.dump('custom_mask2former.py')
!python tools/train.py custom_mask2former.py

可视化训练指标

在这里插入图片描述

评估模型以及测试推理速度

  • 评估模型精度
# 取最佳模型权重
best_pth = glob.glob('work_dirs/best_mIoU*.pth')[0]
# 测试精度
!python tools/test.py custom_mask2former.py '{best_pth}'
  • 输出:
+-----------+-------+-------+-------+--------+-----------+--------+
|   Class   |  IoU  |  Acc  |  Dice | Fscore | Precision | Recall |
+-----------+-------+-------+-------+--------+-----------+--------+
|   normal  | 99.74 | 99.89 | 99.87 | 99.87  |   99.86   | 99.89  |
| sclerotic | 86.41 | 91.87 | 92.71 | 92.71  |   93.57   | 91.87  |
+-----------+-------+-------+-------+--------+-----------+--------+
  • 测试模型推理速度
# 测试FPS
!python tools/analysis_tools/benchmark.py custom_mask2former.py '{best_pth}'
  • 输出:
Done image [50 / 200], fps: 2.24 img / s
Done image [100/ 200], fps: 2.24 img / s
Done image [150/ 200], fps: 2.24 img / s
Done image [200/ 200], fps: 2.24 img / s
Overall fps: 2.24 img / s

Average fps of 1 evaluations: 2.24
The variance of 1 evaluations: 0.0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/751318.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【从零开始学习JAVA | 第二十八篇】不可变集合

目录 目录 前言: 不可变集合: 常见的不可变集合: 1.创建list的不可变集合: 2.创建map的不可变集合: 应用场景: 总结: 前言: 本文我们将为大家介绍JAVA中的不可变集合&#x…

第四十七章Java枚举(enum)详解:Java声明枚举类型、枚举(enum)类、EnumMap 与 EnumSet

枚举是一个被命名的整型常数的集合,用于声明一组带标识符的常数。枚举在曰常生活中很常见,例如一个人的性别只能是“男”或者“女”,一周的星期只能是 7 天中的一个等。类似这种当一个变量有几种固定可能的取值时,就可以将它定义为…

今天实习第二天 ,vue

由于这是一次新的项目,有几个技术,docker,vue 老板要我学习vue。 这里我安装的是node.js,但是操作的时候用的是idea,因为vue插件集成在idea中。 01.vue的技术 vue是MVVM的履行者。 02.第一个vue程序 第一步&#xff1…

C基础day9(2023.7.11)

一、Xmind整理&#xff1a; 二、课上练习&#xff1a; 练习1&#xff1a;实现字符串逆置 #include <stdio.h> #include <string.h> #include <stdlib.h> int main(int argc, const char *argv[]) {char str[]"hello";char *pstr;char *qstrstrlen…

UE编辑器灯光颜色,能量传入Shader流程

编辑器界面&#xff1a; 代码流程&#xff1a; FLinearColor ULightComponent::GetColoredLightBrightness() const {// Brightness in Lumensfloat LightBrightness ComputeLightBrightness();FLinearColor Energy FLinearColor(LightColor) * LightBrightness;if (bUseTem…

数学建模-拟合算法

这里的线性函数指的是参数为线性&#xff0c;而不是变量为线性。 yabx^2是线性的 用的比较多的是多项式拟合和自己定义的 拓展资料&#xff1a;工具箱曲线拟合类型评价解释 文件-导出代码 自动生成的代码修改图名和标签 如果不收敛&#xff0c;自己要修改初始值&#xf…

ES 跨集群搜索 Cross-cluster search (CCS)

跨集群查询 跨集群搜索(cross-cluster search)使你可以针对一个或多个远程集群运行单个搜索请求。 例如&#xff0c;你可以使用跨集群搜索来筛选和分析存储在不同数据中心的集群中的日志数据。 环境准备 角色IP系统dev172.16.122.244CentOS 7.9prod172.16.122.245CentOS 7.9 ES…

记忆——记忆宫殿——地点桩

地点桩图片 室内物品放置方法——时钟放置法 https://www.zhihu.com/question/34549534 地点桩的扩展和记忆 我告诉你一个让记忆宫殿数量翻125倍的方法&#xff0c;以后用一个地点桩就扔一个。 这方法是我在背了几本书后才在偶然中发现的&#xff0c;我叫他“五行推演法”&a…

ES(1)简介和安装

文章目录 简介倒排索引 安装 简介 ES是面向文档型数据库&#xff0c;一条数据在这里就是一个文档。 和关系型数据库大致关系如下: ES7.x中废除掉Type&#xff08;表&#xff09;的概念 倒排索引 要知道什么是倒排索引&#xff0c;就要先知道什么是正排索引 idcontent100…

判断数组中所有元素是否均为True numpy.alltrue()

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 判断数组中所有元素是否均为True numpy.alltrue() [太阳]选择题 请问关于以下代码的说法错误的是&#xff1f; import numpy as np print(【执行】np.alltrue([True, True, True])) print(n…

C/C++图形库EasyX保姆级使用教程(四) 图片的展示与缩放

C/C图形库EasyX保姆级使用教程 第一章 Microsoft Visual Studio 2022和EasyX的下载及安装使用 第二章 图形化窗口设置以及简单图形的绘制 第三章 图形颜色的填充及相关应用 第四章 图片的展示与缩放 文章目录 C/C图形库EasyX保姆级使用教程前言一、图片的展示1.变量存储图片2.…

安全测试方法介绍(下)渗透测试

安全主要测试方法主要有&#xff1a;静态源代码审查&#xff0c;这个在编码阶段就可以进行&#xff0c;这个阶段如果出现问题&#xff0c;修复起来成本也比较低。程序发布之后可以进行渗透测试。前面的文章中我们为大家介绍了静态源代码审查的方法和策略&#xff0c;接下来本文…

【milvus】向量数据库,用来做以图搜图+人脸识别的特征向量

1. 安装milvus ref:https://milvus.io/docs 第一次装东西&#xff0c;要把遇到的问题和成功经验都记录下来。 1.Download the YAML file wget https://github.com/milvus-io/milvus/releases/download/v2.2.11/milvus-standalone-docker-compose.yml -O docker-compose.yml看…

微信小程序中常见组件的使用

文章目录 微信小程序中常见组件的使用视图组件viewscroll-viewswipermovable-area 基础组件icontextrich-textprogress 表单组件buttoncheckbox、checkbox-grouplabelforminputpicker单列选择器多列选择器时间选择器&日期选择器&地区选择器 picker-viewradiosliderswit…

人工智能-神经网络

目录 1 神经元 2 MP模型 3 激活函数 3.1 激活函数 3.2 激活函数作用 3.3 激活函数有多种 4、神经网络模型 1 神经元 神经元是主要由树突、轴突、突出组成&#xff0c;树突是从上面接收很多信号&#xff0c;经过轴突处理后传递给突触&#xff0c;突触会进行选择性向下一级的…

[项目实战] 使用Idea构建单页面Vue3项目(不使用node、npm)

前言 某天张三对小花说&#xff0c;我需要在一台新电脑上实现一个前端的漂亮页面&#xff1a;比如京东手机首页(m.jd.com)。 小花这时吭哧吭哧的去新电脑上安装nodejs、npm&#xff0c;然后在本地使用npm构建vue3项目&#xff0c;在项目里下载安装element-plus、axios。下一步…

常用异常检测算法总结与代码实现[统计学方法/K近邻/孤立森林/DBSCAN/LOF/混合高斯GMM/自编码器AutoEncoder等]

这篇博文主要是延续前文系列的总结记录&#xff0c;这里主要是总结汇总日常主流的异常检测算法相关知识内容。 &#xff08;1&#xff09;基于统计方法的异常值检测 基于统计方法的异常值检测是一种常用的异常检测算法&#xff0c;它基于样本数据的统计特性来识别与其他样本显…

【RS】ENVI5.6 栅格数据坐标转换

ENVI是一个完整的遥感图像处理平台&#xff0c;广泛应用于科研、环境保护、气象、农业、林业、地球科学、遥感工程、水利、海洋等领域。目前ENVI已成为遥感影像处理的必备软件&#xff0c;包含辐射定标、大气校正、镶嵌裁剪、分类识别、阈值分割等多种功能。ENVI针对绝大部分的…

【三相STATCOM】使用D-Q控制的三相STATCOM技术【三相VSI STATCOM为R-L负载提供无功功率】(Simulink实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

COMEN 科曼C60、C80心电监护仪协议对接

通过网口&#xff0c;成功对接到参数&#xff1a; HR、NibpDia、NibpMean、NibpSys、Spo2、Resp、Sys、Dia、Mean、Temp、PR等数值