Pytorch实现R-CNN系列目标检测网络

news2024/11/24 13:56:43

在PyTorch提供的已经训练好的图像目标检测中,均是R-CNN系列的网络,并且针对目标检测和人体关键点检测分别提供了容易调用的方法。针对目标检测的网络,输入图像均要求使用相同的预处理方式,即先将每张图像的像素值预处理到0 ~1之间,且输入的图像尺寸不是很小即可直接调用。已经预训练的可供使用的网络模型如下表所示。

网络类描述
detection.fasterrcnn_resnet50_fpn具有Resnet-50-FPN的Fast R-CNN网络模型
detection.maskrcnn_resnet50_fpn具有Resnet-50-FPN结构的Mask R-CNN网络模型
detection.keypointrcnn_resnet50_fpn具有Resnet-50-FPN结构的Keypoint R-CNN网络模型

这些网络同样是在COCO 2017数据集上进行训练的。

1.图像目标检测

在进行图像目标检测时,使用已经预训练好的具有ResNet-50-FPN结构的FastR-CNN网络模型,该网络同样是通过COCO数据集进行预训练,导入已预训练的网络,程序如下所示:

import numpy as np
import torchvision
import torch
import torchvision.transforms as transforms
from PIL import Image,ImageDraw,ImageFont
import matplotlib.pyplot as plt

model=torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()

下面从文件夹中读取一张照片,并将其转化为张量,像素值在0~1之间,然后使用导入模型对其进行预测,程序如下:

image=Image.open(r'C:\Users\zex\Downloads\VOCdevkit\VOC2012\JPEGImages\2012_001460.jpg')
transform_d=transforms.Compose([transforms.ToTensor()])
image_t=transform_d(image)#对图像进行变换
pred=model([image_t])

 在pred输出的结果中主要包括三种值,分别是检测到每个目标的边界框( boxes坐标)、目标所属的类别(labels),以及属于相应类别的得分( scores )。从上面的输出结果中可以发现,找到的目标约有21个,但仅前5个目标得分大于0.5。下面将检测到的目标可视化,并观察检测的具体结果。

首先定义每个类别所对应的标签COCO_INSTANCE_CATEGORY_NAMES,程序如下:

COCO_INSTANCE_CATEGORY_NAMES=[
    '__background__','person','bicycle','car','motorcycle',
    'airplane','bus','train','truck','boat','traffic light',
    'fire hydrant','N/A','stop sign','parking meter','bench',
    'bird','cat','dog','horse','sheep','cow','elephant',
    'bear','zebra','giraffe','N/A','backpack','umbrella','N/A',
    'N/A','handbag','tie','suitcase','frisbee','skis','snowboard',
    'surfboard','tennis racket','bottle','N/A','wine glass',
    'cup','fork','knife','spoon','bowl','banana','apple',
    'sandwich','orange','broccoli','carrot','hot dog','pizza',
    'donut','cake','chair','couch','potted plant','bed','N/A',
    'dining table','N/A','N/A','toilet','N/A','tv','laptop',
    'mouse','remote','keyboard','cell phone','microwave','oven',
    'toaster','sink','refrigerator','N/A','book','clock',
    'vase','scissors','teddy bear','hair drier','toothbrush'
]

针对预测的结果,在可视化之前,需要分别将有效的预测目标数据解读出来,需要提取的信息有每个目标的位置、类别和得分,然后将得分大于0.5的目标作为检测到的有效目标,并将检测到的目标在图像上显示出来,程序如下:

#检测出目标的类别和得分
pred_class=[COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
pred_score=list(pred[0]['scores'].detach().numpy())
#检测出目标的边界框
pred_boxes=[[i[0],i[1],i[2],i[3]] for i in list(pred[0]['boxes'].detach().numpy())]
#只保留识别的概率大于0.5的结果
pred_index=[pred_score.index(x) for x in pred_score if x > 0.5]
#设置图像显示的字体
fontsize=np.int16(image.size[1] / 30)
font1=ImageFont.truetype(r'E:\PythonWorkSpace\pytorch_project\pytorch_demo\SegmentDetection\华文细黑.ttf',fontsize)
#可视化图像
draw=ImageDraw.Draw(image)
for index in pred_index:
    box=pred_boxes[index]
    draw.rectangle(box,outline='red')
    texts=pred_class[index]+':'+str(np.round(pred_score[index],2))
    draw.text((box[0],box[1]),texts,fill='red',font=font1)
image.show()

 上面的程序在可视化图像时,使用ImageDraw.Draw(image)方法,表示要在原始的image图像上相应的位置添加一些元素,draw.rectangle()表示要添加矩形框,draw.text()表示在图像上指定位置添加文本。运行程序后,可得到下图所示的目标检测结果。

2.人体关键点检测

人体骨骼关键点检测主要检测人体的一些关键点,如关节、五官等,通过关键点描述人体骨骼信息。MS COCO数据集是多人人体关键点检测数据集,具有关键点个数为17,图像的样本数多于30万张,也是目前的相关研究中最常用的数据集。在torchvision库中,提供了已经在MS COCO数据集上预训练的keypointrcnn_resnet50_fpn()网络模型,该网络可以用于人体的关键点检测。先导入预训练好的网络模型,程序如下所示:

import torch
import torchvision

model=torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True)
model.eval()

 因为该网络的预测输出结果中会有目标检测的结果,即每个人的关键点检测结果。下面先导入目标类别标签和17个关键点的标签,程序如下:

COCO_INSTANCE_CATEGORY_NAMES=[
    '__background__','person','bicycle','car','motorcycle',
    'airplane','bus','train','truck','boat','traffic light',
    'fire hydrant','N/A','stop sign','parking meter','bench',
    'bird','cat','dog','horse','sheep','cow','elephant',
    'bear','zebra','giraffe','N/A','backpack','umbrella','N/A',
    'N/A','handbag','tie','suitcase','frisbee','skis','snowboard',
    'surfboard','tennis racket','bottle','N/A','wine glass',
    'cup','fork','knife','spoon','bowl','banana','apple',
    'sandwich','orange','broccoli','carrot','hot dog','pizza',
    'donut','cake','chair','couch','potted plant','bed','N/A',
    'dining table','N/A','N/A','toilet','N/A','tv','laptop',
    'mouse','remote','keyboard','cell phone','microwave','oven',
    'toaster','sink','refrigerator','N/A','book','clock',
    'vase','scissors','teddy bear','hair drier','toothbrush'
]
COCO_PERSON_KEYPOINT_NAMES=['nose','left_eye','right_eye','left_ear','right_ear',
                            'left_shoulder','right_shoulder','left_elbow','right_elbow',
                            'left_wrist','right_wrist','left_hip','right_hip','left_knee',
                            'right_knee','left_ankle','right_ankle']

17个关键点分别是鼻子、左眼、右眼、左耳朵、右耳朵、左肩、右肩、左胳膊肘、右胳膊肘、左手腕、右手腕、左臀、右臀、左膝、右膝、左脚踝和右脚踝,分别使用1~17标号表示。
下面从文件夹中读取一张图像,并对该图像中的人物目标和关键点进行预测,程序如下所示:

image=Image.open(r"C:\Users\zex\Desktop\3.29兼职\person.png")
transforms_d=transforms.Compose([transforms.ToTensor()])
image_t=transforms_d(image)
pred=model([image_t])
print(pred)

 上面的程序对图像进行预测后在pred的结果中包含以下内容:

(1)boxes:检测出目标的位置。

(2)labels:检测出目标的分类。

(3) scores:检测出目标为对应分类的得分

(4) keypoints:检测出N个实例中每个实例的K个关键位置,其中每个点的数据格式为[x,y, visibility],如果visibility =0,表示关键点不可见。

(5) keypoints__scores:表示每个关键点的相应得分。

从输出的检测结果中发现,图像中检测出了三个目标,但并不是每个目标得分都很高,下面先可视化得分高于0.5的目标,程序如下所示:

#检测出目标的类别和得分
pred_classes=[COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
pred_score=list(pred[0]['scores'].detach().numpy())
#检测出目标的边界框
pred_boxes=[[i[0],i[1],i[2],i[3]] for i in list(pred[0]['boxes'].detach().numpy())]
#只保留识别的概率大于0.5的结果
pred_index=[pred_score.index(x) for x in pred_score if x > 0.5]
#设置图像显示的字体
fontsize=np.int16(image.size[1] / 30)
font1=ImageFont.truetype(r'E:\PythonWorkSpace\pytorch_project\pytorch_demo\SegmentDetection\华文细黑.ttf')
#可视化图像
image2=image.copy()
draw=ImageDraw.Draw(image2)
for index in pred_index:
    box=pred_boxes[index]
    draw.rectangle(box,outline='red')
    texts=pred_classes[index]+':'+str(np.round(pred_score[index],2))
    draw.text((box[0],box[1]),texts,fill='red',font=font1)
image2.show()

下面可视化出该人物和网络检测到的关键点位置,程序如下所示:

pred_index=[pred_score.index(x) for x in pred_score if x >0.5]
pred_keypoint=pred[0]['keypoints']
#检测到实例的关节点
pred_keypoint=pred_keypoint[pred_index].detach().numpy()
#可视化出关键点的位置
fontsize=np.int16(image.size[1] /50)
r=np.int16(image.size[1] /150)#圆的半径
font1=ImageFont.truetype(r'E:\PythonWorkSpace\pytorch_project\pytorch_demo\SegmentDetection\华文细黑.ttf',fontsize)
#可视化图像
image3=image.copy()
draw=ImageDraw.Draw(image3)
#对实例数量索引
for index in range(pred_keypoint.shape[0]):
    keypoints=pred_keypoint[index]
    for i in range(keypoints.shape[0]):
        x=keypoints[i,0]
        y=keypoints[i,1]
        visi=keypoints[i,2]
        if visi>0:
            draw.ellipse(xy=(x-r,y-r,x+r,y+r),fill=(255,0,0))
            texts=str(i+1)
            draw.text((x+r,y-r),texts,fill='red',font=font1)
image3.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/425316.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flink 优化 (一) --------- 资源配置调优

目录一、内存设置1. TaskManager 内存模型2. 生产资源配置示例二、合理利用 cpu 资源1. 使用 DefaultResourceCalculator 策略2. 使用 DominantResourceCalculator 策略3 使用 DominantResourceCalculator 策略并指定容器 vcore 数三、并行度设置1. 全局并行度计算2. Source 端…

和猿辅导国奖选手的妈妈聊聊:数学新生代的成长之路

2023年第64届IMO中国国家队名单公布,来自猿辅导的学员王淳稷、孙启傲在此次国家队选拔赛中总成绩排名分列第一、第二,将于今年7月代表中国奔赴日本参加IMO竞赛。 值得一提的是,孙启傲同学继入选2022年IMO国家集训队、获阿里巴巴全球数学竞赛…

ubuntu(20.04)-shell脚本(2)echo-date-awk-sed-iptables-shell变量数组

1.echo 语法:echo [-ne][字符串]补充说明: 1、echo会将输入的字符串送往标准输出。 2、输出的字符串间以空白字符隔开,并在最后加上换行号。OPTIONS: -n 不要在最后自动换行 -e 若字符串中出现以下字符,则特别加以处理,而不会将它当成一般文…

【学习时序论文】

目录【2021 NeurIPS】Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting【2022 ICML】FEDformer: Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting【2023 ICLR】TIMESNET: TEMPORAL 2D-VARIATION …

deque,stack,quque容器

一、deque 1.基本概念 功能: 双端数组,可以对头端进行插入删除操作 deque与vector区别: vector对于头部的插入删除效率低,数据量越大,效率越低. deque相对而言,对头部的插入删除速度会比vector快 vector访问元素时的速度会比de…

NDK编译脚本

一、如何通过NDK进行编译。 1、新建jni文件夹,并将Android.mk、Applicatio n.mk、源文件都放入其中。 2、编写Android.mk文件。 LOCAL_PATH : $(call my-dir) include $(CLEAR_VARS) LOCAL_MODULE: test LOCAL_C_ALL_FILES : test.c LOCAL_SRC_FILES : $(LOCAL_C_…

centos7虚拟机在集群zookeeper上面配置hbase的具体操作步骤

系列文章目录 centos7配置静态网络常见问题归纳_centos7网络问题 centos7克隆虚拟机完成后的的一些配置介绍 虚拟机centos7配置Hadoop单节点伪分布配置教程 卸载centos7自带的jdk的操作步骤 centos7配置zookeeper本地模式与集群模式的详细教程 centos7虚拟机配置集群时间…

HTML引入Typescript编译JS文件 :Uncaught ReferenceError: exports is not defined

初学TypeScript,尝试在html引入ts编译出来的js文件: 报错:Uncaught ReferenceError: exports is not defined 以下是代码: 创建了TS:加入export {}形成独立的作用域,其他ts文件重复声明相同名称的变量。 export {} let str &…

Python和Java二选一该学啥?

首先我们需要了解Python和 Java分别是什么 根据IEEE Spectrum 2022年编程语言排名前十的分别是:Python,C,C,C#,Java,SQL,JavaScript,R,HTML,TypeScript。从该…

专访丨AWS量子网络中心科学家Antía Lamas谈量子计算

​ Anta Lamas Linares(图片来源:网络) 47岁的Anta Lamas Linares出生于西班牙西北部的圣地亚哥德孔波斯特拉。她在当地学习物理学,然后在牛津大学和加利福尼亚继续深造。后来,她在新加坡领导了亚马逊网络服务&#xf…

Java中线程的常用操作-后台线程、自定义线程工厂ThreadFactpry、join加入一个线程、线程异常捕获

场景 Java中Thread类的常用API以及使用示例: Java中Thread类的常用API以及使用示例_霸道流氓气质的博客-CSDN博客 上面讲了Thread的常用API,下面记录下线程的一些常用操作。 注: 博客:霸道流氓气质的博客_CSDN博客-C#,架构之…

Doris(4):建表

可以通过在mysql-client中执行以下 help 命令获得更多帮助: help create table 1 基本概念 在 Doris 中,数据都以表(Table)的形式进行逻辑上的描述。 1.1 Row & Column 一张表包括行(Row)和列&#…

从零开始:如何集成美颜SDK到你的应用中

现在,随着人们对于美的追求不断提升,美颜应用已经成为了人们生活中不可或缺的一部分。在应用中,美颜功能的实现离不开美颜SDK的支持。那么,如何集成美颜SDK到你的应用中呢?下面,我们就来一步步了解。 第一…

Linux复习 / 线程相关----线程互斥 QA梳理

文章目录前言线程互斥Q:什么是临界资源?临界区呢?Q:什么是互斥?Q:数据不一致的本质是什么?Q:用锁对共享资源进行保护的前提是:锁也要作为共享资源被其他线程使用。那么用…

独家 | 招商银行:玩转校园招聘新方式 挖掘金融科技新人才

数字经济时代,金融科技人才队伍的引进与培养是招商银行人才体系建设的关键任务。 01.金融科技校招2大核心课题 招商银行数字化转型过程中,线上化、生态化、平台化、智能化、数据化全面加速发展,对人才队伍能力提出新要求。 2大核心课题&am…

Git的一些使用

虽然说这也不是啥重要的内容,但是作为计算机人也得学学,了解了解。 一些预备内容 首先得下载git,这个就不多说了。 安装完了之后,首先要做的就是设置用户名称和邮箱地址,因为每次Git提交都会使用该信息,…

I.MX6ULL_Linux_驱动篇(33) pinctrl与gpio子系统

上一章我们编写了基于设备树的 LED 驱动,但是驱动的本质还是没变,都是配置 LED 灯所使用的 GPIO 寄存器,驱动开发方式和裸机基本没啥区别。 Linux 是一个庞大而完善的系统,尤其是驱动框架,像 GPIO 这种最基本的驱动不可…

Linux实战学习

文章目录一、Linux权限信息权限控制信息chmodifconfigpingnmap netstatps killzip unzip常用快捷键二、搭建Java环境yumJDKTomcatMysql三、部署Web项目到服务器一、Linux权限信息 Linux中,拥有最大权限的账户为: root(超级管理员),而普通用户在很多地方…

UWB成为智慧工厂时代的代表技术

UWB成为智慧工厂时代的代表技术 随着智慧工厂的到来,在人员安全问题较为重要的行业中,为了避免人员安全事故的出现,各家企业都逐步装备了UWB定位系统。UWB信号的辐射非常低,通常只有手机辐射的千分之一,因此在工业上应…

【 Spring MVC 核心功能(二) - 获取参数(上)】

文章目录一、获取单个参数二、获取多个参数三、获取对象四、后端参数重命名4.1 使用 RequestParam 重命名参数4.2 RequestParam 中参数必传4.3 设置非必传参数五、使用 PathVariable 获取URL中参数一、获取单个参数 在 Spring MVC 中可以直接⽤⽅法中的参数来实现传单个参&…