基于PyTorch搭建FasterRCNN实现目标检测

news2025/1/9 1:34:46

基于PyTorch搭建FasterRCNN实现目标检测

1. 图像分类 vs. 目标检测

图像分类是一个我们为输入图像分配类标签的问题。例如,给定猫的输入图像,图像分类算法的输出是标签“猫”。

在目标检测中,我们不仅对输入图像中存在的对象感兴趣。我们还对它们在输入图像中的位置感兴趣。从这个意义上说,目标检测超越了图像分类。

1.1 图像分类与目标检测:使用哪一个?

图像分类非常适合图像中只有一个对象的应用。可能有多个类(例如猫、狗等),但通常图像中只有该类的一个实例。

在大多数输入图像中有多个对象的应用中,我们需要找到对象的位置,然后对它们进行分类。在这种情况下,我们使用目标检测算法。

目标检测可能比图像分类慢数百倍。因此,在图像中对象的位置不重要的应用中,我们使用图像分类。

2. 目标检测

简单来说,目标检测是一个两步过程:

  • 查找包含对象的边界框,使得每个边界框仅包含一个对象。
  • 对每个边界框内的图像进行分类并为其分配标签。

在接下来的几节中,我们将介绍 Faster R-CNN 目标检测架构开发的步骤。

2.1 滑动窗口方法

大多数用于目标检测的经典计算机视觉技术(例如 HAAR 级联和 HOG + SVM)都使用滑动窗口方法来检测目标。

在这种方法中,滑动窗口在图像上移动。该滑动窗口内的所有像素都被裁剪掉并发送到图像分类器。

如果图像分类器识别出已知对象,则存储边界框和类标签。否则,将评估下一个窗口。

滑动窗口方法的计算量非常大。为了检测输入图像中的对象,需要在图像中的每个像素处评估不同尺度和纵横比的滑动窗口。

由于计算成本,仅当我们检测具有固定纵横比的单个对象类时才使用滑动窗口。例如,OpenCV 中基于 HOG + SVM 或 HAAR 的人脸检测器使用滑动窗口方法。有趣的是,著名的 Viola Jones 人脸检测使用滑动窗口。对于人脸检测器,复杂性是可控的,因为仅在不同尺度下评估方形边界框。

2.2 R-CNN目标检测

在基于 CNN 的方法赢得 2012 年 ImageNet 大规模视觉识别挑战赛 (ILSVRC) 后,基于卷积神经网络 (CNN) 的图像分类器开始流行。

由于每个目标检测器的核心都有一个图像分类器,因此基于 CNN 的目标检测器的发明就变得不可避免。

有两个挑战需要克服:

  • 与 HOG + SVM 或 HAAR 级联等传统技术相比,基于 CNN 的图像分类器的计算成本非常昂贵。
  • 计算机视觉社区变得越来越雄心勃勃。人们希望构建一个多类对象检测器,除了能够处理不同的尺度之外,还可以处理不同的纵横比。

研究人员开始研究训练机器学习模型的新想法,该模型可以提出包含对象的边界框的位置。这些边界框称为区域提议或对象提议。

Ross Girshick 等人提出的第一个使用区域提议的方法被称为 R-CNN(具有 CNN 特征的区域的缩写)。

他们使用一种称为“选择性搜索”的算法来检测 2000 个区域提案,并在这 2000 个边界框上运行基于 CNN + SVM 的图像分类器。

当时 R-CNN 的精度是最先进的,但速度仍然很慢(GPU 上每张图像 18-20 秒)

2.3 Fast R-CNN目标检测

在 R-CNN 中,每个边界框由图像分类器独立分类。有 2000 个区域提案,图像分类器为每个区域提案计算一个特征图。这个过程是昂贵的。

在 Ross Girshick 的后续工作中,他提出了一种名为 Fast R-CNN 的方法,可以显着加快目标检测速度。

这个想法是为整个图像计算单个特征图,而不是为 2000 个区域提案计算 2000 个特征图。对于每个候选区域,感兴趣区域(RoI)池化层从特征图中提取固定长度的特征向量。每个特征向量随后用于两个目的:

  • 将区域分类为某一类(例如狗、猫、背景)。
  • 使用边界框回归器提高原始边界框的准确性。

2.4 Faster R-CNN目标检测

在 Fast R-CNN 中,尽管共享了对 2000 个区域提案进行分类的计算,但生成区域提案的算法部分并未与执行图像分类的部分共享任何计算。

在名为 Faster R-CNN 的后续工作中,主要见解是计算区域提议和图像分类这两个部分可以使用相同的特征图,从而共享计算负载。

卷积神经网络用于生成图像的特征图,同时用于训练区域提议网络和图像分类器。由于这种共享计算,目标检测的速度有了显着的提高。

3. PyTorch实现目标检测

在本节中,我们将学习如何将 Faster R-CNN 目标检测器与 PyTorch 结合使用。我们将使用 torchvision 中包含的预训练模型。 PyTorch 中所有预训练模型的详细信息可以在 torchvision.models 中找到

3.1 输入和输出

我们将要使用的预训练 Faster R-CNN ResNet-50 模型期望输入图像张量采用 [n, c, h, w] 形式,最小尺寸为 800px,其中:

  • n 是图像数量
  • c 是通道数,对于 RGB 图像,其为 3
  • h 是图像的高度
  • w 是图像的宽度

模型将返回

  • 边界框 [x0, y0, x1, y1] 形状为 (N,4) 的所有预测类别,其中 N 是模型预测的图像中存在的类别数量。
  • 所有预测类别的标签。
  • 每个预测标签的分数。

3.2 预训练模型

使用以下代码从 torchvision 下载预训练模型:

import torchvision

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()

定义PyTorch官方文档给出的类名

	
COCO_INSTANCE_CATEGORY_NAMES = ['__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']

我们可以在列表中看到一些 N/A,因为后来的论文中删除了一些类。我们将使用 PyTorch 给出的列表。

3.3 模型预测

我们定义一个函数来获取图像路径并通过模型获得图像的预测。

from PIL import Image
from torchvision import transforms as T


def get_prediction(img_path, threshold):
    """
    get_prediction
      parameters:
        - img_path - path of the input image
        - threshold - threshold value for prediction score
    """
    img = Image.open(img_path)
    transform = T.Compose([T.ToTensor()])
    img = transform(img)
    pred = model([img])
    pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
    pred_boxes = [[(int(i[0]), int(i[1])), (int(i[2]), int(i[3]))] for i in list(pred[0]['boxes'].detach().numpy())]
    pred_score = list(pred[0]['scores'].detach().numpy())
    pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
    pred_boxes = pred_boxes[:pred_t + 1]
    pred_class = pred_class[:pred_t + 1]
    return pred_boxes, pred_class
  • 从图像路径获取图像
  • 使用 PyTorch 的 Transforms 将图像转换为图像张量
  • 图像通过模型来获得预测
  • 获得类、框坐标,但仅选择预测分数>阈值。

3.4 定义目标检测方法

接下来我们将定义一个方法来获取图像路径并获取输出图像。

import cv2
from matplotlib import pyplot as plt


def object_detection_api(img_path, threshold=0.5, rect_th=3, text_size=3, text_th=3):
    boxes, pred_cls = get_prediction(img_path, threshold)
    # Get predictions
    img = cv2.imread(img_path)
    # Read image with cv2
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # Convert to RGB
    for i in range(len(boxes)):
        cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th)
        # Draw Rectangle with the coordinates
        cv2.putText(img,pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)
        # Write the prediction class
        plt.figure(figsize=(20,30))
        # display the output image
        plt.imshow(img)
        plt.xticks([])
        plt.yticks([])
        plt.show()
  • 预测是从 get_prediction 方法获得的
  • 对于每个预测,都会绘制边界框并写入文本
    与opencv
  • 显示最终图像

3.5 运行测试

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1020475.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件工程之总体设计

总体设计是软件工程中的一个重要阶段,它关注整个系统的结构和组织,旨在将系统需求转化为可执行的软件解决方案。总体设计决定了系统的架构、模块划分、功能组织以及数据流和控制流等关键方面。 可行性研究 具体方面:经济可行性、技术可行性…

如何正确安装滚珠螺杆螺母?

在安装滚珠螺母时,相信很多人都遇到过装反这个问题,滚珠螺杆螺母是通过高精度的加工和配合实现传递转矩和运动的,如果将滚珠螺杆螺母反过来装,会导致导向槽和调整垫片位置不正确,使得整个螺杆系统的传动精度降低&#…

sketch for Mac快捷键大全

你可以在sketch中使用键盘快捷键来加快你的设计过程。要使用键盘快捷键,请同时按下下列列表的所有键。有些命令只能根据你在做什么或者你选择了什么才启用,所有把命令分成了下列不同的部分。 sketch下载地址:sketch 破解-Sketch for mac(专业…

Linux 信号集 及其 部分函数

这几个函数都是对自己自定义的信号集操作 int sigemptyset(sigset_t *set) 功能:清空信号集中的数据,将所有的标志位置为0 参数:set需要操作的信号集 返回值:成功0失败-1 int sigfillset(sigset_t *set) 功能:清空…

各个浏览器离线安装包获取方式

前言 我们一般去浏览器官网下载所谓的官方版浏览器,但是如今呢,下载的都是在线安装包,大小大约1~2MB,安装时电脑必须联网,每次都要从网络上下载。就像下面这样的: 在线安装包的运行必须有网络环境&#…

批量使用cdo 修改分辨率的方法

文件夹里有很多这种grib文件 怎么有2.3T啊,好大,一个一个改太浪费时间了 现在我想用cdo 批量他们的分辨率都降低一些,怎么做呢? find . -name *low.grib |xargs -I{} cdo remapbil,r144x72 {} {}_low.nc 思路:使用 find 命令找到…

小型网络实验组网

路漫漫其修远兮,吾将上下而求索 时隔多日,没有更新,今日一写,倍感教育的乐趣。如果让我每天发无意义的文章,我宁可不发。 实验拓扑 实验要求 (1)内网主机采用DHCP分配IP地址 (2&…

6-3 pytorch使用GPU训练模型

深度学习的训练过程常常非常耗时,一个模型训练几个小时是家常便饭,训练几天也是常有的事情,有时候甚至要训练几十天。 训练过程的耗时主要来自于两个部分,一部分来自数据准备,另一部分来自参数迭代。 当数据准备过程还…

分享一下微信公众号怎么添加砸金蛋链接

一、砸金蛋活动的优势 砸金蛋活动是一种非常有趣且吸引人的互动方式,在微信公众号中添加砸金蛋链接有以下优势: 提高用户参与度:砸金蛋活动能够激发用户的参与度和好奇心,让用户感到有乐趣和刺激。通过砸金蛋的方式,…

现在全国融资融券两融利率最低是多少?哪家证券公司券商费率低?

融资融券是指投资者通过向券商借入资金(融资)或借入证券(融券),以达到获得更高收益、降低交易风险、提高资金利用效率的目的。通过融资,投资者可以用借入的资金买入更多的证券;通过融券&#xf…

乐器商城小程序开发全攻略

随着互联网的普及和电子商务的快速发展,越来越多的人开始通过在线购物来满足自己的需求。而乐器作为一种特殊的商品,其在线销售市场也在不断扩大。为了满足这一需求,许多乐器商家开始开发自己的小程序商城,以提供更加便捷、高效的…

python使用SMTP发送邮件

SMTP是发送邮件的协议,Python内置对SMTP的支持,可以发送纯文本邮件、HTML邮件以及带附件的邮件。 Python对SMTP支持有smtplib和email两个模块,email负责构造邮件,smtplib负责发送邮件。 首先,我们来构造一个最简单的…

[BJDCTF2020]Mark loves cat foreach导致变量覆盖

这里我们着重了解一下变量覆盖 首先我们要知道函数是什么 foreach foreach (iterable_expression as $value)statement foreach (iterable_expression as $key > $value)statement第一种格式遍历给定的 iterable_expression 迭代器。每次循环中,当前单元的值被…

184_Python 在 Excel 和 Power BI 绘制堆积瀑布图

184_Python 在 Excel 和 Power BI 绘制堆积瀑布图 一、背景 在 2023 年 8 月 22 日 微软 Excel 官方宣布:在 Excel 原生内置的支持了 Python。博客原文 笔者第一时间就更新到了 Excel 的预览版,通过了漫长等待分发,现在可以体验了&#xf…

微信生态全场景方案

微信生态全场景方案 微信生态场景复杂,如何实现快速接入? 企业拥有跨平台数据,平台间数据割裂,如何实现各业务线数据整合? 借助身份云平台可快速接入微信生态全场景,轻松打通微信生态、电商平台、第三方平台…

prometheus 告警

prometheus 告警 1, prometheus 告警简介 告警能力在Prometheus的架构中被划分成两个独立的部分。如下所示,通过在Prometheus中定义AlertRule(告警规则),Prometheus会周期性的对告警规则进行计算,如果满足告警触发条件就会向Alertmanager发送告警信息。 在Prometheus中一…

基于Java建筑装修图纸管理平台设计实现(源码+lw+部署文档+讲解等)

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

(高阶)Redis 7 第13讲 数据双写一致性 canal篇

面试题 问题答案如何保证mysql改动后,立即同步到Rediscanal 简介 https://github.com/alibaba/canal/wikihttps://github.com/alibaba/canal/wiki 基于 MySQL 数据库增量日志解析,提供增量数据订阅和消费 业务 数据库镜像数据库实时备份多级索引 (卖家和买家各自分库索引…

【springMvc】自定义注解的使用方式

🎬 艳艳耶✌️:个人主页 🔥 个人专栏 :《Spring与Mybatis集成整合》 ⛺️ 生活的理想,为了不断更新自己 ! 1.前言 1.1.什么是注解 Annontation是Java5开始引入的新特征,中文名称叫注解。 它提供了一种安全…

【Java并发】聊聊死锁

什么是死锁 死锁出现的条件主要是资源互斥、占有并等待、非抢占、循环等待。 当出现两个线程对不同的资源进行获取的时候,A持有资源1,去获取资源2,B持有资源2,去获取资源1,就回出现死锁。 如何排查死锁 public cla…