前言
本文介绍了专用于语义分隔模型的python
库mmsegmentation
,github
项目地址,运行环境为Kaggle notebook
,GPU
为P100
针对环境配置、预训练模型推理、在西瓜数据集上微调新sota
模型mask2former
模型,数据说明 由于西瓜数据集较小,我们最后在组织病理切片肾小球数据集上微调了mask2former
模型,数据说明 该教程有部分参考github
项目MMSegmentation_Tutorials
,项目地址
环境配置
跑通代码需要openmim
、mmsegmentation
、mmengine
、mmdetection
和mmcv
环境,mmcv
环境在kaggle
配置比较麻烦,需要预配置包,这里我将所有预配置包都打包好了,放到了数据集frozen-packages-mmdetection
中,详情页
import IPython. display as display
!pip install - U openmim
!rm - rf mmsegmentation
!git clone https: // github. com/ open - mmlab/ mmsegmentation. git
% cd mmsegmentation
!pip install - v - e .
!pip install "mmdet>=3.0.0rc4"
!pip install - q / kaggle/ input / frozen- packages- mmdetection/ mmcv- 2.0 .1 - cp310- cp310- linux_x86_64. whl
!pip install wandb
display. clear_output( )
实测运行上述代码,在kaggle
中可以达到运行项目需求,无报错(2023年7月13日)。 导入常用基础包
import io
import os
import cv2
import glob
import time
import torch
import shutil
import mmcv
import wandb
import random
import mmengine
import numpy as np
from PIL import Image
from tqdm import tqdm
from mmengine import Config
import matplotlib. pyplot as plt
% matplotlib inline
from mmseg. datasets import cityscapes
from mmseg. utils import register_all_modules
register_all_modules( )
from mmseg. datasets import CityscapesDataset
from mmengine. model. utils import revert_sync_batchnorm
from mmseg. apis import init_model, inference_model, show_result_pyplot
import warnings
warnings. filterwarnings( 'ignore' )
display. clear_output( )
创建文件夹,用于放置数据集、模型预训练权重和模型推理输出
os. mkdir( 'checkpoint' )
os. mkdir( 'outputs' )
os. mkdir( 'data' )
分别下载pspnet、segformer、mask2former在cityscapes上的预训练权重,并保存在checkpoint
文件夹中
!wget https: // download. openmmlab. com/ mmsegmentation/ v0. 5 / pspnet/ pspnet_r50- d8_512x1024_40k_cityscapes/ pspnet_r50- d8_512x1024_40k_cityscapes_20200605_003338- 2966598c. pth - P checkpoint
!wget https: // download. openmmlab. com/ mmsegmentation/ v0. 5 / segformer/ segformer_mit- b5_8x1_1024x1024_160k_cityscapes/ segformer_mit- b5_8x1_1024x1024_160k_cityscapes_20211206_072934- 87a052ec. pth - P checkpoint
!wget https: // download. openmmlab. com/ mmsegmentation/ v0. 5 / mask2former/ mask2former_swin- l- in22k- 384x384- pre_8xb2- 90k_cityscapes- 512x1024/ mask2former_swin- l- in22k- 384x384- pre_8xb2- 90k_cityscapes- 512x1024_20221202_141901- 28ad20f1. pth - P checkpoint
display. clear_output( )
下载一些测试模型用的图片以及视频,并存放到data
文件夹中。
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20220713 - mmdetection/ images/ street_uk. jpeg - P data
!wget https: // zihao- download. obs. cn- east- 3 . myhuaweicloud. com/ detectron2/ traffic. mp4 - P data
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20220713 - mmdetection/ images/ street_20220330_174028. mp4 - P data
display. clear_output( )
图片推理
命令行推理
使用命令行对图片进行推理,并使用PIL
对结果进行可视化 分别使用了pspnet
模型和segformer
模型进行推理
!python demo/ image_demo. py \
data/ street_uk. jpeg \
configs/ pspnet/ pspnet_r50- d8_4xb2- 40k_cityscapes- 512x1024. py \
checkpoint/ pspnet_r50- d8_512x1024_40k_cityscapes_20200605_003338- 2966598c. pth \
- - out- file outputs/ B1_uk_pspnet. jpg \
- - device cuda: 0 \
- - opacity 0.5
display. clear_output( )
Image. open ( 'outputs/B1_uk_pspnet.jpg' )
!python demo/ image_demo. py \
data/ street_uk. jpeg \
configs/ segformer/ segformer_mit- b5_8xb1- 160k_cityscapes- 1024x1024. py \
checkpoint/ segformer_mit- b5_8x1_1024x1024_160k_cityscapes_20211206_072934- 87a052ec. pth \
- - out- file outputs/ B1_uk_segformer. jpg \
- - device cuda: 0 \
- - opacity 0.5
display. clear_output( )
Image. open ( 'outputs/B1_uk_segformer.jpg' )
可以看到其实segformer
的效果比pspnet
模型效果要好,基本上能将不同物体分割开。
API推理
使用mmsegmentation的Python API进行图片推理 使用mask2former模型推理,并利用matplotlib对结果进行可视化
img_path = 'data/street_uk.jpeg'
img_pil = Image. open ( img_path)
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'
model = init_model( config_file, checkpoint_file, device= 'cuda:0' )
if not torch. cuda. is_available( ) :
model = revert_sync_batchnorm( model)
result = inference_model( model, img_path)
pred_mask = result. pred_sem_seg. data[ 0 ] . detach( ) . cpu( ) . numpy( )
display. clear_output( )
img_bgr = cv2. imread( img_path)
plt. figure( figsize= ( 14 , 8 ) )
plt. imshow( img_bgr[ : , : , : : - 1 ] )
plt. imshow( pred_mask, alpha= 0.55 )
plt. axis( 'off' )
plt. savefig( 'outputs/B2-1.jpg' )
plt. show( )
mask2former
作为sota
模型,效果确实非常棒!
视频推理
命令行推理
!python demo/ video_demo. py \
data/ street_20220330_174028. mp4 \
configs/ segformer/ segformer_mit- b5_8xb1- 160k_cityscapes- 1024x1024. py \
checkpoint/ segformer_mit- b5_8x1_1024x1024_160k_cityscapes_20211206_072934- 87a052ec. pth \
- - device cuda: 0 \
- - output- file outputs/ B3_video. mp4 \
- - opacity 0.5
API推理
mask2former
模型使用API对视频进行推理
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'
model = init_model( config_file, checkpoint_file, device= 'cuda:0' )
if not torch. cuda. is_available( ) :
model = revert_sync_batchnorm( model)
display. clear_output( )
input_video = 'data/street_20220330_174028.mp4'
temp_out_dir = time. strftime( '%Y%m%d%H%M%S' )
os. mkdir( temp_out_dir)
print ( '创建临时文件夹 {} 用于存放每帧预测结果' . format ( temp_out_dir) )
classes = cityscapes. CityscapesDataset. METAINFO[ 'classes' ]
palette = cityscapes. CityscapesDataset. METAINFO[ 'palette' ]
def pridict_single_frame ( img, opacity= 0.2 ) :
result = inference_model( model, img)
seg_map = np. array( result. pred_sem_seg. data[ 0 ] . detach( ) . cpu( ) . numpy( ) ) . astype( 'uint8' )
seg_img = Image. fromarray( seg_map) . convert( 'P' )
seg_img. putpalette( np. array( palette, dtype= np. uint8) )
show_img = ( np. array( seg_img. convert( 'RGB' ) ) ) * ( 1 - opacity) + img* opacity
return show_img
imgs = mmcv. VideoReader( input_video)
prog_bar = mmengine. ProgressBar( len ( imgs) )
for frame_id, img in enumerate ( imgs) :
show_img = pridict_single_frame( img, opacity= 0.15 )
temp_path = f' { temp_out_dir} / { frame_id: 06d } .jpg'
cv2. imwrite( temp_path, show_img)
prog_bar. update( )
mmcv. frames2video( temp_out_dir, 'outputs/B3_video.mp4' , fps= imgs. fps, fourcc= 'mp4v' )
shutil. rmtree( temp_out_dir)
print ( '删除临时文件夹' , temp_out_dir)
小样本数据集微调mask2former
下载数据集
!rm - rf Watermelon87_Semantic_Seg_Mask. zip Watermelon87_Semantic_Seg_Mask
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20230130 - mmseg/ dataset/ watermelon/ Watermelon87_Semantic_Seg_Mask. zip
!unzip Watermelon87_Semantic_Seg_Mask. zip >> / dev/ null
!rm - rf Watermelon87_Semantic_Seg_Mask. zip
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20230130 - mmseg/ watermelon/ data/ watermelon_test1. jpg - P data
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20230130 - mmseg/ watermelon/ data/ video_watermelon_2. mp4 - P data
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20230130 - mmseg/ watermelon/ data/ video_watermelon_3. mov - P data
!find . - iname '__MACOSX'
!find . - iname '.DS_Store'
!find . - iname '.ipynb_checkpoints'
!for i in `find . - iname '__MACOSX' `; do rm - rf $i; done
!for i in `find . - iname '.DS_Store' `; do rm - rf $i; done
!for i in `find . - iname '.ipynb_checkpoints' `; do rm - rf $i; done
!find . - iname '__MACOSX'
!find . - iname '.DS_Store'
!find . - iname '.ipynb_checkpoints'
display. clear_output( )
可视化探索语义分割数据集
img_path = 'Watermelon87_Semantic_Seg_Mask/img_dir/train/04_35-2.jpg'
mask_path = 'Watermelon87_Semantic_Seg_Mask/ann_dir/train/04_35-2.png'
img = cv2. imread( img_path)
mask = cv2. imread( mask_path)
plt. figure( figsize= ( 8 , 8 ) )
plt. imshow( img[ : , : , : : - 1 ] )
plt. imshow( mask[ : , : , 0 ] , alpha= 0.6 )
plt. axis( 'off' )
plt. show( )
定义Dataset和Pipeline
在Dataset
部分,可以设定数值对应的具体类别,以及不同类别的标注颜色。图像格式,是否忽略类别0 在Pipeline
部分,可以设定训练、验证的数据处理步骤。以及规定图像裁剪尺寸
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):
# 类别和对应的 RGB配色
METAINFO = {
'classes':['background', 'red', 'green', 'white', 'seed-black', 'seed-white'],
'palette':[[127,127,127], [200,0,0], [0,200,0], [144,238,144], [30,30,30], [251,189,8]]
}
# 指定图像扩展名、标注扩展名
def __init__(self,
seg_map_suffix='.png', # 标注mask图像的格式
reduce_zero_label=False, # 类别ID为0的类别是否需要除去
**kwargs) -> None:
super().__init__(
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
"""
with io. open ( 'mmseg/datasets/MyCustomDataset.py' , 'w' , encoding= 'utf-8' ) as f:
f. write( custom_dataset)
将custom_dataset
加入__init__.py
文件
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
LoadBiomedicalAnnotation, LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
PackSegInputs, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate,
RandomRotFlip, Rerange, ResizeShortestEdge,
ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset
# yapf: enable
__all__ = [
'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
'SynapseDataset', 'MyCustomDataset'
]
"""
with io. open ( 'mmseg/datasets/__init__.py' , 'w' , encoding= 'utf-8' ) as f:
f. write( custom_init)
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'Watermelon87_Semantic_Seg_Mask/' # 数据集路径(相对于mmsegmentation主目录)
# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)
# 训练预处理
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
# 测试预处理
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
# 训练 Dataloader
train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='img_dir/train', seg_map_path='ann_dir/train'),
pipeline=train_pipeline))
# 验证 Dataloader
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='img_dir/val', seg_map_path='ann_dir/val'),
pipeline=test_pipeline))
# 测试 Dataloader
test_dataloader = val_dataloader
# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])
# 测试 Evaluator
test_evaluator = val_evaluator
"""
with io. open ( 'configs/_base_/datasets/custom_pipeline.py' , 'w' , encoding= 'utf-8' ) as f:
f. write( custom_pipeline)
修改配置文件
主要修改类别个数、预训练权重路径、初始化图片尺寸(一般为128的整数倍)、batch_size
、缩放学习率(修改的比例是 base_lr_default * (your_bs / default_bs)
)、更改学习率衰减策略 关于学习率:主要修改optimizer
中的lr
,不用修改optim_wrapper
冻结模型的骨干网络,对mask2former
来说可以加快训练
cfg = Config. fromfile( 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py' )
dataset_cfg = Config. fromfile( 'configs/_base_/datasets/custom_pipeline.py' )
cfg. merge_from_dict( dataset_cfg)
NUM_CLASS = 6
cfg. norm_cfg = dict ( type = 'BN' , requires_grad= True )
cfg. crop_size = ( 640 , 640 )
cfg. model. data_preprocessor. size = cfg. crop_size
cfg. load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'
cfg. model. decode_head. num_classes = NUM_CLASS
cfg. model. decode_head. loss_cls. class_weight = [ 1.0 ] * NUM_CLASS + [ 0.1 ]
cfg. model. backbone. frozen_stages = 4
cfg. train_dataloader. batch_size = 2
cfg. test_dataloader = cfg. val_dataloader
cfg. optimizer. lr = cfg. optimizer. lr / 8
cfg. work_dir = './work_dirs'
cfg. train_cfg. max_iters = 4000
cfg. train_cfg. val_interval = 50
cfg. default_hooks. logger. interval = 50
cfg. default_hooks. checkpoint. interval = 50
cfg. default_hooks. checkpoint. max_keep_ckpts = 2
cfg. default_hooks. checkpoint. save_best = 'mIoU'
cfg. param_scheduler[ 0 ] . end = cfg. train_cfg. max_iters
cfg[ 'randomness' ] = dict ( seed= 0 )
cfg. visualizer. vis_backends = [ dict ( type = 'LocalVisBackend' ) , dict ( type = 'WandbVisBackend' ) ]
cfg. dump( 'custom_mask2former.py' )
!python tools/ train. py custom_mask2former. py
best_pth = glob. glob( 'work_dirs/best_mIoU*.pth' ) [ 0 ]
!python tools/ test. py custom_mask2former. py '{best_pth}'
+ - - - - - - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - - + - - - - - - - - - - - + - - - - - - - - +
| Class | IoU | Acc | Dice | Fscore | Precision | Recall |
+ - - - - - - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - - + - - - - - - - - - - - + - - - - - - - - +
| background | 98.55 | 99.12 | 99.27 | 99.27 | 99.42 | 99.12 |
| red | 96.54 | 98.83 | 98.24 | 98.24 | 97.65 | 98.83 |
| green | 94.37 | 96.08 | 97.1 | 97.1 | 98.14 | 96.08 |
| white | 85.96 | 92.67 | 92.45 | 92.45 | 92.24 | 92.67 |
| seed- black | 81.98 | 90.87 | 90.1 | 90.1 | 89.34 | 90.87 |
| seed- white | 65.57 | 69.98 | 79.21 | 79.21 | 91.24 | 69.98 |
+ - - - - - - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - - + - - - - - - - - - - - + - - - - - - - - +
可视化训练指标
肾小球数据集微调模型
在单类别数据集(组织病理切片肾小球)上微调mask2former
模型 首先清空工作目录、data文件夹和outputs文件
!rm - r work_dirs/ *
!rm - r data/ *
!rm - r outputs/ *
可视化探索语义分割数据集
PATH_IMAGE = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
PATH_MASKS = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'
mask = cv2. imread( '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024/VUHSK_1762_29.png' )
np. unique( mask)
array( [ 0 , 1 ] , dtype= uint8)
n = 5
opacity = 0.65
fig, axes = plt. subplots( nrows= n, ncols= n, sharex= True , figsize= ( 12 , 12 ) )
for i, file_name in enumerate ( os. listdir( PATH_IMAGE) [ : n** 2 ] ) :
img_path = os. path. join( PATH_IMAGE, file_name)
mask_path = os. path. join( PATH_MASKS, file_name. split( '.' ) [ 0 ] + '.png' )
img = cv2. imread( img_path)
mask = cv2. imread( mask_path)
axes[ i// n, i% n] . imshow( img[ : , : , : : - 1 ] )
axes[ i// n, i% n] . imshow( mask[ : , : , 0 ] , alpha= opacity)
axes[ i// n, i% n] . axis( 'off' )
fig. suptitle( 'Image and Semantic Label' , fontsize= 20 )
plt. tight_layout( )
plt. savefig( 'outputs/C2-1.jpg' )
plt. show( )
分割训练集与测试集
!mkdir - p data/ images/ train
!mkdir - p data/ images/ val
!mkdir - p data/ masks/ train
!mkdir - p data/ masks/ val
随机打乱数据,并按照90%训练集、10%测试集分割
def copy_file ( og_images, og_masks, tr_images, tr_masks, thor) :
file_names = os. listdir( og_images)
random. shuffle( file_names)
split_index = int ( thor * len ( file_names) )
for file_name in file_names[ : split_index] :
og_image = os. path. join( og_images, file_name)
og_mask = os. path. join( og_masks, file_name)
tr_image = os. path. join( tr_images, 'train' , file_name)
tr_mask = os. path. join( tr_masks, 'train' , file_name)
shutil. copyfile( og_image, tr_image)
shutil. copyfile( og_mask, tr_mask)
for file_name in file_names[ split_index: ] :
og_image = os. path. join( og_images, file_name)
og_mask = os. path. join( og_masks, file_name)
tr_image = os. path. join( tr_images, 'val' , file_name)
tr_mask = os. path. join( tr_masks, 'val' , file_name)
shutil. copyfile( og_image, tr_image)
shutil. copyfile( og_mask, tr_mask)
og_images = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
og_masks = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'
tr_images = 'data/images'
tr_masks = 'data/masks'
copy_file( og_images, og_masks, tr_images, tr_masks, 0.9 )
重新定义Dataset和Pipeline
主要是修改类别及对应RGB配色 以及dataload的路径信息
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):
# 类别和对应的RGB配色
METAINFO = {
'classes':['normal','sclerotic'],
'palette':[[127,127,127],[251,189,8]]
}
# 指定图像扩展名、标注扩展名
def __init__(self,img_suffix='.png',
seg_map_suffix='.png', # 标注mask图像的格式
reduce_zero_label=False, # 类别ID为0的类别是否需要除去
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
"""
with io. open ( 'mmseg/datasets/MyCustomDataset.py' , 'w' , encoding= 'utf-8' ) as f:
f. write( custom_dataset)
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
LoadBiomedicalAnnotation, LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
PackSegInputs, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate,
RandomRotFlip, Rerange, ResizeShortestEdge,
ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset
# yapf: enable
__all__ = [
'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
'SynapseDataset', 'MyCustomDataset'
]
"""
with io. open ( 'mmseg/datasets/__init__.py' , 'w' , encoding= 'utf-8' ) as f:
f. write( custom_init)
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'data/' # 数据集路径(相对于mmsegmentation主目录)
# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)
# 训练预处理
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
# 测试预处理
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
# 训练 Dataloader
train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/train', seg_map_path='masks/train'),
pipeline=train_pipeline))
# 验证 Dataloader
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/val', seg_map_path='masks/val'),
pipeline=test_pipeline))
# 测试 Dataloader
test_dataloader = val_dataloader
# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])
# 测试 Evaluator
test_evaluator = val_evaluator
"""
with io. open ( 'configs/_base_/datasets/custom_pipeline.py' , 'w' , encoding= 'utf-8' ) as f:
f. write( custom_pipeline)
修改配置文件
cfg = Config. fromfile( 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py' )
dataset_cfg = Config. fromfile( 'configs/_base_/datasets/custom_pipeline.py' )
cfg. merge_from_dict( dataset_cfg)
NUM_CLASS = 2
cfg. norm_cfg = dict ( type = 'BN' , requires_grad= True )
cfg. crop_size = ( 640 , 640 )
cfg. model. data_preprocessor. size = cfg. crop_size
cfg. load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'
cfg. model. decode_head. num_classes = NUM_CLASS
cfg. model. decode_head. loss_cls. class_weight = [ 1.0 ] * NUM_CLASS + [ 0.1 ]
cfg. model. backbone. frozen_stages = 4
cfg. train_dataloader. batch_size = 2
cfg. test_dataloader = cfg. val_dataloader
cfg. optimizer. lr = cfg. optimizer. lr / 8
cfg. work_dir = './work_dirs'
cfg. train_cfg. max_iters = 40000
cfg. train_cfg. val_interval = 500
cfg. default_hooks. logger. interval = 50
cfg. default_hooks. checkpoint. interval = 2500
cfg. default_hooks. checkpoint. max_keep_ckpts = 2
cfg. default_hooks. checkpoint. save_best = 'mIoU'
cfg[ 'randomness' ] = dict ( seed= 0 )
cfg. visualizer. vis_backends = [ dict ( type = 'LocalVisBackend' ) , dict ( type = 'WandbVisBackend' ) ]
cfg. dump( 'custom_mask2former.py' )
!python tools/ train. py custom_mask2former. py
可视化训练指标
评估模型以及测试推理速度
best_pth = glob. glob( 'work_dirs/best_mIoU*.pth' ) [ 0 ]
!python tools/ test. py custom_mask2former. py '{best_pth}'
+ - - - - - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - - + - - - - - - - - - - - + - - - - - - - - +
| Class | IoU | Acc | Dice | Fscore | Precision | Recall |
+ - - - - - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - - + - - - - - - - - - - - + - - - - - - - - +
| normal | 99.74 | 99.89 | 99.87 | 99.87 | 99.86 | 99.89 |
| sclerotic | 86.41 | 91.87 | 92.71 | 92.71 | 93.57 | 91.87 |
+ - - - - - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - - + - - - - - - - - - - - + - - - - - - - - +
!python tools/ analysis_tools/ benchmark. py custom_mask2former. py '{best_pth}'
Done image [ 50 / 200 ] , fps: 2.24 img / s
Done image [ 100 / 200 ] , fps: 2.24 img / s
Done image [ 150 / 200 ] , fps: 2.24 img / s
Done image [ 200 / 200 ] , fps: 2.24 img / s
Overall fps: 2.24 img / s
Average fps of 1 evaluations: 2.24
The variance of 1 evaluations: 0.0