Faster-RCNN代码解读3:制作自己的数据加载器
前言
因为最近打算尝试一下Faster-RCNN的复现,不要多想,我还没有厉害到可以一个人复现所有代码。所以,是参考别人的代码,进行自己的解读。
代码来自于B站的UP主(大佬666),其把代码都放到了GitHub上了,我把链接都放到下面了(应该不算侵权吧,毕竟代码都开源了_):
b站链接:https://www.bilibili.com/video/BV1of4y1m7nj/?vd_source=afeab8b555e5eb1bfa1e7f267262cbf2
GitHub链接:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing
目的
其实UP主已经做了很好的视频讲解了他的代码,只是有时候我还是喜欢阅读博客来学习,另外视频很长,6个小时,我看的时候容易睡着_,所以才打算写博客记录一下学习笔记。
目前完成的内容
第一篇:VOC数据集详细介绍
第二篇:Faster-RCNN代码解读2:快速上手使用
第三篇:Faster-RCNN代码解读3:制作自己的数据加载器(本文)
目录结构
文章目录
- Faster-RCNN代码解读3:制作自己的数据加载器
- 1. 前言:
- 2. my_dataset.py文件解读:
- 2.1 init方法:
- 2.2 len方法:
- 2.3 getitem方法:
- 2.4 辅助方法:get_height_and_width
- 2.5 辅助方法:parse_xml_to_dict
- 2.6 辅助方法:coco_index
- 3. 总结:
1. 前言:
其实这个部分还是比较简单的(如果你看过我前面的图像分类加载器实现或者自己实现过),就是定义一个dataset
类。
2. my_dataset.py文件解读:
我们知道,想要定义自己的dataset类,首先需要继承于torch的Dataset类,并且至少需要定义三个方法,即__init__
、__len__
和__getitem__
。
那么,可以写出大体框架:
class VOCDataSet(Dataset):
"""读取解析PASCAL VOC2007/2012数据集"""
def __init__(self):
pass
def __len__(self):
pass
def __getitem__(self, idx):
pass
好的,下面我们来一一实现。
2.1 init方法:
首先,需要定义我们的输入参数,这里如果是自己从头实现的话,估计需要想到什么参数用参数。但是,我们解读的话,就直接看作者定义了哪些参数:
- voc_root: 数据集所在的根目录
- year: 指定读取2007还是2012的数据集,默认为2012
- transforms: 预处理方法,默认为None
- txt_name: 指定加载训练集还是测试集,默认为训练集,即train.txt
接下来,第一步,增加一下代码的容错能力,就是判断一下传入的参数正不正确,并拼接出需要的路径:
# 判断是不是2007或2012,否则报错
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
# 增加容错能力
if "VOCdevkit" in voc_root:
# 如果传入的参数为:.\VOCdevkit,那么直接拼接为.\VOCdevkit\VOC2012
self.root = os.path.join(voc_root, f"VOC{year}")
else:
# 如果传入的参数为:. ,那么直接拼接为.\VOCdevkit\VOC2012
self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
# 拼接路径,即图片路径和注释路径
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
第二步,读取数据集.\VOCdevkit\VOC2012\ImageSets\Main
里面的训练集或测试集txt文件(如果你不知道这里面为什么的话,可以看第一篇文章,VOC数据集介绍),并将里面的值和后缀xml
拼接为训练集或测试集的注释文件:
# 读取train或者val文件
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
# 然后,将文件名(2007_000027)和后缀拼接在一起,这样才是真实的文件
with open(txt_path) as read:
xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines() if len(line.strip()) > 0]
第三步,需要一一读取xml文件,并将里面的内容转为字典值,主要目的是检查一下xml文件是否有问题:
# 定义真正的xml列表
self.xml_list = []
# 检测所有xml文件是否存在并读取内容
for xml_path in xml_list:
if os.path.exists(xml_path) is False:
print(f"Warning: not found '{xml_path}', skip this annotation file.")
continue
# 如果xml文件存在,继续下面的代码
# check for targets
# 读取xml文件
with open(xml_path) as fid:
xml_str = fid.read()
# 构建xml对象
xml = etree.fromstring(xml_str)
# 获取节点的内容,并转为字典值
data = self.parse_xml_to_dict(xml)["annotation"] # 获取annotation节点下的所有内容
if "object" not in data: # 判断object节点是否存在,如果不存在说明xml文件其实有问题,所以需要跳过
print(f"INFO: no objects in {xml_path}, skip this annotation file.")
continue
# 添加
self.xml_list.append(xml_path)
第四步,加载类别json文件,并读取里面的内容:
# 读取类别文件,一共20个类,从1开始是因为0留给背景
json_file = './pascal_voc_classes.json'
assert os.path.exists(json_file), "{} file not exist.".format(json_file)
with open(json_file, 'r') as f:
self.class_dict = json.load(f)
最后,将预处理函数放入一个变量中:
self.transforms = transforms
**总结一下:**经过上面的处理,我们得到了几个主要的变量:
- self.xml_list:里面的值为一个个训练集或测试集的xml文件,里面的值为文件路径值
- self.transforms:里面为我们的预处理方法
- self.class_dict:为我们的类别字典,里面的值为{‘preson’:2}这样的形式
给大家看看,debug下的值的内容:
2.2 len方法:
len方法,这个是最简单的方法,其作用就是返回长度值:
def __len__(self):
# len函数就是返回长度
return len(self.xml_list)
2.3 getitem方法:
这个方法和init方法一样十分重要,其作用就是获取图像和图像对应的标签等信息。
def __getitem__(self, idx):
pass
其中idx是这个方法必备的一个参数,其是随机返回一个索引值,来方便你取你之前在init方法定义的变量里的值。
那么,首先,获取一个xml文件,并打开它获取根节点里面的内容:
# 随机读取一个xml文件
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
# 创建xml对象
xml = etree.fromstring(xml_str)
# 获取根节点,转为字典值
data = self.parse_xml_to_dict(xml)["annotation"]
这里解释一下上面的data值为啥。其实就是xml文件annotation节点里的所有内容,如下图框出来的内容:
当然,同样用debug看看里面真实情况下的值:
然后,**我们知道xml文件名和图片名是对应的,**因此通过xml文件获取图片名字并打开这个图像:
# 获取xml文件对应的图像路径
img_path = os.path.join(self.img_root, data["filename"])
# 打开图像
image = Image.open(img_path)
# 判断图像是否为jpeg格式,主要作者防止别人插入了其它的文件
if image.format != "JPEG":
raise ValueError("Image '{}' format not JPEG".format(img_path))
接着,初始化一些变量:
# 初始化一些变量
boxes = [] # 边界框
labels = [] # 标签值
iscrowd = [] # 是否为难以识别的图像
下面开始是最重要的内容。
首先,迭代读取xml文件object
节点下的内容:
# 读取xml文件中object节点下的内容
for obj in data["object"]:
其中的,obj为下图中的值:
或者可以从xml文件中对应查看:
接着,获取对象的真实边界框的坐标值(左上角,右下角):(ps:下面的代码都是放在上面的for循环里面的)
# 获取bbox框的坐标
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
检测一下,边界框是否有问题:
# 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
if xmax <= xmin or ymax <= ymin:
print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
continue
然后,把坐标值加入boxes变量中,把标签加入labels变量中,并判断图像是否为难以识别的,然后加入iscrowd变量中:
boxes.append([xmin, ymin, xmax, ymax])
# 添加标签 obj["name"]=person, self.class_dict[obj["name"]] = 15
labels.append(self.class_dict[obj["name"]])
# 判断是否为difficult类型
if "difficult" in obj:
iscrowd.append(int(obj["difficult"]))
else:
iscrowd.append(0)
然后,把所有的变量类型都转为tensor格式(此时已经结束了循环):
# 将所有的类型转为tensor类型
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
接着,根据边框框的四个坐标,计算一下边界框的面积,主要方便后期计算IOU:
# boxes =[[,,,],[,,,],。。。。。。]
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# (ymax - ymin) * (xmax - xmin) ,即框的面积
最后,把上面的所有值放入一个字典变量中即可:
# 把这些东西放入一个字典中
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
然后,对图像进行预处理并返回图像和其对应的值即可:
# 变换,此时为自己实现的方法,不是官方的方法
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
最后,我们在debug下看看变量的值:
2.4 辅助方法:get_height_and_width
作用:获取图像的宽和高。
这个十分简单,就是通过xml文件来获取的,还不需要我们自己通过坐标计算:
def get_height_and_width(self, idx):
# 获取图像的宽和高
# 读取xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
# 构建xml对象
xml = etree.fromstring(xml_str)
# 获取根节点
data = self.parse_xml_to_dict(xml)["annotation"]
# 获取宽和高
data_height = int(data["size"]["height"])
data_width = int(data["size"]["width"])
return data_height, data_width
2.5 辅助方法:parse_xml_to_dict
主要作用:将xml格式的数据解析为字典格式,即将节点-----节点的值,转为{‘节点’:‘节点的值’}。
这个方法是通过递归来实现的,这个没什么好说的,如果你想搞清楚如何运行的,可以自己一步一步的推导:
def parse_xml_to_dict(self, xml):
"""
将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
"""
if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息
# xml.tag节点名字
# xml.text里面的值
return {xml.tag: xml.text}
result = {}
# 对于每个xml中的子节点
for child in xml:
child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
2.6 辅助方法:coco_index
这个方法与getitem方法是相同的作用,只是不读取图片,流程都是一样的,我就不细说了。
3. 总结:
my_dataset.py文件主要实现了数据加载器的类,实现思路很简单,但是代码量还是比较大的。
另外,作者在该文件的末尾展示了一下这个类的使用示例代码,大家可以直接把注释取消运行看看结果: