deeplab v3+源码 慢慢解析系列
- 本带着一些孩子们做,但本硕能独立看下来的学生不多。
- 和孩子们一起再学一遍吧。
- 希望孩子们和我自己都能坚持写下去吧。
- 网上资料太多了,但不够慢,都是速成,没有足够的解释和补充,希望这次够慢,够清楚吧。
前期准备和说明
提示:源码众多,此次选这个版本pytorch版
- 已经会python了(有编写代码的基础)。
- 了解深度学习和语义分割都是什么,特别是卷积核、填充等基本概念都已明确。
- 本次尽量不讲原理,尽量只说代码。
- 每次只说一个函数,进度足够慢,尽量足够简单。
总体目录
提示:下载解压后,总体结构如下,计划是一次说一个代码的一个函数。
- 总体上readme.md说的挺详细的,没必要就翻译一事浪费言语,请自行解决。
- requirements.txt的环境,也可自行选择版本吧,具备基本工程经验可以调整即可。
- 第一篇从main.py开始。拿到代码,想运行,下载完数据,配置好datasets文件夹,按readme.md操作即可(网络或者b站视频众多,不重复进行这些了)。直接进入代码。
main.py导入
提示:你过去写得最好的一段代码是什么? 请用代码块贴出来
例如:
#以下是基本操作
from tqdm import tqdm
import network
import utils
import os
import random
import argparse #这个是本篇所讲重点
import numpy as np
#以下是数据部分所需
from torch.utils import data
from datasets import VOCSegmentation, Cityscapes
from utils import ext_transforms as et
from metrics import StreamSegMetrics
#以下是神经网络所需
import torch
import torch.nn as nn
#以下是可视化和图片操作所需
from utils.visualizer import Visualizer
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
- 导入部分,大致分块注释,按此理解即可。
- 关于argparse这个包,这篇说的很清楚,建议先看,不再赘述。
- 之后按源码顺序每个函数分开说。
参数解析,get_argparser函数
提示:看完上个部分所说的argparse包的相关介绍,再看下面的内容。
def get_argparser():
parser = argparse.ArgumentParser()
# Datset Options
#指定数据集位置,'./datasets/data',后期可以根据自己数据集位置修改。
parser.add_argument("--data_root", type=str, default='./datasets/data',
help="path to Dataset")
#指定使用的数据集数据集,此处是['voc', 'cityscapes'],默认'voc',可以改为自己的数据集
parser.add_argument("--dataset", type=str, default='voc',
choices=['voc', 'cityscapes'], help='Name of dataset')
#数据类别,如VOC默认21类
parser.add_argument("--num_classes", type=int, default=None,
help="num classes (default: None)")
# Deeplab Options
#在network.modeling中提供可选择的模型,如本套代码提供deeplab V3和V3+每种6个具体模型,都是基于4类骨干网的选择(hrnetv2的2个,resnet的2个,mobilenet的1个,xception的1个)。
available_models = sorted(name for name in network.modeling.__dict__ if name.islower() and \
not (name.startswith("__") or name.startswith('_')) and callable(
network.modeling.__dict__[name])
)
# 具体特征提取的(骨干网)模型,此处可选'model' (choose from 'deeplabv3_hrnetv2_32', 'deeplabv3_hrnetv2_48', 'deeplabv3_mobilenet', 'deeplabv3_resnet101', 'deeplabv3_resnet50', 'deeplabv3_xception',
#'deeplabv3plus_hrnetv2_32', 'deeplabv3plus_hrnetv2_48', 'deeplabv3plus_mobilenet', 'deeplabv3plus_resnet101', 'deeplabv3plus_resnet50', 'deeplabv3plus_xception')
parser.add_argument("--model", type=str, default='deeplabv3plus_mobilenet',
choices=available_models, help='model name')
#部署剪枝会用到,一开始不用看。ASPP层卷积用
parser.add_argument("--separable_conv", action='store_true', default=False,
help="apply separable conv to decoder and aspp")
parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16])
# Train Options
#要使用训练,不选
parser.add_argument("--test_only", action='store_true', default=False)
#保存则开启
parser.add_argument("--save_val_results", action='store_true', default=False,
help="save segmentation results to \"./results\"")
#迭代次数
parser.add_argument("--total_itrs", type=int, default=30e3,
help="epoch number (default: 30k)")
#学习率
parser.add_argument("--lr", type=float, default=0.01,
help="learning rate (default: 0.01)")
parser.add_argument("--lr_policy", type=str, default='poly', choices=['poly', 'step'],
help="learning rate scheduler policy")
parser.add_argument("--step_size", type=int, default=10000)
#裁剪验证集
parser.add_argument("--crop_val", action='store_true', default=False,
help='crop validation (default: False)')
parser.add_argument("--batch_size", type=int, default=16,
help='batch size (default: 16)')
parser.add_argument("--val_batch_size", type=int, default=4,
help='batch size for validation (default: 4)')
#限制图像输入的大小,此处裁剪为513.
parser.add_argument("--crop_size", type=int, default=513)
parser.add_argument("--ckpt", default=None, type=str,
help="restore from checkpoint")
parser.add_argument("--continue_training", action='store_true', default=False)
#focal_loss是一个动态缩放的交叉熵损失,通过一个动态缩放因子,可以动态降低训练过程中易区分样本的权重,从而将重心快速聚焦在那些难区分的样本
parser.add_argument("--loss_type", type=str, default='cross_entropy',
choices=['cross_entropy', 'focal_loss'], help="loss type (default: False)")
parser.add_argument("--gpu_id", type=str, default='0',
help="GPU ID")
#权重衰减
parser.add_argument("--weight_decay", type=float, default=1e-4,
help='weight decay (default: 1e-4)')
parser.add_argument("--random_seed", type=int, default=1,
help="random seed (default: 1)")
parser.add_argument("--print_interval", type=int, default=10,
help="print interval of loss (default: 10)")
parser.add_argument("--val_interval", type=int, default=100,
help="epoch interval for eval (default: 100)")
parser.add_argument("--download", action='store_true', default=False,
help="download datasets")
# PASCAL VOC Options
parser.add_argument("--year", type=str, default='2012',
choices=['2012_aug', '2012', '2011', '2009', '2008', '2007'], help='year of VOC')
# Visdom options #可视化选项,默认不用,但挺好用的。
parser.add_argument("--enable_vis", action='store_true', default=False,
help="use visdom for visualization")
parser.add_argument("--vis_port", type=str, default='13570',
help='port for visdom')
parser.add_argument("--vis_env", type=str, default='main',
help='env for visdom')
parser.add_argument("--vis_num_samples", type=int, default=8,
help='number of samples for visualization (default: 8)')
return parser
Tips
- 解析参数函数很有用,尤其是初期只运行代码时,让你能快速理解readme中第3部分的那些命令。逐一对照参数名字理解即可,如下:
python main.py --model deeplabv3plus_mobilenet --enable_vis --vis_port 28333 --gpu_id 0 --year 2012_aug --crop_val --lr 0.01 --crop_size 513 --batch_size 16 --output_stride 16
- 函数中很多参数,在验证性尝试中可能用不上,但没关系,入门学习,要有耐心,第一个函数也就结束了。
- 下一个函数是数据集函数get_dataset。