要复现Unbiased Mean Teacher for Cross-domain Object Detection(UMT),首先要正确运行CycleGAN。
1. CycleGAN
CycleGAN的github链接:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
1.1 CycleGAN环境配置
git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
cd pytorch-CycleGAN-and-pix2pix
conda env create -f environment.yml
1.2 CycleGAN数据集准备
1.2.1 如果使用官方数据集
不要运行1.2.2的内容。
bash ./datasets/download_cyclegan_dataset.sh maps
1.2.2 如果使用自己的数据集
不要运行1.2.1的内容。
以下将自己准备的两个图片数据集分别称为A和B。自己划分好训练集、测试集和验证集。
将自己的数据集放到./datasets/maps下,目录结构如下图所示。
maps
├── trainA
├── trainB
├── testA
├── testB
├── valA
└── valB
1.3 CycleGAN训练
python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
训练好的模型保存在./checkpoints/maps_cyclegan/文件夹内的latest_net_G_A.pth和latest_net_G_B.pth文件内。
1.4 CycleGAN测试
将latest_net_G_A.pth存放至./checkpoints/generateA/路径下,并改名为latest_net_G.pth。
generateA也可以随便改成其他名字,只要和下面的–name一致即可。
运行下面指令,即可得到B风格的A数据集,结果保存在./results/generateA/test_latest/images/内。
python test.py --dataroot [自己的数据集路径] --name generateA--model test --num_test [自己数据集的图片总数] --no_dropout
同理,可得到A风格的B数据集。
2. UMT
2.1 环境配置
UMT环境继承了SW_DA模型,使用的包版本极低,CUDA版本使用9.0或9.1可避免CUDA报错。
torch=0.4.0
torchvision=0.2.0
其他包遵循SW_DA模型。链接:https://github.com/VisionLearningGroup/DA_Detection
2.2 数据集配置
UMT需要使用4个数据集,A,B,B风格的A数据集,A风格的B数据集。
2.2.1 数据集内部结构
SCUT
├── SCUT_A
│ └── VOC2007
│ ├── Annotations
│ ├── ImageSets
│ └── JPEGImages # A数据集图片
├── AlikeB
│ └── VOC2007
│ ├── Annotations
│ ├── ImageSets
│ └── JPEGImages # B风格的A数据集图片
├── SCUT_B
│ ├── Annotations
│ ├── ImageSets
│ └── JPEGImages # B数据集图片
└── BlikeA
├── Annotations
├── ImageSets
└── JPEGImages # A风格的B数据集图片
2.2.2 数据集路径配置
本文名义上使用的是VOC2007和clipart,实际上将VOC2007和clipart内部包含的图片内容换成了自己的A和B,这样就可以跑自己的数据集了。
需要修改lib/datasets/config_dataset.py的内容:
__D.PASCAL改为自己的A数据集的路径
__D.PASCAL_CYCLECLIPART改为自己的B风格的A数据集的路径
__D.CLIPART改为自己的B数据集的路径
__D.CLIPART_CYCLEVOC改为自己的A风格的B数据集的路径
# with regard to pascal, the directories under the path will be ./VOC2007, ./VOC2012"
__D.PASCAL = "/home/lch1999/SCUT/SCUT_A/"
__D.PASCAL_CYCLECLIPART = (
"/home/lch1999/SCUT/AlikeB/"
)
__D.CLIPART = "/home/lch1999/SCUT/SCUT_B/"
__D.CLIPART_CYCLEVOC = (
"/home/lch1999/SCUT/BlikeA/"
)
2.3 训练
python umt_train.py --dataset pascal_voc --net vgg16
2.4 测试
./test.sh 0 models/vgg16/pascal_voc/conf_True_conf_gamma_0.1_source_like_True_aug_True_target_like_True_pe_0_pl_True_thresh_0.8_lambda_0.01_lam2_0.1_student_target_clipart_session_1_epoch_8_step_10000.pth
./test.sh 0 models/vgg16/pascal_voc/conf_True_conf_gamma_0.1_source_like_True_aug_True_target_like_True_pe_0_pl_True_thresh_0.8_lambda_0.01_lam2_0.1_teacher_target_clipart_session_1_epoch_8_step_10000.pth
2.5 测试结果
自己的数据集只标注了person,所以UMT也只检测了person,所以其他物体的AP=0是正常的。