基于 PyTorch 的目标检测和跟踪(无敌版)

news2024/10/8 19:45:16

一个不知名大学生,江湖人称菜狗
original author: jacky Li
Email : 3435673055@qq.com

 Time of completion:2023.2.1
Last edited: 2023.2.1

目录

图像中的目标检测

视频中的目标跟踪

作者有言


在文章《基于 PyTorch 的图像分类器》中,介绍了如何在 PyTorch 中使用您自己的图像来训练图像分类器,然后使用它来进行图像识别。本篇文章中,我将向您展示如何使用预训练的分类器检测图像中的多个对象,然后在视频中跟踪它们。

图像分类(识别)和目标检测分类之间有什么区别?在分类中,识别图像中的主要对象,然后通过单个类对整个图像进行分类。在检测中,在图像中识别多个对象,并对其进行分类,同时确定一个位置

图像中的目标检测

目标检测有几种算法,其中 YOLO 和 SSD 是最流行的。对于本篇文章,我们将尝试使用 YOLOv3。

这里的 YOLO 检测代码是基于 Erik Lindernoren 对 Joseph Redmon 和 Ali Farhadi 论文(论文链接:https://pjreddie.com/media/files/papers/YOLOv3.pdf)的实现。我们首先导入所需的模块:


from models import *
from utils import *import os, sys, time, datetime, random
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variableimport matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

然后我们加载预训练模型的配置参数和权重,以及训练 Darknet 模型的 COCO 数据集的类名。和在 PyTorch 中一样,不要忘记在加载后将模型设置为 eval 模式。


config_path='config/yolov3.cfg'
weights_path='config/yolov3.weights'
class_path='config/coco.names'
img_size=416
conf_thres=0.8
nms_thres=0.4# Load model and weights
model = Darknet(config_path, img_size=img_size)
model.load_weights(weights_path)
model.cuda()
model.eval()
classes = utils.load_classes(class_path)
Tensor = torch.cuda.FloatTensor

上面还有一些预定义的值:图像尺寸(416px) ,Confidenc 阈值和 NMS 阈值。

下面的基本函数将返回对指定图像的检测。请注意,它需要一个 Pillow 读取的图像作为输入。大多数代码处理图像大小调整到416 * 416 大小的正方形,同时保持它的长宽比和填充。实际的检测是在最后4行。


def detect_image(img):
    # scale and pad image
    ratio = min(img_size/img.size[0], img_size/img.size[1])
    imw = round(img.size[0] * ratio)
    imh = round(img.size[1] * ratio)
    img_transforms=transforms.Compose([transforms.Resize((imh,imw)),
         transforms.Pad((max(int((imh-imw)/2),0),
              max(int((imw-imh)/2),0), max(int((imh-imw)/2),0),
              max(int((imw-imh)/2),0)), (128,128,128)),
         transforms.ToTensor(),
         ])
    # convert image to Tensor
    image_tensor = img_transforms(img).float()
    image_tensor = image_tensor.unsqueeze_(0)
    input_img = Variable(image_tensor.type(Tensor))
    # run inference on the model and get detections
    with torch.no_grad():
        detections = model(input_img)
        detections = utils.non_max_suppression(detections, 80,
                        conf_thres, nms_thres)
    return detections[0]

最后,让我们通过加载一个图像,获得检测值,然后在检测到的对象周围显示它与边框一起。同样,这里的大多数代码处理缩放和填充图像,以及为每个检测到的类获取不同的颜色。

# load image and get detections
img_path = "images/blueangels.jpg"
prev_time = time.time()
img = Image.open(img_path)
detections = detect_image(img)
inference_time = datetime.timedelta(seconds=time.time() - prev_time)
print ('Inference Time: %s' % (inference_time))# Get bounding-box colors
cmap = plt.get_cmap('tab20b')
colors = [cmap(i) for i in np.linspace(0, 1, 20)]img = np.array(img)
plt.figure()
fig, ax = plt.subplots(1, figsize=(12,9))
ax.imshow(img)pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape))
pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape))
unpad_h = img_size - pad_y
unpad_w = img_size - pad_xif detections is not None:
    unique_labels = detections[:, -1].cpu().unique()
    n_cls_preds = len(unique_labels)
    bbox_colors = random.sample(colors, n_cls_preds)
    # browse detections and draw bounding boxes
    for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
        box_h = ((y2 - y1) / unpad_h) * img.shape[0]
        box_w = ((x2 - x1) / unpad_w) * img.shape[1]
        y1 = ((y1 - pad_y // 2) / unpad_h) * img.shape[0]
        x1 = ((x1 - pad_x // 2) / unpad_w) * img.shape[1]
        color = bbox_colors[int(np.where(
             unique_labels == int(cls_pred))[0])]
        bbox = patches.Rectangle((x1, y1), box_w, box_h,
             linewidth=2, edgecolor=color, facecolor='none')
        ax.add_patch(bbox)
        plt.text(x1, y1, s=classes[int(cls_pred)],
                color='white', verticalalignment='top',
                bbox={'color': color, 'pad': 0})
plt.axis('off')
# save image
plt.savefig(img_path.replace(".jpg", "-det.jpg"),        
                  bbox_inches='tight', pad_inches=0.0)
plt.show()

视频中的目标跟踪

现在你已经知道如何检测图像中的不同物体了。当你在一个视频中一帧一帧地做这个动作时,你会看到那些跟踪框在移动,这样的可视化效果可能会非常酷。但是,如果这些视频帧中有多个对象,你怎么知道一帧中的对象是否与前一帧中的对象相同呢?这就是所谓的目标跟踪,并使用多个检测,以确定一个特定的对象随着时间的推移。

有几个算法可以做到这一点,我决定使用 SORT,这是非常容易使用并且速度相当快。SORT(Simple Online and Realtime Tracking)是2017年由 Alex Bewley,Zongyuan Ge,Lionel Ott,Fabio Ramos,Ben Upcroft 发表的一篇论文,该论文建议使用 Kalman 过滤器来预测先前确定的物体的轨迹,并将它们与新的检测结果进行匹配。作者 Alex Bewley 还编写了一个多功能的 Python 实现,我将在本文中使用它。

现在说到代码,前3个代码段将与单幅图像检测中的代码段相同,因为它们处理的是在单帧上获得 YOLO 检测。区别在于最后一部分,对于每个检测,我们调用 Sort 对象的 Update 函数,以获得对图像中对象的引用。因此,与前面示例中的常规检测(包括边界框的坐标和类预测)不同,我们将获得跟踪的对象,除了上面的参数,还包括一个对象 ID。然后我们以几乎相同的方式显示,但添加 ID 并使用不同的颜色,这样你就可以很容易地在视频帧中看到对象。

我还使用 OpenCV 来读取视频并显示视频帧。注意,Jupyter 处理视频的速度相当慢。您可以使用它进行测试和简单的可视化,但是我还提供了一个独立的 Python 脚本,它将读取源视频,并用跟踪的对象输出一个副本。在笔记本上播放 OpenCV 视频并不容易,所以你可以保留这段代码用于其他实验。

videopath = 'video/intersection.mp4'%pylab inline
import cv2
from IPython.display import clear_outputcmap = plt.get_cmap('tab20b')
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]# initialize Sort object and video capture
from sort import *
vid = cv2.VideoCapture(videopath)
mot_tracker = Sort()#while(True):
for ii in range(40):
    ret, frame = vid.read()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pilimg = Image.fromarray(frame)
    detections = detect_image(pilimg)    img = np.array(pilimg)
    pad_x = max(img.shape[0] - img.shape[1], 0) *
            (img_size / max(img.shape))
    pad_y = max(img.shape[1] - img.shape[0], 0) *
            (img_size / max(img.shape))
    unpad_h = img_size - pad_y
    unpad_w = img_size - pad_x
    if detections is not None:
        tracked_objects = mot_tracker.update(detections.cpu())        unique_labels = detections[:, -1].cpu().unique()
        n_cls_preds = len(unique_labels)
        for x1, y1, x2, y2, obj_id, cls_pred in tracked_objects:
            box_h = int(((y2 - y1) / unpad_h) * img.shape[0])
            box_w = int(((x2 - x1) / unpad_w) * img.shape[1])
            y1 = int(((y1 - pad_y // 2) / unpad_h) * img.shape[0])
            x1 = int(((x1 - pad_x // 2) / unpad_w) * img.shape[1])            color = colors[int(obj_id) % len(colors)]
            color = [i * 255 for i in color]
            cls = classes[int(cls_pred)]
            cv2.rectangle(frame, (x1, y1), (x1+box_w, y1+box_h),
                         color, 4)
            cv2.rectangle(frame, (x1, y1-35), (x1+len(cls)*19+60,
                         y1), color, -1)
            cv2.putText(frame, cls + "-" + str(int(obj_id)),
                        (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX,
                        1, (255,255,255), 3)    fig=figure(figsize=(12, 8))
    title("Video Stream")
    imshow(frame)
    show()
    clear_output(wait=True)

您可以使用常规的 Python 脚本进行实时处理(您可以从摄像机获取输入)和保存视频。下面是我用这个程序生成的视频样本。

视频链接:https://youtu.be/1BY2CxiMYvQ

结果如视频所示,你现在可以尝试自己检测图像中的多个对象,并通过视频帧跟踪这些对象。

作者有言

如果需要代码,请私聊博主,博主看见回。
如果感觉博主讲的对您有用,请点个关注支持一下吧,将会对此类问题持续更新……

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/191261.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网卡ID简要说明

一、概述 网卡ID标识着网卡的具体类型,由五个ID共同确认。根据这五个ID可以在公示网站查到具体的网卡型号。 1. Class id (1) 区分不同的PCI(外设)设备 (2) 网卡类型是:0200 (3) 查询网址:http://pci-ids.ucw.cz/read/PD 2. Vendor id: …

15_open_basedir绕过

open_basedir绕过 一、了解open_basedir 1. 搭建环境 在test目录下存在一个open_basedir.php的文件 里面的php代码就是简单的文件包含或者ssrf,利用的是file_get_contents函数 open_basedir也就是在这种文件包含或者ssrf访问其它文件的时候生效 然后在www目录下再新建一个t…

(隐私计算)联邦学习概述

一、是什么 概念 联邦学习(Federated Learning,FELE)是一种打破数据孤岛、释放 AI 应用潜能的分布式机器学习技术,能够让联邦学习各参与方在不披露底层数据和底层数据加密(混淆)形态的前提下,…

Unity-TCP-网络聊天功能(一): API、客户端服务器、数据格式、粘包拆包

1.TCP相关API介绍与服务端编写TCP是面向连接的。因此需要创建监听器,监听客户端的连接。当连接成功后,会返回一个TcpClient对象。通过TcpClient可以接收和发送数据。VS创建C# .net控制台应用项目中创建文件夹Net,Net 下添加TCPServer.cs类&am…

界面组件DevExtreme v22.2亮点——UI模板库升级换代!

DevExtreme拥有高性能的HTML5 / JavaScript小部件集合,使您可以利用现代Web开发堆栈(包括React,Angular,ASP.NET Core,jQuery,Knockout等)构建交互式的Web应用程序。从Angular和Reac&#xff0c…

高频链表算法

1.从尾到头打印链表值 输入一个链表的头节点,从尾到头反过来返回每个节点的值(用数组返回) 思路 (1)如果使用数组来保存反转之后的链表数据,这样只需要使用到队列或栈的知识,关键是unshif和push,reverse函数 &…

【vue2】vuex基础与五大配置项

🥳博 主:初映CY的前说(前端领域) 🌞个人信条:想要变成得到,中间还有做到! 🤘本文核心:vuex基础认识、state、getters、mutations actions、modules使用 目录(文末原素材) 一、…

【JavaEE初阶】第九节.多线程 (基础篇)定时器(案例三)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 前言 一、定时器概述、 二、定时器的实现 2.1 Java标准库 定时器的使用 2.2 自己模拟实现一个定时器 2.3 对自己实现的定时器的进一步优化 2.3.1 为何需要再进行优化 2…

CMOS图像传感器——了解光圈

在之前有提到传感器英寸,也提到了曝光三要素之一的ISO,这里主要说明另外一个曝光三要素——光圈。在本文中,我们将介绍光圈及其工作原理。 一、什么是光圈 光圈可以定义为镜头中的开口,光线通过该开口进入相机。类比眼睛是的工作原理,就容易理解了:当人在明亮和黑暗的环…

【链表之单链表】

前言:链表是什么? 链表的操作 1.单链表的结构 2.头文件的包含 3.动态申请一个节点 4.单链表打印 5.单链表尾插 6.单链表头插 7.单链表尾删 8.单链表头删 9.单链表查找 10.单链表在pos位置之后插入x 11.单链表在pos位置之前插入x 12. 单链表…

【数据挖掘】基于粒子群算法优化支持向量机PSO-SVM对葡萄酒数据集进行分类

1.粒子群算法的概念 PSO是粒子群优化算法(Particle Swarm Optimization)的英文缩写,是一种基于种群的随机优化技术,由Eberhart和Kennedy于1995年提出。粒子群算法是模仿昆虫、兽群、鸟群和鱼群等的群集行为,这些群体按…

中国电子学会2021年03月份青少年软件编程Scratch图形化等级考试试卷三级真题(含答案)

2021-03Scratch三级真题 分数:100题数:38 一、单选题(共25题,每题2分,共50分) 1.在《采矿》游戏中,当角色捡到黄金时财富值加1分,捡到钻石时财富值加2分,下面哪个程序实现这个功能&#xff1…

【软件测试】资深测试总结的测试必备8点,堪称测试人的好莱坞大片......

目录:导读前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜)前言 测试8板斧。测试8板…

Task8:Excel的数据可视化

目录一 条形图二 条件单元格格式三 迷你图四 练习题一 条形图 【例子】直观的展示销售额之间的差别 方法:【开始】–>【条件格式】–>【数据条】 【只想显示条形图,不想显示金额】 1.条形图区域—>条件格式—>管理规则 2.选择设置的规则&a…

单应性Homography梳理,概念解释,传统方法,深度学习方法

Homography 这篇博客比较清晰准确的介绍了关于刚性变换,仿射变换,透视投影变换的理解 单应性变换 的 条件和表示 用 [无镜头畸变] 的相机从不同位置拍摄 [同一平面物体] 的图像之间存在单应性,可以用 [透视变换] 表示 。 opencv单应性变换求…

Active Directory计算机备份和恢复

在Active Directory(AD)环境中,用户通过域中的计算机认证他们自身。从AD中删除这些计算机账户时,系统也会自动从域中删除它们。于是,用户不能再通过些计算机登录网络。为允许用户访问域资源,必须恢复这些已…

聚集千百个企业管理系统的API资产,打造API资产全生命周期一站式集成体验

API——接口,作为软件世界中的连接服务和传输数据的重要管道,已经成为数字时代的新型基础设施,是各领域驱动数字变革的重要力量之一。传统企业集成主要采用点对点或ESB集成方式,基于全新API战略中台的API新型集成方式通过解耦系统…

SpringBoot跨域请求解决方案详细分析

跨域的定义 跨域是指不同域名之间的相互访问,这是由浏览器的同源策略决定的,是浏览器对JavaScript施加的安全措施,防止恶意文件破坏。同源策略:同源策略是一种约定,它是浏览器最核心的也是最基本的安全策略&#xff0…

【数据产品】缓存设计

背景:为什么需要做缓存? 我所做的产品的指标设计越来越复杂,查询性能也随之下降。因此需要增加缓存层, 以提高接口查询效率。 哪些层需要做缓存? 随着指标系统的应用,该产品的查询逻辑也越来越简单&…

二分查找核心思路--单调性--极值

在最初的二分查找中,我们将一组数据按大小排序,然后根据arr[mid]与要查找的k的大小比较,从而每次去掉一半的数字,使时间复杂度简化为O(logN)。 排序本质上是让数据的单调性统一,变为单增或单减…