1、问题描述
今天需要训练一个人工智能检测模型,用于检测图片或视频中的人。自行收集训练数据费时费力,因而选择从公开数据集COCO
中进行抽取。
2、数据准备
2.1 下载 COCO2017 数据集
train:http://images.cocodataset.org/zips/train2017.zip
valid:http://images.cocodataset.org/zips/val2017.zip
2.2 下载对应的YOLO格式标签文件
因为训练的是 YOLO 系列模型,所以需要下载其对应的格式标签文件。
https://github.com/ultralytics/yolov5/releases/download/v1.0/coco2017labels.zip
3、Pyhton代码
数据准备好之后则是进行数据抽取了。首先需要新建一个要抽取类别的 yaml
文件。
# 定义需要抽取的类别(classes.yaml)
path: /home/dataset/coco
train: train2017.txt
val: val2017.txt
names:
0: person
1: bicycle
2: car
3: motorcycle
4: airplane
5: bus
之后是抽取代码:
# create_sub_coco_dataset.py
import json
import yaml
import sys
import os
import shutil
import tqdm
def MakeDirs(dir):
if not os.path.exists(dir):
os.makedirs(dir, True)
def create_sub_coco_dataset(data_yaml="classes.yaml",
src_root_dir="../datasets/coco",
dst_root_dir="classes/coco",
folder="val2017"
):
MakeDirs(dst_root_dir + "/annotations/")
MakeDirs(dst_root_dir + "/images/" + folder)
MakeDirs(dst_root_dir + "/labels/" + folder)
print(yaml.safe_load(open(data_yaml).read())['names'])
keep_names = [x + 1 for x in yaml.safe_load(open(data_yaml).read())['names'].keys()]
all_annotations = json.loads(open(src_root_dir + "/annotations/instances_{}.json".format(folder)).read())
keep_categories = [x for x in all_annotations["categories"] if x["id"] in keep_names]
keep_annotations = [x for x in all_annotations['annotations'] if x['category_id'] in keep_names]
all_annotations['annotations'] = keep_annotations
all_annotations["categories"] = keep_categories
if not os.path.exists(dst_root_dir + "/annotations/instances_{}.json".format(folder)):
with open(dst_root_dir + "/annotations/instances_{}.json".format(folder), "w") as f:
json.dump(all_annotations, f)
filelist = set()
for i in tqdm.tqdm(keep_annotations):
img_src_path = "/images/{}/{:012d}.jpg".format(folder, i["image_id"])
label_src_path = "/labels/{}/{:012d}.txt".format(folder, i["image_id"])
if not os.path.exists(dst_root_dir + img_src_path):
shutil.copy(src_root_dir + img_src_path, dst_root_dir + img_src_path)
if not os.path.exists(dst_root_dir + label_src_path):
keep_records = [x for x in open(src_root_dir + label_src_path, "r").readlines() if
(int(x.strip().split(" ")[0]) + 1) in keep_names]
with open(dst_root_dir + label_src_path, "w") as f:
for r in keep_records:
f.write(r)
filelist.add("./images/{}/{:012d}.jpg\n".format(folder, i["image_id"]))
with open(dst_root_dir + "/{}.txt".format(folder), "w") as f:
for r in filelist:
f.write(r)
new_data_yaml = yaml.safe_load(open(data_yaml).read())
new_data_yaml["path"] = dst_root_dir
with open(dst_root_dir + "/coco.yaml", 'w') as f:
f.write(yaml.dump(new_data_yaml, allow_unicode=True))
create_sub_coco_dataset(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])
# 使用命令示例
# python create_sub_coco_dataset.py car.yaml ../../Datasets/coco_extract/coco2017/coco car train2017
# python create_sub_coco_dataset.py person.yaml ../../Datasets/coco_extract/coco2017/coco person val2017
使用命令示例分别包含了抽取类别到train和val文件夹下的运行代码,需要根据自己的路径进行适当修改。
以下就是今天抽取完成的结果文件目录,可以用于训练YOLO
模型了。
好了,今天的分享到此结束,期待下期继续。
往期推荐:
手把手教你玩转人工智能算法,Yolov5实践教程(1)(附源码)
手把手教你玩转人工智能算法,Yolov5实践教程(2)(附源码)
手把手教你生成有趣有料的素描图(附代码)
关注公众号送115G Python和人工智能学习资料。