Faster-RCNN代码解读3:制作自己的数据加载器

news2025/1/16 4:49:12

Faster-RCNN代码解读3:制作自己的数据加载器

前言

​ 因为最近打算尝试一下Faster-RCNN的复现,不要多想,我还没有厉害到可以一个人复现所有代码。所以,是参考别人的代码,进行自己的解读。

代码来自于B站的UP主(大佬666),其把代码都放到了GitHub上了,我把链接都放到下面了(应该不算侵权吧,毕竟代码都开源了_):

b站链接:https://www.bilibili.com/video/BV1of4y1m7nj/?vd_source=afeab8b555e5eb1bfa1e7f267262cbf2

GitHub链接:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing

目的

​ 其实UP主已经做了很好的视频讲解了他的代码,只是有时候我还是喜欢阅读博客来学习,另外视频很长,6个小时,我看的时候容易睡着_,所以才打算写博客记录一下学习笔记。

目前完成的内容

第一篇:VOC数据集详细介绍

第二篇:Faster-RCNN代码解读2:快速上手使用

第三篇:Faster-RCNN代码解读3:制作自己的数据加载器(本文)

目录结构

文章目录

    • Faster-RCNN代码解读3:制作自己的数据加载器
      • 1. 前言:
      • 2. my_dataset.py文件解读:
        • 2.1 init方法:
        • 2.2 len方法:
        • 2.3 getitem方法:
        • 2.4 辅助方法:get_height_and_width
        • 2.5 辅助方法:parse_xml_to_dict
        • 2.6 辅助方法:coco_index
      • 3. 总结:

1. 前言:

​ 其实这个部分还是比较简单的(如果你看过我前面的图像分类加载器实现或者自己实现过),就是定义一个dataset类。

2. my_dataset.py文件解读:

​ 我们知道,想要定义自己的dataset类,首先需要继承于torch的Dataset类,并且至少需要定义三个方法,即__init____len____getitem__

​ 那么,可以写出大体框架:

class VOCDataSet(Dataset):
    """读取解析PASCAL VOC2007/2012数据集"""

    def __init__(self):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
		pass

​ 好的,下面我们来一一实现。

2.1 init方法:

​ 首先,需要定义我们的输入参数,这里如果是自己从头实现的话,估计需要想到什么参数用参数。但是,我们解读的话,就直接看作者定义了哪些参数:

  • voc_root: 数据集所在的根目录
  • year: 指定读取2007还是2012的数据集,默认为2012
  • transforms: 预处理方法,默认为None
  • txt_name: 指定加载训练集还是测试集,默认为训练集,即train.txt

​ 接下来,第一步,增加一下代码的容错能力,就是判断一下传入的参数正不正确,并拼接出需要的路径:

# 判断是不是2007或2012,否则报错
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
# 增加容错能力
if "VOCdevkit" in voc_root:
    # 如果传入的参数为:.\VOCdevkit,那么直接拼接为.\VOCdevkit\VOC2012
    self.root = os.path.join(voc_root, f"VOC{year}")
else:
    # 如果传入的参数为:. ,那么直接拼接为.\VOCdevkit\VOC2012
    self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
# 拼接路径,即图片路径和注释路径
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")

​ 第二步,读取数据集.\VOCdevkit\VOC2012\ImageSets\Main里面的训练集或测试集txt文件(如果你不知道这里面为什么的话,可以看第一篇文章,VOC数据集介绍),并将里面的值和后缀xml拼接为训练集或测试集的注释文件:

# 读取train或者val文件
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
# 然后,将文件名(2007_000027)和后缀拼接在一起,这样才是真实的文件
with open(txt_path) as read:
    xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                        for line in read.readlines() if len(line.strip()) > 0]

​ 第三步,需要一一读取xml文件,并将里面的内容转为字典值,主要目的是检查一下xml文件是否有问题:

# 定义真正的xml列表
self.xml_list = []
# 检测所有xml文件是否存在并读取内容
for xml_path in xml_list:
    if os.path.exists(xml_path) is False:
        print(f"Warning: not found '{xml_path}', skip this annotation file.")
        continue
    # 如果xml文件存在,继续下面的代码
    # check for targets
    # 读取xml文件
    with open(xml_path) as fid:
    	xml_str = fid.read()
    # 构建xml对象
    xml = etree.fromstring(xml_str)
    # 获取节点的内容,并转为字典值
    data = self.parse_xml_to_dict(xml)["annotation"] # 获取annotation节点下的所有内容
    if "object" not in data: # 判断object节点是否存在,如果不存在说明xml文件其实有问题,所以需要跳过
        print(f"INFO: no objects in {xml_path}, skip this annotation file.")
        continue
    # 添加
    self.xml_list.append(xml_path)

​ 第四步,加载类别json文件,并读取里面的内容:

# 读取类别文件,一共20个类,从1开始是因为0留给背景
json_file = './pascal_voc_classes.json'
assert os.path.exists(json_file), "{} file not exist.".format(json_file)
with open(json_file, 'r') as f:
	self.class_dict = json.load(f)

​ 最后,将预处理函数放入一个变量中:

self.transforms = transforms

​ **总结一下:**经过上面的处理,我们得到了几个主要的变量:

  • self.xml_list:里面的值为一个个训练集或测试集的xml文件,里面的值为文件路径值
  • self.transforms:里面为我们的预处理方法
  • self.class_dict:为我们的类别字典,里面的值为{‘preson’:2}这样的形式

​ 给大家看看,debug下的值的内容:

在这里插入图片描述

2.2 len方法:

​ len方法,这个是最简单的方法,其作用就是返回长度值:

def __len__(self):
    # len函数就是返回长度
    return len(self.xml_list)

2.3 getitem方法:

​ 这个方法和init方法一样十分重要,其作用就是获取图像和图像对应的标签等信息。

def __getitem__(self, idx):
	pass

​ 其中idx是这个方法必备的一个参数,其是随机返回一个索引值,来方便你取你之前在init方法定义的变量里的值。

​ 那么,首先,获取一个xml文件,并打开它获取根节点里面的内容:

# 随机读取一个xml文件
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
	xml_str = fid.read()
# 创建xml对象
xml = etree.fromstring(xml_str)
# 获取根节点,转为字典值
data = self.parse_xml_to_dict(xml)["annotation"]

​ 这里解释一下上面的data值为啥。其实就是xml文件annotation节点里的所有内容,如下图框出来的内容:

在这里插入图片描述

​ 当然,同样用debug看看里面真实情况下的值:

在这里插入图片描述

​ 然后,**我们知道xml文件名和图片名是对应的,**因此通过xml文件获取图片名字并打开这个图像:

# 获取xml文件对应的图像路径
img_path = os.path.join(self.img_root, data["filename"])
# 打开图像
image = Image.open(img_path)
# 判断图像是否为jpeg格式,主要作者防止别人插入了其它的文件
if image.format != "JPEG":
	raise ValueError("Image '{}' format not JPEG".format(img_path))

​ 接着,初始化一些变量:

# 初始化一些变量
boxes = []		# 边界框
labels = []		# 标签值
iscrowd = []	# 是否为难以识别的图像

下面开始是最重要的内容

​ 首先,迭代读取xml文件object节点下的内容:

# 读取xml文件中object节点下的内容
for obj in data["object"]:

​ 其中的,obj为下图中的值:

在这里插入图片描述

​ 或者可以从xml文件中对应查看:

在这里插入图片描述

​ 接着,获取对象的真实边界框的坐标值(左上角,右下角):(ps:下面的代码都是放在上面的for循环里面的)

# 获取bbox框的坐标
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])

​ 检测一下,边界框是否有问题:

# 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
if xmax <= xmin or ymax <= ymin:
    print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
    continue

​ 然后,把坐标值加入boxes变量中,把标签加入labels变量中,并判断图像是否为难以识别的,然后加入iscrowd变量中:

boxes.append([xmin, ymin, xmax, ymax])
# 添加标签  obj["name"]=person,  self.class_dict[obj["name"]] = 15
labels.append(self.class_dict[obj["name"]])
# 判断是否为difficult类型
if "difficult" in obj:
    iscrowd.append(int(obj["difficult"]))
    else:
        iscrowd.append(0)

​ 然后,把所有的变量类型都转为tensor格式(此时已经结束了循环):

# 将所有的类型转为tensor类型
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])

​ 接着,根据边框框的四个坐标,计算一下边界框的面积,主要方便后期计算IOU:

#  boxes =[[,,,],[,,,],。。。。。。]
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# (ymax - ymin) * (xmax - xmin) ,即框的面积

​ 最后,把上面的所有值放入一个字典变量中即可:

# 把这些东西放入一个字典中
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd

​ 然后,对图像进行预处理并返回图像和其对应的值即可:

# 变换,此时为自己实现的方法,不是官方的方法
if self.transforms is not None:
	image, target = self.transforms(image, target)
return image, target

​ 最后,我们在debug下看看变量的值:

在这里插入图片描述

2.4 辅助方法:get_height_and_width

​ 作用:获取图像的宽和高。

​ 这个十分简单,就是通过xml文件来获取的,还不需要我们自己通过坐标计算:

def get_height_and_width(self, idx):
    # 获取图像的宽和高
    # 读取xml
    xml_path = self.xml_list[idx]
    with open(xml_path) as fid:
		xml_str = fid.read()
    # 构建xml对象
    xml = etree.fromstring(xml_str)
    # 获取根节点
    data = self.parse_xml_to_dict(xml)["annotation"]
    # 获取宽和高
    data_height = int(data["size"]["height"])
    data_width = int(data["size"]["width"])
    return data_height, data_width

2.5 辅助方法:parse_xml_to_dict

​ 主要作用:将xml格式的数据解析为字典格式,即将节点-----节点的值,转为{‘节点’:‘节点的值’}。

​ 这个方法是通过递归来实现的,这个没什么好说的,如果你想搞清楚如何运行的,可以自己一步一步的推导:

def parse_xml_to_dict(self, xml):
    """
    将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
    """

    if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
        # xml.tag节点名字
        # xml.text里面的值
        return {xml.tag: xml.text}

    result = {}
    # 对于每个xml中的子节点
    for child in xml:
        child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
        if child.tag != 'object':
	        result[child.tag] = child_result[child.tag]
        else:
            if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

2.6 辅助方法:coco_index

这个方法与getitem方法是相同的作用,只是不读取图片,流程都是一样的,我就不细说了。

3. 总结:

​ my_dataset.py文件主要实现了数据加载器的类,实现思路很简单,但是代码量还是比较大的。

​ 另外,作者在该文件的末尾展示了一下这个类的使用示例代码,大家可以直接把注释取消运行看看结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/415723.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Node【三】Buffer 与 Stream

文章目录&#x1f31f;前言&#x1f31f;Buffer&#x1f31f; Buffer结构&#x1f31f; 什么时候用Buffer&#x1f31f; Buffer的转换&#x1f31f; Buffer使用&#x1f31f; 创建Buffer&#x1f31f; 字符串转Buffer&#x1f31f; Buffer转字符串&#x1f31f; 拼接Buffer&…

python 理解BN、LN、IN、GN归一化、分析torch.nn.LayerNorm()和torch.var()工作原理

目录 前言&#xff1a; 简言之BN、LN、IN、GN等归一化的区别&#xff1a; 批量归一化(Batch Normalization&#xff0c;BN) 优点 缺点 计算过程 层归一化(Layer Normalization&#xff0c;LN) 优点 计算过程 总结 分析torch.nn.LayerNorm()工作原理 分析torch.var(…

Vue2-黑马(十一)

目录&#xff1a; &#xff08;1&#xff09;vue2-联调准备 &#xff08;2&#xff09;vue2-登录实战-国际化 &#xff08;3&#xff09;vue2实战-登录-login-index.vue &#xff08;1&#xff09;vue2-联调准备 登录这个请求&#xff0c;并不是发给后台的&#xff0c;现在还…

浙大MBA提面申请材料的三六九等……

每年浙大MBA项目提前批面试申请的每个批次中都会有部分材料因为某些原因而被淘汰&#xff0c;无缘面试资格。考生们由最初的不理解到逐渐隐约的理解&#xff0c;行至今日也可以大体接受材料被刷这个结果&#xff0c;当然其中含有一部分面上资质背景还可以的考生&#xff0c;等到…

Faster-RCNN代码解读2:快速上手使用

Faster-RCNN代码解读2&#xff1a;快速上手使用 前言 ​ 因为最近打算尝试一下Faster-RCNN的复现&#xff0c;不要多想&#xff0c;我还没有厉害到可以一个人复现所有代码。所以&#xff0c;是参考别人的代码&#xff0c;进行自己的解读。 ​ 代码来自于B站的UP主&#xff08;…

中国电子学会2023年03月份青少年软件编程Scratch图形化等级考试试卷四级真题(含答案)

2023-03 Scratch四级真题 分数&#xff1a;100 题数&#xff1a;24 测试时长&#xff1a;90min 一、单选题(共10题&#xff0c;共30分) 1.编写一段程序&#xff0c;从26个英文字母中&#xff0c;随机选出10个加入列表a。空白处应填入的代码是&#xff1f;&#xff08;C&am…

Flink (十二) --------- Flink CEP

目录一、基本概念1. CEP 是什么2. 模式 (Pattern)3. 应用场景二、快速上手1. 需要引入的依赖2. 一个简单实例三、模式 API&#xff08;Pattern API&#xff09;1. 个体模式2. 组合模式3. 模式组4. 匹配后跳过策略四、模式的检测处理1. 将模式应用到流上2. 处理匹配事件3. 处理超…

【高项】项目整体管理、范围管理与进度管理(十大管理)

【高项】项目整体管理与范围管理 文章目录1、项目整体管理1.1 整体管理的过程1.2 制定项目章程&#xff08;启动&#xff09;1.3 制订项目管理计划&#xff08;规划&#xff09;1.4 指导与管理项目执行&#xff08;执行&#xff09;1.5 监控项目工作与实施整体变更控制&#xf…

Systemverilog中operators和expression的记录

1. Equality operators Equality operators有三种&#xff1a; Logical equality&#xff1a;, !&#xff0c;该运算符中如果运算数包含有x/z态&#xff0c;那么结果就是x态。只有在两边的bit都不包含x/z态&#xff0c;最终结果才会为0(False)或1(True)Case equality&#xf…

中云盾DDoS云防护系统

中云盾 DDoS 防护系统作为公司级网络安全产品&#xff0c;为各类业务提供专业可靠的 DDoS/CC 攻击防护。在黑客攻防对抗日益激烈的环境下&#xff0c; DDoS 对抗不仅需要 “降本” 还需要 “增效”。 为什么上云&#xff1f; 云原生作为近年来相当热门的概念&#xff0c;无论…

RHCE-NTP、SSH服务器

1.配置ntp时间服务器&#xff0c;确保客户端主机能和服务主机同步时间​ 服务器端&#xff1a; &#xff08;1&#xff09;首先安装chrony软件&#xff1a; dnf install -y chrony &#xff08;2&#xff09;配置时间同步源&#xff1a; 进入vim /etc/chrony.conf &#xf…

引用和指针

总结 引用&#xff1a; 因为引用是变量的别名&#xff0c;所以引用必须初始化 因为引用不存在自己的地址&#xff0c;所以指针不能指向引用&#xff0c;即不能定义引用的指针 因为引用不是对象&#xff0c;但是引用又要绑定一个对象&#xff0c;所以不能定义引用的引用 in…

一篇文章看懂C++三大特性——多态的定义和使用

目录 前文 一&#xff0c;什么是多态&#xff1f; 1.1 多态的概念 二&#xff0c; 多态的定义及实现 2.1 多态的构成条件 2.2 虚函数 2.3 虚函数的重写 2.3.1 虚函数重写的两个例外 2.4 C override 和 final 2.5 重载,重写(覆盖),隐藏(重定义)的区别 三&#xff0c;抽…

代码随想录刷题-双指针总结篇

文章目录双指针移除元素习题我的解法双指针优化反转字符串习题我的解法剑指 Offer 05. 替换空格习题我的解法正确解法反转字符串里的单词习题我的解法反转链表习题我的解法删除链表的倒数第 N 个节点习题我的解法相交链表习题我的解法环形链表 II习题我的解法三数之和习题我的解…

Unity VFX -- (3)创建环境粒子系统

粒子系统中最常用也最重要的一种使用场景是实现天气效果。只需要做很少修改&#xff0c;场景就能很快从蓝天白云变成雪花飘舞。 和之前看到的粒子系统从一个源头发出粒子的情况不同&#xff0c;天气效果完全围绕着场景。 新增和放置一个新的粒子系统 为了创建下雨或下雪的天气…

【从零开始学Skynet】基础篇(三):服务模块常用API

1、服务模块 Skynet提供了开启服务和发送消息的API&#xff0c;必须要先掌握它们。列出了Skynet中8个最重要的API&#xff0c;PingPong程序会用到它们。 Lua API说明newservice(name, ...) 启动一个名为 name 的新服务&#xff0c;并返回服务的地址。 start(func) …

【学习笔记】unity脚本学习(二)(Time时间体系、Random随机数、Mathf数学运算)

目录Time时间体系timeScalemaximumDeltaTimefixedDeltaTimecaptureDeltaTimedeltaTime整体展示Random随机数Mathf数学运算IMathf.Round()Mathf.Ceil() Mathf.CeilToInt()Mathf.SignMathf.ClampMathf数学运算II-曲线变换Lerp 线性插值LerpAngleSmoothDamp疑问&#xff1a;Smooth…

自己动手写编译器:DFA跳转表的压缩算法

在编译器开发体系中有两套框架&#xff0c;一个叫"lex && yacc", 另一个名气更大叫llvm&#xff0c;这两都是开发编译器的框架&#xff0c;我们只要设置好配置文件&#xff0c;那么他们就会生成相应的编译器代码&#xff0c;通常是c或者c代码&#xff0c;然后…

AI自动寻路AStar算法【图示讲解原理】

文章目录AI自动寻路AStar算法背景AStar算法原理AStar寻路步骤AStar具体寻路过程AStar代码实现运行结果AI自动寻路AStar算法 背景 AI自动寻路的算法可以分为以下几种&#xff1a; 1、A*算法&#xff1a;A*算法是一种启发式搜索算法&#xff0c;它利用启发函数&#xff08;heu…

Jmeter接口测试和性能测试

目前最新版本发展到5.0版本&#xff0c;需要Java7以上版本环境&#xff0c;下载解压目录后&#xff0c;进入\apache-jmeter-5.0\bin\&#xff0c;双击ApacheJMeter.jar文件启动JMemter。 1、创建测试任务 添加线程组&#xff0c;右击测试计划&#xff0c;在快捷菜单单击添加-…