【PyTorch】基于YOLO的多目标检测项目(一)

news2025/1/18 17:14:25

【PyTorch】基于YOLO的多目标检测项目(一)

【PyTorch】基于YOLO的多目标检测项目(二)

目标检测是对图像中的现有目标进行定位和分类的过程。识别的对象在图像中显示有边界框。一般的目标检测方法有两种:基于区域提议的和基于回归/分类的。这里使用一种基于回归/分类的方法,称为YOLO。

目录

准备COCO数据集

创建自定义数据集

转换数据

定义数据加载器


准备COCO数据集

COCO是一个大规模的对象检测,分割和字幕数据集。它包含80个对象类别用于对象检测。

下载以下GitHub存储库

https://github.com/pjreddie/darkneticon-default.png?t=N7T8https://github.com/pjreddie/darknet

创建一个名为config的文件夹,将darknet/cfg/coco.data、darknet/cfg/yolov3.cfg文件复制到config文件夹中。

创建一个名为data的文件夹,从以下链接获取coco.names文件,并将其放入data文件夹,coco.names文件包含COCO数据集中80个对象类别的列表。

darknet/data/coco.names at master · pjreddie/darknet · GitHubConvolutional Neural Networks. Contribute to pjreddie/darknet development by creating an account on GitHub.icon-default.png?t=N7T8https://github.com/pjreddie/darknet/blob/master/data/coco.names将darknet/scripts/get_coco_dataset.sh文件复制到data文件夹中,并复制get_coco_cocoet.sh到data文件夹。接下来,打开一个终端并执行get_coco_cocoet.sh,该脚本将把完整的COCO数据集下载到名为coco的子文件夹中。也可通过以下链接下载coco数据集。

COCO2014_数据集-飞桨AI Studio星河社区 (baidu.com)icon-default.png?t=N7T8https://aistudio.baidu.com/datasetdetail/165195

在images文件夹中,有两个名为train 2014和val 2014的文件夹,分别包含82783和40504个图像。在labels文件夹中,有两个名为train 2014和val 2014的标签,分别包含82081和40137文本文件。这些文本文件包含图像中对象的边界框坐标。此外,trainvalno5k.txt文件是一个包含117264张图像的列表,这些图像将用于训练模型。此列表是train2014和val2014中图像的组合,5000个图像除外。5k.txt文件包含将用于验证的5000个图像的列表。

创建自定义数据集

完成数据集下载后,使用PyTorch的Dataset和Dataloader类创建训练和验证数据集和数据加载器。

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms.functional as TF
import os
import numpy as np

import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.__version__)
#定义CocoDataset类,并展示来自训练和验证数据集的一些示例图像
class CocoDataset(Dataset):
    def __init__(self, path2listFile, transform=None, trans_params=None):
        with open(path2listFile, "r") as file:
            self.path2imgs = file.readlines()
        
        self.path2labels = [
            path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt")
            for path in self.path2imgs]

        self.trans_params = trans_params
        self.transform = transform

    def __len__(self):
        return len(self.path2imgs)
    
    def __getitem__(self, index):
        path2img = self.path2imgs[index % len(self.path2imgs)].rstrip()

        img = Image.open(path2img).convert('RGB')

        path2label = self.path2labels[index % len(self.path2imgs)].rstrip()

        labels= None
        if os.path.exists(path2label):
            labels = np.loadtxt(path2label).reshape(-1, 5)
            
        if self.transform:
            img, labels = self.transform(img, labels, self.trans_params)

        return img, labels, path2img    
root_data="./data/coco"
path2trainList=os.path.join(root_data, "trainvalno5k.txt")

coco_train = CocoDataset(path2trainList)
print(len(coco_train))

 

# 从coco_train中获取图像、标签和图像路径
img, labels, path2img = coco_train[1] 
print("image size:", img.size, type(img))
print("labels shape:", labels.shape, type(labels))
print("labels \n", labels)

path2valList=os.path.join(root_data, "5k.txt")
coco_val = CocoDataset(path2valList, transform=None, trans_params=None)
print(len(coco_val))

img, labels, path2img = coco_val[7] 
print("image size:", img.size, type(img))
print("labels shape:", labels.shape, type(labels))
print("labels \n", labels)

import matplotlib.pylab as plt
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms.functional import to_pil_image
import random
%matplotlib inline
path2cocoNames="./data/coco.names"
fp = open(path2cocoNames, "r")
coco_names = fp.read().split("\n")[:-1]
print("number of classese:", len(coco_names))
print(coco_names)

def rescale_bbox(bb,W,H):
    x,y,w,h=bb
    return [x*W, y*H, w*W, h*H]
COLORS = np.random.randint(0, 255, size=(80, 3),dtype="uint8")
# fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 16)
fnt = ImageFont.truetype('arial.ttf', 16)
def show_img_bbox(img,targets):
    if torch.is_tensor(img):
        img=to_pil_image(img)
    if torch.is_tensor(targets):
        targets=targets.numpy()[:,1:]
        
    W, H=img.size
    draw = ImageDraw.Draw(img)
    
    for tg in targets:
        id_=int(tg[0])
        bbox=tg[1:]
        bbox=rescale_bbox(bbox,W,H)
        xc,yc,w,h=bbox
        
        color = [int(c) for c in COLORS[id_]]
        name=coco_names[id_]
        
        draw.rectangle(((xc-w/2, yc-h/2), (xc+w/2, yc+h/2)),outline=tuple(color),width=3)
        draw.text((xc-w/2,yc-h/2),name, font=fnt, fill=(255,255,255,0))
    plt.imshow(np.array(img))        
np.random.seed(1)
rnd_ind=np.random.randint(len(coco_train))
img, labels, path2img = coco_train[rnd_ind] 
print(img.size, labels.shape)

plt.rcParams['figure.figsize'] = (20, 10)
show_img_bbox(img,labels)

np.random.seed(1)
rnd_ind=np.random.randint(len(coco_val))
img, labels, path2img = coco_val[rnd_ind] 
print(img.size, labels.shape)

plt.rcParams['figure.figsize'] = (20, 10)
show_img_bbox(img,labels)

转换数据

定义一个转换函数和传递给CocoDataset类的参数

def pad_to_square(img, boxes, pad_value=0, normalized_labels=True):
    w, h = img.size
    w_factor, h_factor = (w,h) if normalized_labels else (1, 1)
    
    dim_diff = np.abs(h - w)
    pad1= dim_diff // 2
    pad2= dim_diff - pad1
    
    if h<=w:
        left, top, right, bottom= 0, pad1, 0, pad2
    else:
        left, top, right, bottom= pad1, 0, pad2, 0
    padding= (left, top, right, bottom)

    img_padded = TF.pad(img, padding=padding, fill=pad_value)
    w_padded, h_padded = img_padded.size
            
    x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)
    y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
    x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
    y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)    
    
    x1 += padding[0] # 左
    y1 += padding[1] # 上
    x2 += padding[2] # 右
    y2 += padding[3] # 下
            
    boxes[:, 1] = ((x1 + x2) / 2) / w_padded
    boxes[:, 2] = ((y1 + y2) / 2) / h_padded
    boxes[:, 3] *= w_factor / w_padded
    boxes[:, 4] *= h_factor / h_padded

    return img_padded, boxes    
def hflip(image, labels):
    image = TF.hflip(image)
    labels[:, 1] = 1.0 - labels[:, 1]
    return image, labels

def transformer(image, labels, params):
    if params["pad2square"] is True:
        image,labels= pad_to_square(image, labels)
    
    image = TF.resize(image,params["target_size"])

    if random.random() < params["p_hflip"]:
        image,labels=hflip(image,labels)

    image=TF.to_tensor(image)
    targets = torch.zeros((len(labels), 6))
    targets[:, 1:] = torch.from_numpy(labels)
    
    return image, targets
trans_params_train={
    "target_size" : (416, 416),
    "pad2square": True,
    "p_hflip" : 1.0,
    "normalized_labels": True,
}
coco_train=CocoDataset(path2trainList,transform=transformer,trans_params=trans_params_train)

np.random.seed(100)
rnd_ind=np.random.randint(len(coco_train))
img, targets, path2img = coco_train[rnd_ind] 
print("image shape:", img.shape)
print("labels shape:", targets.shape) 

plt.rcParams['figure.figsize'] = (20, 10)
COLORS = np.random.randint(0, 255, size=(80, 3),dtype="uint8")
show_img_bbox(img,targets)

通过传递 transformer 函数来定义 CocoDataset 的一个对象来验证数据 

trans_params_val={
    "target_size" : (416, 416),
    "pad2square": True,
    "p_hflip" : 0.0,
    "normalized_labels": True,
}
coco_val= CocoDataset(path2valList,
                      transform=transformer,
                      trans_params=trans_params_val)

np.random.seed(55)
rnd_ind=np.random.randint(len(coco_val))
img, targets, path2img = coco_val[rnd_ind] 
print("image shape:", img.shape)
print("labels shape:", targets.shape) 

plt.rcParams['figure.figsize'] = (20, 10)
COLORS = np.random.randint(0, 255, size=(80, 3),dtype="uint8")
show_img_bbox(img,targets)

 

定义数据加载器

定义两个用于训练和验证数据集的数据加载器,从coco_train和coco_val中获取小批量数据。

from torch.utils.data import DataLoader

batch_size=8
def collate_fn(batch):
    imgs, targets, paths = list(zip(*batch))
    
    targets = [boxes for boxes in targets if boxes is not None]
    
    for b_i, boxes in enumerate(targets):
        boxes[:, 0] = b_i
    targets = torch.cat(targets, 0)
    imgs = torch.stack([img for img in imgs])
    return imgs, targets, paths

train_dl = DataLoader(
        coco_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
        collate_fn=collate_fn,
        )

torch.manual_seed(0)
for imgs_batch,tg_batch,path_batch in train_dl:
    break
print(imgs_batch.shape)
print(tg_batch.shape,tg_batch.dtype)

 

val_dl = DataLoader(
        coco_val,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=collate_fn,
        )

torch.manual_seed(0)
for imgs_batch,tg_batch,path_batch in val_dl:
    break
print(imgs_batch.shape)
print(tg_batch.shape,tg_batch.dtype)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1949425.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

javaEE-02-servlet

文章目录 Servlet 技术servlet程序示例通过实现Servlet接口实现Servlet程序通过继承 HttpServlet 实现 Servlet 程序 Servlet的声明周期 ServletConfig 类ServletContext 类HttpServletRequest 类请求的转发 HttpServletResponse 类请求重定向 HTTP 协议GET 请求Post请求常用请…

三维影像系统PACS源码,图像存储与传输系统,应用于医院中管理医疗设备如CT,MR等产生的医学图像的信息系统

PACS&#xff0c;即图像存储与传输系统&#xff0c;是应用于医院中管理医疗设备如CT&#xff0c;MR等产生的医学图像的信息系统。目标是支持在医院内部所有关于图像的活动&#xff0c;集成了医疗设备&#xff0c;图像存储和分发&#xff0c;数字图像在重要诊断和会诊时的显示&a…

unity ui toolkit的使用

UIToolkitExamples (github)样例 GitHub - ikewada/UIToolkitExamples: チュートリアル動画「使ってみようUI Toolkit」のためのサンプルプロジェクトです官网 Unity - Manual: UI Toolkit视频教程 使用 UI Toolkit - 上集_哔哩哔哩_bilibili 使用 UI Toolkit - 下集_哔哩哔哩_…

vue3前端开发-小兔鲜项目-使用pinia插件完成token的本地存储

vue3前端开发-小兔鲜项目-使用pinia插件完成token的本地存储&#xff01;实际业务开发中&#xff0c;token是一个表示着用户登录状态的重要信息&#xff0c;它有自己的生命周期。因此&#xff0c;这个参数值必须实例化存储在本地中。不能跟着pinia。因为pinia是基于内存设计的模…

go语言day18 reflect反射

Golang-100-Days/Day16-20(Go语言基础进阶)/day19_Go语言反射.md at master rubyhan1314/Golang-100-Days (github.com) 一、interface接口 接口类型内部存储了一对pair(value,Type) type interface { type *Type // 类型信息 data unsafe.Pointer // 指向具体数据 } 1)创建R…

Git基本原理讲解、常见命令、Git版本回退、Git抛弃本地分支拉取仓库最新分支

借此机会写篇博客汇总一下自己去公司实习之后遇到的一些常见关于Git的操作。 Git基本认识 Git把数据看作是对小型文件系统的一组快照&#xff0c;每次提交更新&#xff0c;或在Git中保存项目状态时&#xff0c;Git主要对当时的全部文件制作一个快照并保存这个快照的索引。同时…

嵌入式C++、MQTT、数据库、Grafana、机器学习( Scikit-learn):智能建筑大数据管理平台(代码示例)

项目概述 智能建筑管理系统&#xff08;Intelligent Building Management System, IBMS&#xff09;是一个集成多种技术的复杂系统&#xff0c;旨在通过智能化手段提升建筑的管理效率、节能效果和居住舒适度。该系统涉及嵌入式系统、物联网&#xff08;IoT&#xff09;、大数据…

数据库-触发器,存储过程

按照题目要求完成下列题目&#xff1a; 1.触发器 mysql> use mydb16_trigger; Database changed mysql> create table goods(-> gid char(8) primary key,-> name varchar(10),-> price decimal(8,2),-> num int); Query OK, 0 rows affected (0.01 sec)my…

01 Redis引入和概述

Redis引入和概述 一、Redis的历史和发展过程 ​ Redis是在2008年由意大利的一家创业公司Merzia的创始人Salvatore Sanfilippo(萨尔瓦托.圣菲利波)创造的。 ​ 当时&#xff0c;Salvatore 正在开发一款基于MySQL的网站实时统计系统LLOOGG&#xff0c;然而他发现MySQL的性能并…

VAE、GAN与Transformer核心公式解析

VAE、GAN与Transformer核心公式解析 VAE、GAN与Transformer&#xff1a;三大深度学习模型的异同解析 【表格】VAE、GAN与Transformer的对比分析 序号对比维度VAE&#xff08;变分自编码器&#xff09;GAN&#xff08;生成对抗网络&#xff09;Transformer&#xff08;变换器&…

计算机网络(四)数字签名和CA认证

什么是数字签名和CA认证&#xff1f; 数字签名 数字签名的过程通常涉及以下几个步骤&#xff1a; 信息哈希&#xff1a;首先&#xff0c;发送方使用一个哈希函数&#xff08;如SHA-256&#xff09;对要发送的信息&#xff08;如电子邮件、文件等&#xff09;生成一个固定长度…

GIS场景升级:支持多种影像协议与天气效果

在GIS场景编辑领域&#xff0c;升级视效的需求日益增加。有一款名为山海鲸可视化的免费工具&#xff0c;本人亲测能够完美满足这一需求。山海鲸可视化不仅支持多种GIS影像协议&#xff08;如TMS、WMS、WMTS等&#xff09;&#xff0c;还能一键添加天气效果&#xff0c;瞬间提升…

【Unity】 HTFramework框架(五十三)使用 Addressables 可寻址系统

更新日期&#xff1a;2024年7月25日。 Github源码&#xff1a;[点我获取源码] Gitee源码&#xff1a;[点我获取源码] 索引 Addressables 可寻址系统使用 Addressables 可寻址系统一、导入 Addressables二、切换到 Addressables 加载模式三、切换资源加载助手四、加载资源五、注…

【全面介绍Python多线程】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进步! 🦇目录 1. 🦇前言2. 🦇threading 模块的基本用法3. 🦇Thre…

编程类精品GPTs

文章目录 编程类精品GPTs前言种类ChatGPT - GrimoireProfessional-coder-auto-programming 总结 编程类精品GPTs 前言 代码类的AI, 主要看以下要点: 面对含糊不清的需求是否能引导出完整的需求面对完整的需求是否能分步编写代码完成需求编写的代码是否具有可读性和可扩展性 …

【个人亲试最新】WSL2中的Ubuntu 22.04安装Docker

文章目录 Wsl2中的Ubuntu22.04安装Docker其他问题wsl中执行Ubuntu 报错&#xff1a;System has not been booted with systemd as init system (PID 1). Can‘t operate. 参考博客 &#x1f60a;点此到文末惊喜↩︎ Wsl2中的Ubuntu22.04安装Docker 确定为wsl2ubuntu22.04&#…

57 数据链路层

用于两个设备&#xff08;同一种数据链路节点&#xff09;之间传递 目录 对比理解“数据链路层” 和 “网络层”以太网 2.1 认识以太网 2.2 以太网帧格式MAC地址 3.1 认识MAC地址 3.2 对比理解MAC地址和IP地址局域网通信MTU 5.1 认识MTU 5.2 MTU对ip协议的影响 5.3 MTU对UDP的…

vue elementui 在table里使用el-switch

<el-table-columnprop"operationStatus"label"状态"header-align"center"align"center"><template slot-scope"scope"><el-switch active-value"ENABLE" inactive-value"DISABLE" v-mod…

Java OpenCV 图像处理40 图形图像 图片裁切ROI

Java OpenCV 图像处理40 图形图像 图片裁切 在 OpenCV 中&#xff0c;Rect 类是用来表示矩形的数据结构&#xff0c;通常用于定义图像处理中的感兴趣区域&#xff08;Region of Interest&#xff0c;ROI&#xff09;&#xff0c;或者指定图像中的某个区域的位置和大小。Rect 类…

【深度学习】大模型GLM-4-9B Chat ,微调与部署(3) TensorRT-LLM、TensorRT量化加速、Triton部署

文章目录 获取TensorRT-LLM代码&#xff1a;构建docker镜像并安装TensorRT-LLM&#xff1a;运行docker镜像&#xff1a;安装依赖魔改下部分package代码&#xff1a;量化&#xff1a;构建图&#xff1a;全局参数插件配置常用配置参数 测试推理是否可以代码推理CLI推理 性能测试小…