1.确认任务
经过mydataset文件处理后 - > 在train_res50_fpn文件内应用
# load train data set
# VOCdevkit -> VOC2012 -> ImageSets -> Main -> train.txt
train_dataset = VOCDataSet(VOC_root, "2012", data_transform["train"], "train.txt")
train_sampler = None
在经过mydataset处理后,框出各项位置。
2.原mydataset内容
主要要做的就是在每个xml文件内提取出 类别+类别所在区域(xmin xmax ymin ymax)
2.1 split_data.py分类出训练集 and 验证集
得到结果:
2.2 构造函数 def_init
索引每一个xml文件
xml_list = 每一个训练集中的xml文件集合
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
# 增加容错能力
if "VOCdevkit" in voc_root:
self.root = os.path.join(voc_root, f"VOC{year}")
else:
self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
# read train.txt or val.txt file
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
with open(txt_path) as read:
xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines() if len(line.strip()) > 0]
self.xml_list = []
按行索引classes文件
class_dict 匹配类别对应的序号
# read class_indict
json_file = './pascal_voc_classes.json'
assert os.path.exists(json_file), "{} file not exist.".format(json_file)
with open(json_file, 'r') as f:
self.class_dict = json.load(f)
parse_xml_to_dict方法 把每一个xml文件检测到的类别提取出来<object>
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
提取每个类别的各项信息
位置信息放入boxes中
boxes = []
labels = []
iscrowd = []
assert "object" in data, "{} lack of object information.".format(xml_path)
for obj in data["object"]:
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
# 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
if xmax <= xmin or ymax <= ymin:
print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
continue
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj["name"]])
if "difficult" in obj:
iscrowd.append(int(obj["difficult"]))
else:
iscrowd.append(0)
转为tensor
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
提取图像的高宽
def get_height_and_width(self, idx):
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
data_height = int(data["size"]["height"])
data_width = int(data["size"]["width"])
return data_height, data_width
被调用过的函数 --- 将xml解析为字典模式
def parse_xml_to_dict(self, xml):
"""
将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
Args:
xml: xml tree obtained by parsing XML file contents using lxml.etree
Returns:
Python dictionary holding XML contents.
"""
if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
3.关于json文件
如何在Python中优雅地处理JSON文件 - 知乎
JSON结构看起来和Python中的字典非常类似。需要注意的是,JSON格式通常是由key:<value> 结对组成,其中key是字符串形式,value是字符串、数字、布尔值、数组、对象或null。
为了更直观的进行说明,在下图中我们以蓝色突出显示了所有的key,同时以橙色突出显示了所有的value。
请注意,以下每组key/value间均使用逗号进行区分。
首先我们需要导入 json库, 接着我们使用open函数来读取JSON文件,最后利用json.load()函数将JSON字符串转化为Python字典形式.
4.提取出相应json文件的每个类别以及对应区域
注:中文的时候encoding=‘gbk’
import json
import torch
with open('test.json',encoding="gbk") as f:
json_dict = json.load(f)
#print(type(json_dict))
data = json_dict['shapes']
for data_ in data:
#print(data_)
#print(data[0]['label'])
#print(data[0]['points'])
label=data_['label']
xmin=float(data_['points'][0][0])
xmax=float(data_['points'][1][0])
ymin=float(data_['points'][0][1])
ymax=float(data_['points'][1][1])
print(label , xmin , xmax , ymin , ymax)
4.不支持png格式预测
使用了几乎一样的
jpg文件基本都没问题
png文件没有成功的。