目录
一、如何生成类似pascal voc一样结构的文件(split_data.py)
二、如何创建属于自己的数据集(my_dataset.py)
2.1 代码
2.2 代码解释
2.2.1 初始化函数__init__
2.2.2 parse_xml_to_dict函数(解析xml文件)
2.2.3 __getitem__方法:传入索引值,返回索引值对应的图片信息
2.2.4 get_height_and_width方法
三、图片处理类(transform.py)
四、测试本节代码
一、如何生成类似pascal voc一样结构的文件(split_data.py)
如图,如何生成像上图一样的train、test文件。
我们来看代码:
import os import random """ 作用:如何生成自己数据集的目录,将数据集分为训练集和验证集 """ def main(): random.seed(0) # 设置随机种子,保证随机结果可复现 #标注的xml的根目录 files_path = "./VOCdevkit/VOC2012/Annotations" assert os.path.exists(files_path), "path: '{}' does not exist.".format(files_path) val_rate = 0.5 #os.listdir(files_path)可以遍历整个目录下的文件,我们可以得到类似annotation一样 #文件的名称 2007_000027.xml 2007_000028.xml ... #通过.这个字符进行分割 分割之后就分割为 2007_000027 xml 取第0维度 即图片的名称 files_name = sorted([file.split(".")[0] for file in os.listdir(files_path)]) files_num = len(files_name) #随机采样 范围是0-file_num 采样个数为files_num*val_rate 即二分之一的数据集 val_index = random.sample(range(0, files_num), k=int(files_num*val_rate)) train_files = [] val_files = [] #对于每一个在files_name中的文件名称进行遍历 for index, file_name in enumerate(files_name): if index in val_index: val_files.append(file_name) else: train_files.append(file_name) try: train_f = open("train.txt", "x") eval_f = open("val.txt", "x") train_f.write("\n".join(train_files)) eval_f.write("\n".join(val_files)) except FileExistsError as e: print(e) exit(1) if __name__ == '__main__': main()
我们首先给定文件路径files_path,这个目录是标注的xml文件的目录,标注的xml文件如下:
val_rate指的是验证集的比率,这里设置为50%。
files_name = sorted([file.split(".")[0] for file in os.listdir(files_path)])
这行代码的含义是将files_path中的所有文件遍历,即一个个xml文件,将这些xml文件用.隔开,以2007_000032.xml为例,我们将其分割成2007_000032 和 xml两部分,取第一部分(索引为0)的部分,即2007_000032,我们将这个文件夹中的所有的xml的文件名抽取出来并排序,将文件名放入files_name这个变量中。
val_index = random.sample(range(0, files_num), k=int(files_num*val_rate))
用这行代码进行随机采样,采样的范围是(0-files_name的数量),采样个数为(files_name的数量*0.5即一半),里面存放的是图片的索引。
最后我们遍历file_name,若其索引在val_index中则将图片名称放入val_files列表,否则放入train_files列表中。
>>> seq = ['one', 'two', 'three'] >>> for i, element in enumerate(seq): ... print i, element ... 0 one 1 two 2 three
随后创建文件train.txt,val.txt。都是如下格式:
二、如何创建属于自己的数据集(my_dataset.py)
2.1 代码
import numpy as np from torch.utils.data import Dataset import os import torch import json from PIL import Image from lxml import etree class VOCDataSet(Dataset): """读取解析PASCAL VOC2007/2012数据集""" #@voc_root 训练集所在的根目录 def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"): #断言 year必须在2007和2012 否则就报错 assert year in ["2007", "2012"], "year must be in ['2007', '2012']" # 增加容错能力 定义root目录 #root目录 VOCdevkit/VOC2012/(Annotations、ImageSets、JEPGImages、..) #self.root = VOCdevkit/VOC2012/ if "VOCdevkit" in voc_root: self.root = os.path.join(voc_root, f"VOC{year}") else: self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}") #self.img_root = VOCdevkit/VOC2012/JPEGImages self.img_root = os.path.join(self.root, "JPEGImages") #self.annotations_root = VOCdevkit/VOC2012/Annotations self.annotations_root = os.path.join(self.root, "Annotations") # 图片的索引目录,我们上一步 #txt_path = VOCdevkit/VOC2012/ImageSets/Main/train.txt txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name) #如果txt_path为空(这个目录下面没有东西),则抛出异常not found {txt_name} file assert os.path.exists(txt_path), "not found {} file.".format(txt_name) #按行读取文件 读取VOCdevkit/VOC2012/Annotations/去掉换行符图片名称.xml文件 #将所有的xml文件名称传入到xml_list中 with open(txt_path) as read: xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml") for line in read.readlines() if len(line.strip()) > 0] #每行读取 self.xml_list = [] # 遍历xml_list for xml_path in xml_list: #如果没能找到该信息报错 if os.path.exists(xml_path) is False: print(f"Warning: not found '{xml_path}', skip this annotation file.") continue # check for targets #该方法是将xml格式转化为Element 对象,Element 对象代表 XML 文档中的一个元素。元素可以包含属性、其他元素或文本。 with open(xml_path) as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) data = self.parse_xml_to_dict(xml)["annotation"] if "object" not in data: print(f"INFO: no objects in {xml_path}, skip this annotation file.") continue self.xml_list.append(xml_path) assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path) # read 分类器 放入self.class_dict中 json_file = './pascal_voc_classes.json' assert os.path.exists(json_file), "{} file not exist.".format(json_file) with open(json_file, 'r') as f: self.class_dict = json.load(f) #读取图像变换 self.transforms = transforms #返回数据集文件的个数 def __len__(self): return len(self.xml_list) #传入索引值,返回索引值对应的图片信息 def __getitem__(self, idx): # 获取xml文件 xml_path = self.xml_list[idx] with open(xml_path) as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) data = self.parse_xml_to_dict(xml)["annotation"] #data是一个字典,里面包含了xml信息 有一条是<filename>2007_000063.jpg</filename> #于是 self.img_root, data["filename"] = VOCdevkit/VOC2012/JPEGImages/2007_000063.jpg img_path = os.path.join(self.img_root, data["filename"]) image = Image.open(img_path) if image.format != "JPEG": raise ValueError("Image '{}' format not JPEG".format(img_path)) boxes = [] labels = [] iscrowd = [] """ <object> <name>dog</name> <pose>Unspecified</pose> <truncated>o</truncated> <difficult>0</difficult> <bndbox> <xmin>123</xmin> <ymin>115</ymin> <xmax>379</xmax> <ymax>275</ymax> </bndbox> </object> """ assert "object" in data, "{} lack of object information.".format(xml_path) #遍历object中的每一个信息 可能是桌子/狗/猫.... for obj in data["object"]: xmin = float(obj["bndbox"]["xmin"]) xmax = float(obj["bndbox"]["xmax"]) ymin = float(obj["bndbox"]["ymin"]) ymax = float(obj["bndbox"]["ymax"]) # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan if xmax <= xmin or ymax <= ymin: print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path)) continue boxes.append([xmin, ymin, xmax, ymax]) labels.append(self.class_dict[obj["name"]]) if "difficult" in obj: iscrowd.append(int(obj["difficult"])) else: iscrowd.append(0) # convert everything into a torch.Tensor boxes = torch.as_tensor(boxes, dtype=torch.float32) labels = torch.as_tensor(labels, dtype=torch.int64) iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64) image_id = torch.tensor([idx]) area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) target = {} target["boxes"] = boxes target["labels"] = labels target["image_id"] = image_id target["area"] = area target["iscrowd"] = iscrowd if self.transforms is not None: image, target = self.transforms(image, target) return image, target def get_height_and_width(self, idx): # read xml xml_path = self.xml_list[idx] with open(xml_path) as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) data = self.parse_xml_to_dict(xml)["annotation"] data_height = int(data["size"]["height"]) data_width = int(data["size"]["width"]) return data_height, data_width #xml文件里面包含很多信息,包括标注框体 def parse_xml_to_dict(self, xml): """ 将xml文件解析成字典形式 Args: xml: xml tree obtained by parsing XML file contents using lxml.etree Returns: Python dictionary holding XML contents. """ if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息 return {xml.tag: xml.text} result = {} for child in xml: #遍历annotations下面的子目录,递归调用 child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息 #子目录的tag名称是否是object if child.tag != 'object': #在result字典中存入 键为child.tag 值为child.tag中的值 result[child.tag] = child_result[child.tag] else: if child.tag not in result: # 因为object可能有多个,所以需要放入列表里 result[child.tag] = [] """ <object> <name>dog</name> <pose>Unspecified</pose> <truncated>0</truncated> <diffcult>0</difficult> <bndbox> <xmin>123</xmin> <ymin>115</ymin> <xmax>379</xmax> <ymax>275</ymax> </bndbox> </object> """ result[child.tag].append(child_result[child.tag]) return {xml.tag: result} def coco_index(self, idx): """ 该方法是专门为pycocotools统计标签信息准备,不对图像和标签作任何处理 由于不用去读取图片,可大幅缩减统计时间 Args: idx: 输入需要获取图像的索引 """ # read xml xml_path = self.xml_list[idx] with open(xml_path) as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) data = self.parse_xml_to_dict(xml)["annotation"] data_height = int(data["size"]["height"]) data_width = int(data["size"]["width"]) # img_path = os.path.join(self.img_root, data["filename"]) # image = Image.open(img_path) # if image.format != "JPEG": # raise ValueError("Image format not JPEG") boxes = [] labels = [] iscrowd = [] for obj in data["object"]: xmin = float(obj["bndbox"]["xmin"]) xmax = float(obj["bndbox"]["xmax"]) ymin = float(obj["bndbox"]["ymin"]) ymax = float(obj["bndbox"]["ymax"]) boxes.append([xmin, ymin, xmax, ymax]) labels.append(self.class_dict[obj["name"]]) iscrowd.append(int(obj["difficult"])) # convert everything into a torch.Tensor boxes = torch.as_tensor(boxes, dtype=torch.float32) labels = torch.as_tensor(labels, dtype=torch.int64) iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64) image_id = torch.tensor([idx]) area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) target = {} target["boxes"] = boxes target["labels"] = labels target["image_id"] = image_id target["area"] = area target["iscrowd"] = iscrowd return (data_height, data_width), target @staticmethod def collate_fn(batch): return tuple(zip(*batch)) # import transforms # from draw_box_utils import draw_objs # from PIL import Image # import json # import matplotlib.pyplot as plt # import torchvision.transforms as ts # import random # # # read class_indict # category_index = {} # try: # json_file = open('./pascal_voc_classes.json', 'r') # class_dict = json.load(json_file) # category_index = {str(v): str(k) for k, v in class_dict.items()} # except Exception as e: # print(e) # exit(-1) # # data_transform = { # "train": transforms.Compose([transforms.ToTensor(), # transforms.RandomHorizontalFlip(0.5)]), # "val": transforms.Compose([transforms.ToTensor()]) # } # # # load train data set # train_data_set = VOCDataSet(os.getcwd(), "2012", data_transform["train"], "train.txt") # print(len(train_data_set)) # for index in random.sample(range(0, len(train_data_set)), k=5): # img, target = train_data_set[index] # img = ts.ToPILImage()(img) # plot_img = draw_objs(img, # target["boxes"].numpy(), # target["labels"].numpy(), # np.ones(target["labels"].shape[0]), # category_index=category_index, # box_thresh=0.5, # line_thickness=3, # font='arial.ttf', # font_size=20) # plt.imshow(plot_img) # plt.show()
2.2 代码解释
2.2.1 初始化函数__init__
我们传入了训练集所在的根目录voc_root(VOCdevkit的目录)、transform(图像预处理方法)、txt_name: str = "train.txt"。(刚才生成的)
初始化了几个类内变量:根目录、图像根目录、标注根目录、图片索引根目录
self.root:VOCdevkit/VOC2012/
self.img_root = VOCdevkit/VOC2012/JPEGImages
self.annotations_root = VOCdevkit/VOC2012/Annotations
txt_path = VOCdevkit/VOC2012/ImageSets/Main/train.txt
with open(txt_path) as read: xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml") for line in read.readlines() if len(line.strip()) > 0] #每行读取
读取VOCdevkit/VOC2012/ImageSets/Main/train.txt的每一行(即一张图片名称),并做一个拼接VOCdevkit/VOC2012/Annotations/2008_xxxxxx.xml。得到了每一个图片文件所对应的xml文件。存储到xml_list 中。
遍历xml_list :
for xml_path in xml_list: #如果没能找到该信息报错 if os.path.exists(xml_path) is False: print(f"Warning: not found '{xml_path}', skip this annotation file.") continue # check for targets #该方法是将xml格式转化为Element 对象,Element 对象代表 XML 文档中的一个元素。元素可以包含属性、其他元素或文本。 with open(xml_path) as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) data = self.parse_xml_to_dict(xml)["annotation"] if "object" not in data: print(f"INFO: no objects in {xml_path}, skip this annotation file.") continue self.xml_list.append(xml_path)
如果其中一个没有xml文件则报错,如果有该条目则将xml格式转化为Element 对象,Element 对象代表 XML 文档中的一个元素。元素可以包含属性、其他元素或文本。
最终得到标注信息存放进self.xml_list中。
最后读取分类器索引放入self.class_dict中。
{ "aeroplane": 1, "bicycle": 2, "bird": 3, "boat": 4, "bottle": 5, "bus": 6, "car": 7, "cat": 8, "chair": 9, "cow": 10, "diningtable": 11, "dog": 12, "horse": 13, "motorbike": 14, "person": 15, "pottedplant": 16, "sheep": 17, "sofa": 18, "train": 19, "tvmonitor": 20 }
将图像变换放入self.transforms中。
因此,初始化函数中,我们初始化了如下类内变量:
self.root、self.img_root、self.annotations_root、txt_path、self.xml_list、self.class_dict、self.transforms。
2.2.2 parse_xml_to_dict函数(解析xml文件)
#xml文件里面包含很多信息,包括标注框体 def parse_xml_to_dict(self, xml): """ 将xml文件解析成字典形式 Args: xml: xml tree obtained by parsing XML file contents using lxml.etree Returns: Python dictionary holding XML contents. """ if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息 return {xml.tag: xml.text} result = {} for child in xml: #遍历annotations下面的子目录,递归调用 child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息 #子目录的tag名称是否是object if child.tag != 'object': #在result字典中存入 键为child.tag 值为child.tag中的值 result[child.tag] = child_result[child.tag] else: if child.tag not in result: # 因为object可能有多个,所以需要放入列表里 result[child.tag] = [] """ <object> <name>dog</name> <pose>Unspecified</pose> <truncated>0</truncated> <diffcult>0</difficult> <bndbox> <xmin>123</xmin> <ymin>115</ymin> <xmax>379</xmax> <ymax>275</ymax> </bndbox> </object> """ result[child.tag].append(child_result[child.tag]) return {xml.tag: result}
我们看XML文件的格式:
传进来的时候,我们先判断顶层annotations还有没有子目录(source、size...),返回他们的数量,若是底层则返回底层的信息(xmin....)。
若下层还有东西,则定义一个result字典,然后我们遍历我们的xml文件,存储xml文件。
2.2.3 __getitem__方法:传入索引值,返回索引值对应的图片信息
def __getitem__(self, idx): # 获取xml文件 xml_path = self.xml_list[idx] with open(xml_path) as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) data = self.parse_xml_to_dict(xml)["annotation"] #data是一个字典,里面包含了xml信息 有一条是<filename>2007_000063.jpg</filename> #于是 self.img_root, data["filename"] = VOCdevkit/VOC2012/JPEGImages/2007_000063.jpg img_path = os.path.join(self.img_root, data["filename"]) image = Image.open(img_path) if image.format != "JPEG": raise ValueError("Image '{}' format not JPEG".format(img_path)) boxes = [] labels = [] iscrowd = [] """ <object> <name>dog</name> <pose>Unspecified</pose> <truncated>o</truncated> <difficult>0</difficult> <bndbox> <xmin>123</xmin> <ymin>115</ymin> <xmax>379</xmax> <ymax>275</ymax> </bndbox> </object> """ assert "object" in data, "{} lack of object information.".format(xml_path) #遍历object中的每一个信息 可能是桌子/狗/猫.... for obj in data["object"]: xmin = float(obj["bndbox"]["xmin"]) xmax = float(obj["bndbox"]["xmax"]) ymin = float(obj["bndbox"]["ymin"]) ymax = float(obj["bndbox"]["ymax"]) # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan if xmax <= xmin or ymax <= ymin: print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path)) continue boxes.append([xmin, ymin, xmax, ymax]) labels.append(self.class_dict[obj["name"]]) if "difficult" in obj: iscrowd.append(int(obj["difficult"])) else: iscrowd.append(0) # convert everything into a torch.Tensor boxes = torch.as_tensor(boxes, dtype=torch.float32) labels = torch.as_tensor(labels, dtype=torch.int64) iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64) image_id = torch.tensor([idx]) area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) target = {} target["boxes"] = boxes target["labels"] = labels target["image_id"] = image_id target["area"] = area target["iscrowd"] = iscrowd if self.transforms is not None: image, target = self.transforms(image, target) return image, target
首先获取对应idx的xml文件的信息,用data读取xml文件的信息(data是以字典进行存储的)。
data['filename"]存储的是图片的名称,img_path存储的是图片的绝对路径,用img读取图片信息。
用assert声明判断data字典里是否有object字段:
对每一个object字段,获取其bndbox信息(obj["bndbox"]["xmin"])、标签信息对应的索引值信息(self.class_dict[obj["name"]])、是否为难分辨样本信息。
最后将这些信息封装进target字典中,target字典中包含如下项目:
target = {} 声明target字典
target["boxes"] = boxes 框体信息(可能有多个)
target["labels"] = labels 标签信息(可能有多个)
target["image_id"] = image_id 照片索引值信息
target["area"] = area 框体面积信息(可能有多个)
target["iscrowd"] = iscrowd 难分辨信息(可能有多个)image 图片信息
2.2.4 get_height_and_width方法
def get_height_and_width(self, idx): # read xml xml_path = self.xml_list[idx] with open(xml_path) as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) data = self.parse_xml_to_dict(xml)["annotation"] data_height = int(data["size"]["height"]) data_width = int(data["size"]["width"]) return data_height, data_width
同理,我们解析xml文件,获取xml文件的data字典信息。
三、图片处理类(transform.py)
import random from torchvision.transforms import functional as F class Compose(object): """组合多个transform函数""" def __init__(self, transforms): self.transforms = transforms def __call__(self, image, target): for t in self.transforms: image, target = t(image, target) return image, target class ToTensor(object): """将PIL图像转为Tensor""" def __call__(self, image, target): image = F.to_tensor(image) return image, target #target的信息 # target = {} # target["boxes"] = boxes # target["labels"] = labels # target["image_id"] = image_id # target["area"] = area # target["iscrowd"] = iscrowd class RandomHorizontalFlip(object): """随机水平翻转图像以及bboxes""" def __init__(self, prob=0.5): self.prob = prob def __call__(self, image, target): if random.random() < self.prob: height, width = image.shape[-2:] image = image.flip(-1) # 水平翻转图片 bbox = target["boxes"] # bbox: xmin, ymin, xmax, ymax bbox[:, [0, 2]] = width - bbox[:, [2, 0]] # 翻转对应bbox坐标信息 target["boxes"] = bbox return image, target
ToTensor就是利用pytorch官方的方法将图片转化成一个tensor格式。
随机水平翻转在调用时我们会传入image和target。然后生成一个随机数,如果这个随机数小于0.5(prob)时我们对图像进行一个随机翻转。
翻转之后如右图所示,水平翻转后y值是不会变化的,变化的只有x值。
第一维度是有多少个bndbox不进行改动。翻转后的长度如上。
xmin ([0]) = width - xmax ([2]) xmax ([2])= width - xmin([0])
因此bndbox信息也更改了....返回新的image和target信息。
四、测试本节代码
import transforms from draw_box_utils import draw_objs from PIL import Image import json import matplotlib.pyplot as plt import torchvision.transforms as ts import random #read class_indict category_index = {} try: json_file = open('./pascal_voc_classes.json', 'r') class_dict = json.load(json_file) category_index = {str(v): str(k) for k, v in class_dict.items()} except Exception as e: print(e) exit(-1) # data_transform = { "train": transforms.Compose([transforms.ToTensor(), transforms.RandomHorizontalFlip(0.5)]), "val": transforms.Compose([transforms.ToTensor()]) } # # load train data set train_data_set = VOCDataSet(os.getcwd(), "2012", data_transform["train"], "train.txt") print(len(train_data_set)) for index in random.sample(range(0, len(train_data_set)), k=5): img, target = train_data_set[index] img = ts.ToPILImage()(img) plot_img = draw_objs(img, target["boxes"].numpy(), target["labels"].numpy(), np.ones(target["labels"].shape[0]), category_index=category_index, box_thresh=0.5, line_thickness=3, font='arial.ttf', font_size=20) plt.imshow(plot_img) plt.show()
我们先创建自己训练集的dataset,取出五张图片的img与target展示,如下: