Lnton羚通关于【PyTorch】教程:torchvision 目标检测微调

news2025/1/16 3:43:43

torchvision 目标检测微调
本教程将使用Penn-Fudan Database for Pedestrian Detection and Segmentation 微调 预训练的Mask R-CNN 模型。 它包含 170 张图片,345 个行人实例。

定义数据集
用于训练目标检测、实例分割和人物关键点检测的参考脚本允许轻松支持添加新的自定义数据集。数据集应继承自标准的 torch.utils.data.dataset 类,并实现 __len__ 和 __getitem__ 。

__getitem__ 需要返回:

image: PIL 图像 (H, W)
target: 字典数据,需要包含字段
boxes (FloatTensor[N, 4]): N 个 Bounding box 的位置坐标 [x0, y0, x1, y1], 0~W, 0~H
labels (Int64Tensor[N]): 每个 Bounding box 的类别标签,0 代表背景类。
image_id (Int64Tensor[1]): 图像的标签 id,在数据集中是唯一的。
area (Tensor[N]): Bounding box 的面积,在 COCO 度量里使用,可以分别对不同大小的目标进行度量。
iscrowd (UInt8Tensor[N]): 如果 iscrowd=True 在评估时忽略。
(optionally) masks (UInt8Tensor[N, H, W]): 可选的 分割掩码
(optionally) keypoints (FloatTensor[N, K, 3]): 对于 N 个目标来说,包含 K 个关键点 [x, y, visibility], visibility=0 表示关键点不可见。
如果模型可以返回上述方法,可以在训练、评估都能使用,可以用 pycocotools 里的脚本进行评估。

pip install pycocotools 安装工具。

关于 labels 有个说明,模型默认 0 为背景。如果数据集没有背景类别,不需要在标签里添加 0 。 例如,假设有 cat 和 dog 两类,定义了 1 表示 cat , 2 表示 dog , 如果一个图像有两个类别,类别的 tensor 为 [1, 2] 。

此外,如果希望在训练时使用纵横比分组,那么建议实现 get_height_and_width 方法,该方法将返回图像的高度和宽度,如果未提供此方法,我们将通过 __getitem__ 查询数据集的所有元素,这会将图像加载到内存中,并且比提供自定义方法的速度慢。

为 PennFudan 写自定义数据集
文件夹结构如下:

PennFudanPed/
  PedMasks/
    FudanPed00001_mask.png
    FudanPed00002_mask.png
    FudanPed00003_mask.png
    FudanPed00004_mask.png
    ...
  PNGImages/
    FudanPed00001.png
    FudanPed00002.png
    FudanPed00003.png
    FudanPed00004.png

这是图像的标注信息,包含了 mask 以及 bounding box 。每个图像都有对应的分割掩码,每个颜色代表不同的实例。

import os 
import numpy as np 
import torch 
from PIL import Image

class PennFudanDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
    
        ## 加载所有图像,sort 保证他们能够对应起来
        self.images = list(sorted(os.listdir(os.path.join(self.root, 'PNGImages'))))
        self.masks = list(sorted(os.listdir(os.path.join(self.root, 'PedMasks'))))
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root, 'PNGImages', self.images[idx])
        mask_path = os.path.join(self.root, 'PedMasks', self.masks[idx])
        image = Image.open(img_path).convert('RGB')
    
        ## mask 图像并没有转换为 RGB,里面存储的是标签,0表示的是背景
        mask = Image.open(mask_path)
    
        # 转换为 numpy
        mask = np.array(mask) 
    
        # 实例解码成不同的颜色
        obj_ids = np.unique(mask)
    
        # 移除背景
        obj_ids = obj_ids[1:]
    
        masks = mask == obj_ids[:, None, None]
    
        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])
        
        # 转换为 tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
    
        target = {}
    
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
    
        if self.transforms is not None:
            image, target = self.transforms(image, target)
    
        return image, target
  
    def __len__(self):
        return len(self.images)

Lnton羚通专注于音视频算法、算力、云平台的高科技人工智能企业。 公司基于视频分析技术、视频智能传输技术、远程监测技术以及智能语音融合技术等, 拥有多款可支持ONVIF、RTSP、GB/T28181等多协议、多路数的音视频智能分析服务器/云平台。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/896796.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3 个 ChatGPT 插件您需要立即下载3 ChatGPT Extensions You need to Download Immediately

在16世纪,西班牙探险家皮萨罗带领约200名西班牙士兵和37匹马进入了印加帝国。尽管印加帝国的军队数量达到了数万,其中包括5,000名精锐步兵和3,000名弓箭手,他们装备有大刀、长矛和弓箭等传统武器。但皮萨罗的军队中有100名火枪手,…

居然有这么好用的调试工具

居然有这么好用的调试工具 基本收发虚拟示波器GPIO操作PWM输出AD-DAIIC操作SPI操作GPS显示模块设置 基本收发 软件具备最常用的串口收发功能,可以在需要发送的数据最后选择添加一些常用的附加数据: 支持2通道COM口同时接收,目前自己最常用的…

ARM(实验二)

uart4.h #ifndef __H__ #define __H__#include "stm32mp1xx_rcc.h" #include "stm32mp1xx_gpio.h" #include "stm32mp1xx_uart.h"//RCC/GPIO/UART4章节初始化 void hal_uart4_init();//发送一个字符函数 void hal_put_char(const char str);//发…

Java进阶(4)——结合类加载JVM的过程理解创建对象的几种方式:new,反射Class,克隆clone(拷贝),序列化反序列化

目录 引出类什么时候被加载JVM中创建对象几种方式1.new 看到new : new Book()2.反射 Class.forName(“包名.类名”)如何获取Class对象【反射的基础】案例:连接数据库方法 3.克隆(拷贝)clone浅拷贝深拷贝案例 序列化和反序列化对象流-把对象存…

中大型无人机远程VHF语音电台系统方案

方案背景 中大型无人机在执行飞行任务时,特别是在管制空域飞行时地面航管人员需要通过语音与无人机通信。按《无人驾驶航空器飞行管理暂行条例》规定,中大型无人机应当进行适航管理。物流无人机和载人eVTOL都将进行适航管理,所以无人机也要有…

python ORM框架 sqlAlchemy

背景 最近在研究mysql的ORM框架,忽然看到了一个pip的包sqlalchemy,让我觉得很神奇,用下来的感觉和java的hibernate差不多,后边的链式查询又让我觉得和我很喜欢用的mybatis plus差不多,于是抱着好奇加上学习的态度&…

神经网络简单理解:机场登机

目录 神经网络简单理解:机场登机 ​编辑 激活函数:转为非线性问题 ​编辑 激活函数ReLU 通过神经元升维(神经元数量):提升线性转化能力 通过增加隐藏层:增加非线性转化能力​编辑 模型越大,…

OpenCV 玩转图像和视频

为什么学OpenCV? • OpenCV ⽀持对图像缩放、旋转、绘制⽂字图形等基础操作 • OpenCV 库包含了很多计算机视觉领域常⻅算法:⽬标检测、⽬标跟踪等 OpenCV 简介 • OpenCV (Open Source Computer Vision) 是计算机视觉和机器学习软件库 • Intel 1999…

特殊数字专题

特殊数字 1.奇数2.偶数3.完数4.素数5.回文数6.水仙花数7.中位数9.随机数11.求年份&#xff1a;闰年12.求数字&#xff1a;两个整数的最大公约数及最小公倍数 1.奇数 代码案例&#xff1a; //输出所有1-1000之间的奇数 #define _CRT_SECURE_NO_WARNINGS 1 #include<stdio.h&…

Java虚拟机(JVM):虚拟机栈溢出

一、概念 Java虚拟机栈溢出&#xff08;Java Virtual Machine Stack Overflow&#xff09;是指在Java程序中&#xff0c;当线程调用的方法层级过深&#xff0c;导致栈空间溢出的情况。 Java虚拟机栈是每个线程私有的&#xff0c;用于存储方法的调用和局部变量的内存空间。每当…

Java二分法查找

二分法&#xff1a;首先需要一个由小到大排序好的数组&#xff0c;先找到其中间值&#xff0c;然后进行比较如果比较中间值大的话则向前找。如果比要找的小&#xff0c;则向后找。 代码实现&#xff1a; //定义查询方法 public static int searchTarget(int[] nums, int targ…

用户新增预测(Datawhale机器学习AI夏令营第三期)

文章目录 简介任务1&#xff1a;跑通Baseline实操并回答下面问题&#xff1a;如果将submit.csv提交到讯飞比赛页面&#xff0c;会有多少的分数&#xff1f;代码中如何对udmp进行了人工的onehot&#xff1f; 任务2.1&#xff1a;数据分析与可视化编写代码回答下面的问题&#xf…

【CSS动画02--卡片旋转3D】

CSS动画02--卡片旋转3D 介绍代码HTMLCSS css动画02--旋转卡片3D 介绍 当鼠标移动到中间的卡片上会有随着中间的Y轴进行360的旋转&#xff0c;以下是几张图片的介绍&#xff0c;上面是鄙人自己录得一个供大家参考的小视频&#x1f92d; 代码 HTML <!DOCTYPE html>…

上半年营收19亿,金融壹账通第二增长曲线“加速上坡”

8月16日&#xff0c;壹账通金融科技有限公司&#xff08;下称“金融壹账通”&#xff09;发布了截至2023年6月30日中期业绩报告。 根据财报&#xff0c;2023年上半年&#xff0c;金融壹账通实现营收18.99亿元&#xff0c;毛利润为6.96亿元&#xff1b;归母净利润率从-26.1%提升…

(ElementPlus)操作(不使用 ts): Form表单检验、规则及案例分析(这一篇就够了)

Ⅰ、Form 表单检验简介&#xff1a; 1、什么是 Form 表单检验&#xff1f; 其一、属性&#xff1a; 表单验证是 javascript 中的高级选项之一&#xff1b; 其二、定义&#xff1a; JavaScript 在数据被送往服务器前对 HTML 中的 Form 表单中的这些输入数据进行验证的行为就…

MySQL索引介绍 为什么mysql使用B+树

什么是索引&#xff1f; 索引是一种用于快速查询和检索数据的数据结构&#xff0c;常见的索引结构有&#xff1a;B树&#xff0c;B树和Hash。 索引的作用就相当于目录。打个比方&#xff0c;我们在查字典的时候&#xff0c;如果没有目录&#xff0c;那我们就只能一页一页的去…

文本三剑客之sed编辑器

sed 一、sed简介1.1 什么是sed&#xff1f;1.2 sed原理1.3 sed核心功能 二、sed命令格式详解2.1 命令格式2.2 常用选项2.3 sed脚本语法2.3.1 基本语法结构2.3.2 地址部分-----指定匹配范围2.3.3 命令部分-----要执行的命令 三、sed查找替换3.1 基本语法3.2 分组后向引用3.3 变量…

.NET应用UI组件DevExpress XAF v23.1 - 全新的日程模块

DevExpress XAF是一款强大的现代应用程序框架&#xff0c;允许同时开发ASP.NET和WinForms。DevExpress XAF采用模块化设计&#xff0c;开发人员可以选择内建模块&#xff0c;也可以自行创建&#xff0c;从而以更快的速度和比开发人员当前更强有力的方式创建应用程序。 在新版中…

【爬虫练习之glidedsky】爬虫-基础2

题目 链接 爬虫往往不能在一个页面里面获取全部想要的数据&#xff0c;需要访问大量的网页才能够完成任务。 这里有一个网站&#xff0c;还是求所有数字的和&#xff0c;只是这次分了1000页。 思路 找到调用接口 可以看到后面有个参数page来控制页码 代码实现 import reques…

Python功能制作之简单的3D特效

需要导入的库&#xff1a; pygame: 这是一个游戏开发库&#xff0c;用于创建多媒体应用程序&#xff0c;提供了处理图形、声音和输入的功能。 from pygame.locals import *: 导入pygame库中的常量和函数&#xff0c;用于处理事件和输入。 OpenGL.GL: 这是OpenGL的Python绑定…