深度学习目标检测项目实战(五)—基于mobilenetv2和resnet的图像背景抠图及其界面封装
该项目很有意思,也是比较前沿,项目主要参考了开源代码:
https://github.com/PeterL1n/BackgroundMattingV2
环境搭建
kornia==0.4.1
tensorboard==2.3.0
torch==1.7.0
torchvision==0.8.1
tqdm==4.51.0
opencv-python==4.4.0.44
onnxruntime==1.6.0
数据集
https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets
下载比较小的数据集
训练
主要看readme的介绍
有两个训练的代码:
train_base.py
train_refine.py
要用gpu训练,不然顶不住,比如
CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \
--dataset-name videomatte240k \
--model-backbone resnet50 \
--model-name mattingrefine-resnet50-videomatte240k \
--model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \
--epoch-end 1
配置data_path.pth以指向您的数据集。原始论文使用train_base.pth只训练基本模型直到收敛,然后使用train_refine.pth端到端训练整个网络。更多细节将在论文中详细说明:
https://arxiv.org/abs/2012.07810
使用
inference_images.py:在图像目录上执行抠图。
inference_video.py:对视频进行抠图处理。
inference_webcam.py:一个使用网络摄像头的交互式抠图演示。
将结果保存为pth文件:
pytorch_mobilenetv2.pth
界面效果
封装成web界面,有模有样。
需要界面代码可以私聊。
这个开源代码的写法很值得学习。