Faster-RCNN代码解读2:快速上手使用

news2025/1/16 5:46:56

Faster-RCNN代码解读2:快速上手使用

前言

​ 因为最近打算尝试一下Faster-RCNN的复现,不要多想,我还没有厉害到可以一个人复现所有代码。所以,是参考别人的代码,进行自己的解读。

代码来自于B站的UP主(大佬666),其把代码都放到了GitHub上了,我把链接都放到下面了(应该不算侵权吧,毕竟代码都开源了_):

b站链接:https://www.bilibili.com/video/BV1of4y1m7nj/?vd_source=afeab8b555e5eb1bfa1e7f267262cbf2

GitHub链接:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing

目的

​ 其实UP主已经做了很好的视频讲解了他的代码,只是有时候我还是喜欢阅读博客来学习,另外视频很长,6个小时,我看的时候容易睡着_,所以才打算写博客记录一下学习笔记。

目前完成的内容

第一篇:VOC数据集详细介绍

第二篇:Faster-RCNN代码解读2:快速上手使用(本文)

目录结构

文章目录

    • Faster-RCNN代码解读2:快速上手使用
      • 1. 前言:
      • 2. 下载项目代码:
      • 3. 下载数据和权重文件:
      • 4. predict.py文件解读:
      • 5. pascal_voc_classes.json文件介绍:
      • 6. 快速上手:
      • 7. 总结:

1. 前言:

​ 本篇文章的作用是准备好一些必备的数据或权重文件,以实现直接快速使用代码的目的。

2. 下载项目代码:

​ 打开大佬的GitHub链接,然后,进入pytorch_object_detection文件内:

在这里插入图片描述

​ 然后,把Faster-RCNN文件夹下载下来即可。不过,GitHub本身不支持单个文件夹的下载,这时候推荐一下浏览器的插件GitZip for github ,把这个插件安装后,即可下载单独的文件夹,如下图所示:

在这里插入图片描述

​ 下载完成后的目录结构如下:

在这里插入图片描述

3. 下载数据和权重文件:

​ 打开README.md文件,里面说明了预训练权重文件和数据集的下载地址:

  • ResNet50+FPN权重文件下载:
官方的权重文件:https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth

up主自己训练后的权重地址:
https://pan.baidu.com/s/1ifilndFRtAV5RDZINSHj5w 提取码:dsz8
  • 数据集下载地址:
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

​ 完成上述下载后,可以得到下图的文件:

在这里插入图片描述

4. predict.py文件解读:

​ 打开predict.py文件,这个文件的作用就是加载已经训练过的模型,对一张图片进行目标检测。

main函数:

​ 看main函数,主要分为四个部分:

  • 设置权重文件路径(需要我们改的参数),并用模型加载:
# 选定GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

# 创建模型:21=20个类别+1个背景
model = create_model(num_classes=21)

# 加载权重参数
# weights_path = "./save_weights/model.pth"   # 权重保存路径,作者自己定义的
weights_path = "./fasterrcnn_voc2012.pth"   # 权重保存路径,我们下载后自己的路径
assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
# 开始加载权重文件
weights_dict = torch.load(weights_path, map_location='cpu')  # 加载之前训练保存的字典
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict # 选定model参数
model.load_state_dict(weights_dict) # 加载
model.to(device) # 放入GPU
  • 读取“类别----数字值”的json文件,并生成一个字典,以方便后期将预测的类别(比如:1、2这样的数字)转为字符串(比如:person、bicycle等)
# 读取json文件
label_json_path = './pascal_voc_classes.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
with open(label_json_path, 'r') as f:
class_dict = json.load(f)
# 将值转为字典
category_index = {str(v): str(k) for k, v in class_dict.items()}
  • 加载一张测试图片(需要改为我们自己的),并改为符合[batch,channel,w,h]的格式
# 加载一张测试图片
original_img = Image.open("./test.jpg") # 需要改为自己的路径

# 将PIL图像格式转为tensor格式
data_transform = transforms.Compose([transforms.ToTensor()])
img = data_transform(original_img)
# 增加一个batch维度,符合训练图片格式
img = torch.unsqueeze(img, dim=0)
  • 开始预测图片,并计算运行的时间(一般不计算第一次的时间,因为GPU调用需要时间)和画出对应图像
model.eval()  # 进入验证模式
with torch.no_grad():
    # init
    # 初始化,原始图像的宽、高
    img_height, img_width = img.shape[-2:]
    # 将图像放入GPU中,并变为model可以识别的格式[batch_size,channel,w,h]
    init_img = torch.zeros((1, 3, img_height, img_width), device=device)
    # 验证
    model(init_img)

    # 计算预测时间,不过不能直接计算第一次,因为需要启动gpu等
    t_start = time_synchronized()
    predictions = model(img.to(device))[0]
    t_end = time_synchronized()
    print("inference+NMS time: {}".format(t_end - t_start))

    # 得到预测的相关参数
    predict_boxes = predictions["boxes"].to("cpu").numpy()
    predict_classes = predictions["labels"].to("cpu").numpy()
    predict_scores = predictions["scores"].to("cpu").numpy()

    if len(predict_boxes) == 0:
    	print("没有检测到任何目标!")

    # 绘制图像
    plot_img = draw_objs(original_img,
                             predict_boxes,
                             predict_classes,
                             predict_scores,
                             category_index=category_index,
                             box_thresh=0.5,
                             line_thickness=3,
                             font='arial.ttf',
                             font_size=20)
    plt.imshow(plot_img)
    plt.show()
    # 保存预测的图片结果
    plot_img.save("test_result.jpg")

create_model函数

​ 了解了main函数后,我们再看看create_model函数,这个函数的作用就是创建模型。作者在该项目中采取了很多模型,比如VGG16、mobilenetv2、resnet等等,而这里我们用的是刚刚下载的权重文件对应的模型,即resNet50+fpn+faster-rcnn,因此需要把其它的模型代码注释掉:

def create_model(num_classes):
    # mobileNetv2+faster_RCNN
    # backbone = MobileNetV2().features
    # backbone.out_channels = 1280
    #
    # anchor_generator = AnchorsGenerator(sizes=((32, 64, 128, 256, 512),),
    #                                     aspect_ratios=((0.5, 1.0, 2.0),))
    #
    # roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
    #                                                 output_size=[7, 7],
    #                                                 sampling_ratio=2)
    #
    # model = FasterRCNN(backbone=backbone,
    #                    num_classes=num_classes,
    #                    rpn_anchor_generator=anchor_generator,
    #                    box_roi_pool=roi_pooler)

    # resNet50+fpn+faster_RCNN
    # 注意,这里的norm_layer要和训练脚本中保持一致
    backbone = resnet50_fpn_backbone(norm_layer=torch.nn.BatchNorm2d)
    model = FasterRCNN(backbone=backbone, num_classes=num_classes, rpn_score_thresh=0.5)

    return model

5. pascal_voc_classes.json文件介绍:

​ 我们再看看上面涉及json文件,这个文件就是voc数据集的类别和数字值的对应关系,比如:

{
    "aeroplane": 1,
    "bicycle": 2,
    "bird": 3,
    "boat": 4,
    "bottle": 5,
    "bus": 6,
    "car": 7,
    "cat": 8,
    "chair": 9,
    "cow": 10,
    "diningtable": 11,
    "dog": 12,
    "horse": 13,
    "motorbike": 14,
    "person": 15,
    "pottedplant": 16,
    "sheep": 17,
    "sofa": 18,
    "train": 19,
    "tvmonitor": 20
}

​ 需要注意的是,这里的值是从1开始的,是因为0一般是留给背景的。

6. 快速上手:

​ 有了上面的解读后,我们可以快速上手看看效果。

​ 这里再次声明一下predict.py文件需要修改**权重文件路径和自己搞一张测试图片并修改路径。**完成修改后,直接运行该文件即可,我测试了几张图片,结果如下图:

在这里插入图片描述
在这里插入图片描述

7. 总结:

​ 上面主要简单介绍了如何快速上手,看到结果,给自己一种这个很简单的错觉。后面,主要就是对一些主要的文件进行解读。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/415718.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

中国电子学会2023年03月份青少年软件编程Scratch图形化等级考试试卷四级真题(含答案)

2023-03 Scratch四级真题 分数:100 题数:24 测试时长:90min 一、单选题(共10题,共30分) 1.编写一段程序,从26个英文字母中,随机选出10个加入列表a。空白处应填入的代码是?(C&am…

Flink (十二) --------- Flink CEP

目录一、基本概念1. CEP 是什么2. 模式 (Pattern)3. 应用场景二、快速上手1. 需要引入的依赖2. 一个简单实例三、模式 API(Pattern API)1. 个体模式2. 组合模式3. 模式组4. 匹配后跳过策略四、模式的检测处理1. 将模式应用到流上2. 处理匹配事件3. 处理超…

【高项】项目整体管理、范围管理与进度管理(十大管理)

【高项】项目整体管理与范围管理 文章目录1、项目整体管理1.1 整体管理的过程1.2 制定项目章程(启动)1.3 制订项目管理计划(规划)1.4 指导与管理项目执行(执行)1.5 监控项目工作与实施整体变更控制&#xf…

Systemverilog中operators和expression的记录

1. Equality operators Equality operators有三种: Logical equality:, !,该运算符中如果运算数包含有x/z态,那么结果就是x态。只有在两边的bit都不包含x/z态,最终结果才会为0(False)或1(True)Case equality&#xf…

中云盾DDoS云防护系统

中云盾 DDoS 防护系统作为公司级网络安全产品,为各类业务提供专业可靠的 DDoS/CC 攻击防护。在黑客攻防对抗日益激烈的环境下, DDoS 对抗不仅需要 “降本” 还需要 “增效”。 为什么上云? 云原生作为近年来相当热门的概念,无论…

RHCE-NTP、SSH服务器

1.配置ntp时间服务器,确保客户端主机能和服务主机同步时间​ 服务器端: (1)首先安装chrony软件: dnf install -y chrony (2)配置时间同步源: 进入vim /etc/chrony.conf &#xf…

引用和指针

总结 引用: 因为引用是变量的别名,所以引用必须初始化 因为引用不存在自己的地址,所以指针不能指向引用,即不能定义引用的指针 因为引用不是对象,但是引用又要绑定一个对象,所以不能定义引用的引用 in…

一篇文章看懂C++三大特性——多态的定义和使用

目录 前文 一,什么是多态? 1.1 多态的概念 二, 多态的定义及实现 2.1 多态的构成条件 2.2 虚函数 2.3 虚函数的重写 2.3.1 虚函数重写的两个例外 2.4 C override 和 final 2.5 重载,重写(覆盖),隐藏(重定义)的区别 三,抽…

代码随想录刷题-双指针总结篇

文章目录双指针移除元素习题我的解法双指针优化反转字符串习题我的解法剑指 Offer 05. 替换空格习题我的解法正确解法反转字符串里的单词习题我的解法反转链表习题我的解法删除链表的倒数第 N 个节点习题我的解法相交链表习题我的解法环形链表 II习题我的解法三数之和习题我的解…

Unity VFX -- (3)创建环境粒子系统

粒子系统中最常用也最重要的一种使用场景是实现天气效果。只需要做很少修改,场景就能很快从蓝天白云变成雪花飘舞。 和之前看到的粒子系统从一个源头发出粒子的情况不同,天气效果完全围绕着场景。 新增和放置一个新的粒子系统 为了创建下雨或下雪的天气…

【从零开始学Skynet】基础篇(三):服务模块常用API

1、服务模块 Skynet提供了开启服务和发送消息的API,必须要先掌握它们。列出了Skynet中8个最重要的API,PingPong程序会用到它们。 Lua API说明newservice(name, ...) 启动一个名为 name 的新服务,并返回服务的地址。 start(func) …

【学习笔记】unity脚本学习(二)(Time时间体系、Random随机数、Mathf数学运算)

目录Time时间体系timeScalemaximumDeltaTimefixedDeltaTimecaptureDeltaTimedeltaTime整体展示Random随机数Mathf数学运算IMathf.Round()Mathf.Ceil() Mathf.CeilToInt()Mathf.SignMathf.ClampMathf数学运算II-曲线变换Lerp 线性插值LerpAngleSmoothDamp疑问:Smooth…

自己动手写编译器:DFA跳转表的压缩算法

在编译器开发体系中有两套框架,一个叫"lex && yacc", 另一个名气更大叫llvm,这两都是开发编译器的框架,我们只要设置好配置文件,那么他们就会生成相应的编译器代码,通常是c或者c代码,然后…

AI自动寻路AStar算法【图示讲解原理】

文章目录AI自动寻路AStar算法背景AStar算法原理AStar寻路步骤AStar具体寻路过程AStar代码实现运行结果AI自动寻路AStar算法 背景 AI自动寻路的算法可以分为以下几种: 1、A*算法:A*算法是一种启发式搜索算法,它利用启发函数(heu…

Jmeter接口测试和性能测试

目前最新版本发展到5.0版本,需要Java7以上版本环境,下载解压目录后,进入\apache-jmeter-5.0\bin\,双击ApacheJMeter.jar文件启动JMemter。 1、创建测试任务 添加线程组,右击测试计划,在快捷菜单单击添加-…

STM32F103RCT6驱动SG90舵机-完成正反转角度控制

一、SG90舵机介绍 SG90是一种微型舵机,也被称为伺服电机。它是一种小型、低成本的直流电机,通常用于模型和机器人控制等应用中。SG90舵机可以通过电子信号来控制其精确的位置和速度。它具有体积小、重量轻、响应快等特点,因此在各种小型机械…

亚马逊测评只能下单上好评?卖家倾向养号测评还有这些骚操作

亚马逊测评这对于绝大部分亚马逊卖家来说都不陌生,如今的亚马逊市场也很多卖家都在用测评科技来打造爆款。不过很多对于亚马逊测评的认知只停留在简单的刷销量,上好评。殊不知亚马逊养号测评还有其它强大的骚操作。 亚马逊自养号测评哪些功能呢&#xf…

PyTorch 深度学习实战 |用 TensorFlow 训练神经网络

为了更好地理解神经网络如何解决现实世界中的问题,同时也为了熟悉 TensorFlow 的 API,本篇我们将会做一个有关如何训练神经网络的练习,并以此为例,训练一个类似的神经网络。我们即将看到的神经网络,是一个预训练好的用…

【深度学习】【分布式训练】Collective通信操作及Pytorch示例

相关博客 【深度学习】【分布式训练】Collective通信操作及Pytorch示例 【自然语言处理】【大模型】大语言模型BLOOM推理工具测试 【自然语言处理】【大模型】GLM-130B:一个开源双语预训练语言模型 【自然语言处理】【大模型】用于大型Transformer的8-bit矩阵乘法介…

第02章_变量与运算符

第02章_变量与运算符 讲师:尚硅谷-宋红康(江湖人称:康师傅) 官网:http://www.atguigu.com 本章专题与脉络 1. 关键字(keyword) 定义:被Java语言赋予了特殊含义,用做专门…