Paddle实现单目标检测

news2025/1/16 1:38:41

单目标检测

单目标检测(Single Object Detection)是人工智能领域中的一个重要研究方向,旨在通过计算机视觉技术,识别和定位图像中的特定目标物体。单目标检测可以应用于各种场景,如智能监控、自动驾驶、医疗影像分析等。

简单来说,单目标检测就是在确定一个目标在图片中的位置:

检测亮起的信号灯在图像中的位置

 本文将以信号灯检测为例,介绍单目标检测的方法

环境准备

这个案例需要安装以下两个库:

pip install paddlepaddle-gpu
pip install lxml

数据集准备

本文采用如下数据集:红绿灯检测_练习_训练集(非比赛数据)_数据集-飞桨AI Studio星河社区 (baidu.com)

这个数据集共有2000张信号灯的照片,其中1000张绿灯,1000张红灯。每张照片都对应着一个xml文件,标注着信号灯在图片中的位置:

<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<annotation>
    <folder>Images</folder>
    <filename>green_0.jpg</filename>
    <source>
        <database>Unknown</database>
    </source>
    <size>
        <width>424</width>
        <height>240</height>
        <depth>3</depth>
    </size>
    <segmented>0</segmented>
    <object>
        <name>green</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <occluded>0</occluded>
        <bndbox>
            <xmin>247</xmin>
            <ymin>147</ymin>
            <xmax>301</xmax>
            <ymax>190</ymax>
        </bndbox>
    </object>
</annotation>

这里面,<width>和<height>标签分别定义了宽和高,<name>定义了样本的类别(red或者green),<bndbox>里的标签则是定义了信号灯的位置(矩形框)

接下来我们编写dataset.py,用于定义数据集类:

import paddle
import glob
from lxml import etree
from PIL import Image  
import numpy as np 
  
# 定义一个字典,将颜色名称映射到ID  
name_to_id = {'red': 0, 'green': 1}  
  
# 将绝对坐标转换为相对坐标  
def to_labels(path):  
    # 读取XML文件内容  
    text = open(f'{path}').read().encode('utf8')  
    # 解析XML内容  
    xml = etree.HTML(text)  
    # 提取图像的宽度和高度  
    width = int(xml.xpath('//size/width/text()')[0])  
    height = int(xml.xpath('//size/height/text()')[0])  
    # 提取边界框的坐标  
    xmin = int(xml.xpath('//bndbox/xmin/text()')[0])  
    xmax = int(xml.xpath('//bndbox/xmax/text()')[0])  
    ymin = int(xml.xpath('//bndbox/ymin/text()')[0])  
    ymax = int(xml.xpath('//bndbox/ymax/text()')[0])  
    # 将绝对坐标转换为相对坐标  
    return xmin / width, ymin / height, xmax / width, ymax / height  
  
  
# 定义一个PaddlePaddle数据集类  
class Dataset(paddle.io.Dataset):  
    def __init__(self, pos='training_data'):  
        super().__init__()  # 调用父类构造函数  
        # 查找指定目录下的所有.jpg图片和.xml标签文件  
        self.imgs = glob.glob(f'{pos}/*.jpg')  
        self.labels = glob.glob(f'{pos}/*.xml')  
  
    def __getitem__(self, idx):  
        # 根据索引获取图片和标签  
        img = self.imgs[idx]  
        label = to_labels(self.labels[idx])  
        # 打开图片并转换为RGB模式  
        pil_img = Image.open(img).convert('RGB')  
        # 将PIL图片转换为numpy数组,并转换为float32类型  
        # 同时将通道顺序从HWC转换为CHW(PaddlePaddle默认输入格式)  
        t = paddle.to_tensor(np.array(pil_img, dtype=np.float32).transpose((2, 0, 1)))  
        # 返回图片张量和标签张量  
        return t, paddle.to_tensor(label[:4])  
  
    def __len__(self):  
        # 返回数据集中图片的数量  
        return len(self.imgs)

训练脚本

单目标检测可以看作一个回归问题,输出4个值,用于确定目标的坐标,因此我们可以使用resnet,并指定其类别数量为4(即输出4个值),并采用MSE损失函数(因为这是回归问题),据此,可以写出训练脚本的代码:

import paddle  
from dataset import Dataset  
  
# 初始化Dataset实例,设置数据位置为'training_data'  
dataset = Dataset(pos='training_data')  
  
# 使用ResNet18网络结构,并设置输出类别数为4  
net = paddle.vision.resnet18(num_classes=4)  
# 将网络封装为PaddlePaddle的Model对象  
model = paddle.Model(net)  
  
# 准备模型训练,包括优化器(Adam)和损失函数(均方误差损失)  
model.prepare(  
    paddle.optimizer.Adam(parameters=model.parameters()),  
    paddle.nn.MSELoss(),  
)  
  
# 训练模型,设置训练轮数为160,批处理大小为16 
model.fit(dataset, epochs=160, batch_size=16, verbose=1)  
  
# 保存模型到'output/model'路径  
model.save('output/model')

可以看到,训练脚本还是非常简单的。

简单使用

使用脚本也很简单:

import matplotlib.pyplot as plt  
import matplotlib.patches as patches  
import numpy as np  
from PIL import Image  
import paddle  
  
# 图片路径  
img_path = 'testing_data/red_1003.jpg' 
# 打开图片并转换为RGB格式  
pil_img = Image.open(img_path).convert('RGB')  
# 将PIL图片转换为Paddle Tensor,并调整通道顺序  
t = paddle.to_tensor([np.array(pil_img, dtype=np.float32).transpose((2, 0, 1))])  
  
# 加载ResNet18模型,并设置为4个类别  
net = paddle.vision.resnet18(num_classes=4)  
model = paddle.Model(net)  
# 加载训练好的模型权重  
model.load('output/model')  
  
# 预测图片  
pred = model.predict_batch(t)[0][0]  
print(f'预测结果:{pred}')  
  
# 根据预测结果计算边界框坐标  
xmin = float(pred[0]) * 424  
ymin = float(pred[1]) * 240  
xmax = float(pred[2]) * 424  
ymax = float(pred[3]) * 240  
  
# 显示原始图片  
plt.imshow(np.array(t[0], dtype=np.int32).transpose((1, 2, 0)))  
  
# 定义多边形的顶点坐标(这里是预测的边界框)  
vertices = np.array([[xmin, ymin], [xmin, ymax], [xmax, ymax], [xmax, ymin]])  
# 创建一个多边形对象,用于绘制边界框  
polygon = patches.Polygon(vertices, closed=True, edgecolor='black', facecolor='none')  
# 将多边形添加到当前坐标轴上  
plt.gca().add_patch(polygon)  
# 显示图片和边界框  
plt.show()

输出:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1796801.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

03_初识Spring Cloud Gateway

文章目录 一、网关简介1.1 网关提出的背景1.2 网关在微服务中的位置1.3 网关的技术选型1.4 补充 二、Spring Cloud Gateway的简介2.1 核心概念&#xff1a;路由&#xff08;Route&#xff09;2.2 核心概念&#xff1a;断言&#xff08;Predicate&#xff09;2.3 核心概念&#…

Python怎么发邮件不会被拦?如何设置信息?

Python发邮件的注意事项&#xff1f;Python发邮件需要哪些库&#xff1f; 使用Python发送电子邮件是一个常见的需求。然而&#xff0c;有时候邮件可能会被拦截&#xff0c;要确保发送的邮件不被拦截&#xff0c;需要一些技巧和注意事项。AokSend将介绍如何使用Python发送邮件&…

stm32中如何实现EXTI线 0 ~ 15与对应IO口的配置呢?

STM32的EXTI控制器支持19 个外部中断/ 事件请求。每个中断设有状态位&#xff0c;每个中断/ 事件都有独立的触发和屏蔽设置。 STM32的19个外部中断对应着19路中断线&#xff0c;分别是EXTI_Line0-EXTI_Line18&#xff1a; 线0~15&#xff1a;对应外部 IO口的输入中断。 线16&…

十年数据分析经验分享

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

小熊家务帮day10-day12 门户管理(缓存,主页,定时任务)

门户管理 1 门户介绍1.1 介绍1.2 常用技术方案 2 缓存技术方案2.1 需求分析2.1.1 C端用户界面原型2.1.2 缓存需求2.1.3 使用的工具 2.2 项目基础使用2.2.1 项目集成SpringCache2.2.2 测试Cacheable需求Service测试 2.1.3 缓存管理器&#xff08;设置过期时间&#xff09;2.1.4 …

我的python管理

目前环境 Anaconda&#xff1a;python3.9 python2.7 IDA&#xff1a;python3.8 pycharm&#xff1a;&#xff1f;&#xff1f; 以后应该会补吧… 因为某些文件似乎用的python2决定整个python2 安装python2.7 打开anaconda命令行输入 conda create --name python27 python2…

RuoYi 使用达梦数据库 MySQL迁移达梦数据库

达梦数据库使用 达梦数据库安装路径&#xff1a;/home/aite/dmdbms 达梦数据库版本 RuoYi-Vue V3.8.7版本达梦数据库目录说明 cd /home/aite/dmdbms ls -l总用量 80 drwxr-xr-x 10 aite aite 12288 5月 31 14:41 bin drwxr-xr-x 2 aite aite 4096 5月 31 14:37 bin2 drwx…

计算机网络-OSI七层参考模型与数据封装

目录 一、网络 1、网络的定义 2、网络的分类 3、网络的作用 4、网络的数据传输方式 5、网络的数据通讯方式 二、OSI七层参考模型 1、网络参考模型定义 2、分层的意义 3、分层与功能 4、TCP\IP五层模型 三、参考模型的协议 1、物理层 2、数据链路层 3、网络层 4…

【Endnote】如何在word界面加载Endnote

如何在word界面加载Endnote 方法1&#xff1a;方法2&#xff1a;从word入手方法3&#xff1a;从CWYW入手参考 已下载EndNote,但Word中没有显示EndNote&#xff0c;应如何加载显示呢&#xff1f; 方法1&#xff1a; 使用EndNote的Configure EndNote.exe 。 具体步骤为&#x…

使用MFC DLL

本文仅供学习交流&#xff0c;严禁用于商业用途&#xff0c;如本文涉及侵权请及时联系本人将于及时删除 应用程序与DLL链接后&#xff0c;DLL才能通过应用程序调用运行。应用程序与DLL链接的方式主要有如下两种&#xff1a;隐式链接和显式链接。 隐式链接又称为静态加载&…

pw备份问题

1、手动build&#xff0c;dn gs_ctl build -D /database/panweidb/data 或 gs_ctl build -D /database/panweidb/data -b full 2、拉起2节点cm cm_ctl start -n 2 3、启动cm_server服务 cm_ctl start 4、 netstat -anop|grep 17700 5、

代码随想录——删除二叉搜索树中的节点(Leetcode450)

题目链接 递归 /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}* TreeNode(int val) { this.val val; }* TreeNode(int val, TreeNode left, TreeNode right) {* …

在k8s中部署Logstash多节点示例(超详细讲解)

&#x1f407;明明跟你说过&#xff1a;个人主页 &#x1f3c5;个人专栏&#xff1a;《洞察之眼&#xff1a;ELK监控与可视化》&#x1f3c5; &#x1f516;行路有良友&#xff0c;便是天堂&#x1f516; 目录 一、引言 1、Logstash简介 2、在K8s中部署Logstash多节点实例…

每日题库:Huawe数通HCIA——13

所有资料均来源自网络&#xff0c;但个人亲测有效&#xff0c;特来分享&#xff0c;希望各位能拿到好成绩&#xff01; PS&#xff1a;别忘了一件三连哈&#xff01; 今日题库&#xff1a; 186. 下列协议中属于动态IGP路由协议的是&#xff1f;-单选 A.stA.tiC. B.OSPF c…

GraphQL(3):参数类型与参数传递

1 基本参数类型 &#xff08;1&#xff09;基本类型:String,Int,Float,Boolean和ID。可以在shema声明的时候直接使用。 &#xff08;2&#xff09;[类型]代表数组&#xff0c;例如:[int]代表整型数组 2 参数传递 &#xff08;1&#xff09;和js传递参数一样&#xff0c;小括…

Sentinel1.8.6更改配置同步到nacos(项目是Gateway)

本次修改的源码在&#xff1a;https://gitee.com/stonic-open-source/sentinel-parent 一 下载源码 地址&#xff1a;https://github.com/alibaba/Sentinel/releases/tag/1.8.6 二 导入idea&#xff0c;等待maven下载好各种依赖 三 打开sentile-dashboard这个模块&#xf…

Linux 服务查询命令(包括 服务器、cpu、数据库、中间件)

Linux 服务查询命令&#xff08;包括 服务器、cpu、数据库、中间件&#xff09; Linux获取当前服务器ipLinux使用的是麒麟版本还是cenos版本Linux获取系统信息Linux获取CPU 的详细信息Linux查询nignx版本(非容器) Linux获取当前服务器ip hostname -ILinux使用的是麒麟版本还是…

SFML 小demo

文章目录 项目搭建代码实现main.cppobject.hsnake.hcommon.h 使用 demo 做到最后的话其实就只是验证了以前自己的一个想法&#xff0c;但是没有做成一个真正的游戏&#xff0c;可以算是一个 demo 而已吧&#xff0c;没做游戏的界面和关卡&#xff0c;不过完成了核心显式机制和功…

Day32 实现登录注册接口服务

​ 本章节,实现登录和注册接口服务 一.完善登录注册接口 完善登录和注册接口,对登录明文密码获取到MD5 字符串后,进行对比校验或注册明文密码进行MD5 加密后再插入到数据库。在MyToDo.Shared 项目中创建一个Extensions 文件夹,并创建一个 StringExtensions 静态扩展类,其中…

Redis学习(十二)Redis的三种删除策略

目录 一、背景二、Redis 的三种删除策略2.1 定时删除&#xff08;用CPU换内存空间&#xff09;2.2 定期删除2.3 惰性删除&#xff08;用内存换CPU性能&#xff09; 三、总结 一、背景 我们都知道 Redis 是一种内存数据&#xff0c;所有的数据均存储在内存中&#xff0c;可以通…