联合目标检测与图像分类提升数据不平衡场景下的准确率

news2025/2/25 0:19:49

联合目标检测与图像分类提升数据不平衡场景下的准确率

在一些数据不平衡的场景下,使用单一的目标检测模型很难达到99%的准确率。为了优化这一问题,适当将其拆解为目标检测模型图像分类模型的组合,可以更有效地控制最终效果,尤其是在添加焦点损失(focal loss)、调整超参数和数据预处理无效的情况下。以下是具体的实现方式及联合两个模型的推理代码。

整体功能概述

这段代码的主要功能包括:

  1. 加载目标检测和分类模型:使用两个 Ultralytics YOLO(YOLOv8/YOLOv11均可) 模型进行目标检测和分类。
  2. 处理图像:遍历指定输入文件夹中的所有图像,进行目标检测和分类。
  3. 绘制检测框和分类标签:在图像上绘制检测到的对象的边界框,并在框上方添加分类名称和置信度。
  4. 可选保存裁剪的对象图像:根据设置,裁剪检测到的对象区域并保存为单独的图像文件,文件名包含类别名称、置信度和坐标信息(便于调试)。

实现细节

1. 加载模型

代码加载了两个 YOLO 模型:

  • 目标检测模型:一个单一类别的 YOLO 模型,用于检测主体对象。
  • 图像分类模型:一个多类别的 YOLO 模型,用于对检测到的对象进行分类。

2. 处理图像

脚本处理输入文件夹中的每一张图像,步骤如下:

  • 目标检测:使用目标检测模型检测图像中的对象。
  • 裁剪检测到的对象:根据检测到的边界框坐标,裁剪出感兴趣的区域。
  • 图像分类:对裁剪出的对象区域进行分类。
  • 数据增强或欠采样:根据任务需求,对裁剪出的子图像进行数据增强或欠采样,以平衡数据集。

3. 绘制检测框和标签

对于每一个检测到的对象,脚本会:

  • 在图像上绘制一个边界框。
  • 在边界框上方添加分类名称和置信度标签。

4. 保存裁剪的对象图像

可选地,脚本会保存裁剪出的对象图像,文件名包含以下信息:

  • 分类名称
  • 置信度
  • 边界框坐标

这对于调试和分析特定的检测结果非常有帮助。

推理代码

import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
import random

def generate_random_color_from_name(name):
    """根据类别名生成可重复的颜色。"""
    random.seed(name)  # 使用类别名作为随机种子
    return tuple(random.randint(0, 255) for _ in range(3))

def generate_class_colors(names):
    """为每个类别生成一个固定的颜色。"""
    class_colors = {}
    for class_name in names:
        class_colors[class_name] = generate_random_color_from_name(class_name)
    return class_colors

def draw_box_on_image(image, box, color=(0, 255, 0), thickness=2):
    """在图像上绘制检测框。"""
    x1, y1, x2, y2 = map(int, box)
    cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)

def add_classification_to_box(image, box, class_name, confidence, color=(0, 255, 0)):
    """在边界框上方添加分类名称和置信度。"""
    x1, y1, x2, y2 = map(int, box)
    label = f"{class_name}: {confidence:.2f}"
    cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2, cv2.LINE_AA)

def save_cropped_object(image, box, cls_class_name, confidence, output_folder, image_name):
    """将裁剪的对象区域保存为图像到子文件夹中,文件名包含类别名、置信度和坐标。"""
    x1, y1, x2, y2 = map(int, box)
    cropped_img = image[y1:y2, x1:x2]
    
    # 为当前图像创建一个以图像文件名命名的子文件夹
    image_subfolder = Path(output_folder) / Path(image_name).stem
    image_subfolder.mkdir(parents=True, exist_ok=True)
    
    # 为裁剪的对象创建文件名(class_name_confidence_x1_y1_x2_y2.jpg)
    # 确保置信度格式安全,使用两位小数,并用下划线分隔
    cropped_img_name = f"{cls_class_name}_{confidence:.2f}_{x1}_{y1}_{x2}_{y2}.jpg"
    cropped_img_path = image_subfolder / cropped_img_name
    cv2.imwrite(str(cropped_img_path), cropped_img)
    print(f"已保存裁剪对象: {cropped_img_path}")

def process_image_with_detection_and_classification(model_det, model_cls, img_path, names, class_colors, output_folder, save_cropped=False, detection_size=1280, classification_size=640):
    """
    处理单张图像:执行对象检测,分类每个对象,并返回处理后的图像。

    :param model_det: 检测模型
    :param model_cls: 分类模型
    :param img_path: 图像路径
    :param names: 类别名称列表
    :param class_colors: 类别颜色映射字典
    :param output_folder: 输出文件夹路径
    :param save_cropped: 是否保存裁剪的对象图像
    :param detection_size: 检测模型输入图像大小
    :param classification_size: 分类模型输入图像大小
    :return: 处理后的图像
    """
    img = cv2.imread(str(img_path))
    if img is None:
        print(f"无法读取图像: {img_path}")
        return None

    # 创建图像副本用于绘制(不修改原始图像)
    img_copy = img.copy()

    # 执行对象检测
    results_det = model_det.predict(str(img_path), imgsz=detection_size, conf=0.25, iou=0.45)

    # 处理每个检测结果(每个检测框)
    for r in results_det:
        boxes = r.boxes.xyxy.cpu().numpy()  # xyxy 格式
        classes = r.boxes.cls.cpu().numpy()
        confidences = r.boxes.conf.cpu().numpy()

        for box, cls_id, confidence in zip(boxes, classes, confidences):
            # 检测模型的类别名
            det_class_name = names[int(cls_id)]
            
            # 使用检测到的类别名对应的颜色(该颜色是全局唯一的)
            color = class_colors.get(det_class_name, (255, 255, 255))
            
            # 裁剪对象区域
            x1, y1, x2, y2 = map(int, box)
            object_region = img[y1:y2, x1:x2]
            # 将对象区域调整为分类模型的输入大小
            object_region = cv2.resize(object_region, (classification_size, classification_size))

            # 执行分类
            results_cls = model_cls.predict(object_region, imgsz=classification_size)

            for result in results_cls:
                try:
                    # 获取Top1预测结果
                    classification_confidence = result.probs.cpu().numpy().top1conf
                    top1_index = result.probs.top1
                    cls_class_name = names[top1_index]

                    # 根据分类结果的类别名设置颜色
                    final_color = class_colors.get(cls_class_name, color)
                    add_classification_to_box(img_copy, box, cls_class_name, classification_confidence, color=final_color)

                    # 如果启用了保存裁剪对象,则保存
                    if save_cropped:
                        save_cropped_object(img, box, cls_class_name, classification_confidence, output_folder, img_path.name)
                except Exception as e:
                    print(f"分类时出错: {e}")

            # 在图像副本上绘制检测框
            draw_box_on_image(img_copy, box, color=color)

    return img_copy

def process_images(model_det, model_cls, input_folder, output_folder, names, class_colors, save_cropped=False, detection_size=1280, classification_size=640):
    """
    处理输入文件夹中的图像,执行对象检测和分类,并保存处理后的图像。

    :param model_det: 检测模型
    :param model_cls: 分类模型
    :param input_folder: 输入文件夹路径
    :param output_folder: 输出文件夹路径
    :param names: 类别名称列表
    :param class_colors: 类别颜色映射字典
    :param save_cropped: 是否保存裁剪的对象图像
    :param detection_size: 检测模型输入图像大小
    :param classification_size: 分类模型输入图像大小
    """
    Path(output_folder).mkdir(parents=True, exist_ok=True)

    image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.webp']
    for ext in image_extensions:
        for img_path in Path(input_folder).glob(ext):
            print(f"正在处理: {img_path}")
            processed_img = process_image_with_detection_and_classification(
                model_det, model_cls, img_path, names, class_colors, output_folder, save_cropped, detection_size, classification_size
            )

            if processed_img is not None:
                output_image_path = Path(output_folder) / f"{img_path.stem}_with_boxes_and_classification.jpg"
                cv2.imwrite(str(output_image_path), processed_img)
                print(f"已保存处理后的图像: {output_image_path}")
            else:
                print(f"跳过图像: {img_path} (无法处理)")

if __name__ == '__main__':
    # 设置是否保存裁剪的对象图像(默认不保存)
    SAVE_CROPPED = True  # 设置为 True 以启用保存裁剪对象

    # 加载检测和分类模型
    model_det = YOLO('runs/device_train/exp9/weights/best.pt')
    model_cls = YOLO('runs/cls_99.4%_exp14/weights/best.pt')

    # 设置输入和输出文件夹路径
    input_folder = 'test1'
    output_folder = 'infer-1216'

    # 获取类别名(用于生成一致的类别颜色映射)
    # 这里使用一张全白的图像来获取类别名
    black_image = 255 * np.ones((224, 224, 3), dtype=np.uint8)
    results = model_cls.predict(source=black_image)
    name_dict = results[0].names
    names = list(name_dict.values())

    # 只在这里生成一次类别颜色映射
    class_colors = generate_class_colors(names)

    # 开始处理图像
    process_images(
        model_det, model_cls, input_folder, output_folder,
        names, class_colors,
        save_cropped=SAVE_CROPPED,
        detection_size=1280,
        classification_size=224
    )

执行完后的结果
在这里插入图片描述

下面贴一下目标检测和图像分类的ultralytics的训练代码

目标检测训练代码

注意把single_cls=False改成True,变成单类训练

# nohup python -m torch.distributed.launch --nproc_per_node=4 --master_port=25643 det_train.py > output-lane-1212.txt 2>&1 &
# nohup python -m torch.distributed.launch --nproc_per_node=5 --master_port=25698 det_train.py > output-lane-1212.txt 2>&1 &
from ultralytics import YOLO

if __name__ == '__main__':
    # 加载模型
    model = YOLO("checkpoints/yolo11l.pt")  # 使用预训练权重训练
    # 训练参数 ----------------------------------------------------------------------------------------------
    model.train(
        data='/home/lizhijun/01.det/ultralytics-8.3.23/datasets/device_1212_yolo_without_vdd/config.yaml',
        epochs=150,  # (int) 训练的周期数
        patience=50,  # (int) 等待无明显改善以进行早期停止的周期数
        batch=16,  # (int) 每批次的图像数量(-1 为自动批处理)
        imgsz=1280,  # (int) 输入图像的大小,整数或w,h
        save=True,  # (bool) 保存训练检查点和预测结果
        save_period=-1,  # (int) 每x周期保存检查点(如果小于1则禁用)
        cache=False,  # (bool) True/ram、磁盘或False。使用缓存加载数据
        device='1,2,3,5',  # (int | str | list, optional) 运行的设备,例如 cuda device=0 或 device=0,1,2,3 或 device=cpu
        workers=8,  # (int) 数据加载的工作线程数(每个DDP进程)
        project='runs/device_train',  # (str, optional) 项目名称
        name='exp',  # (str, optional) 实验名称,结果保存在'project/name'目录下
        exist_ok=False,  # (bool) 是否覆盖现有实验
        pretrained=True,  # (bool | str) 是否使用预训练模型(bool),或从中加载权重的模型(str)
        optimizer='auto',  # (str) 要使用的优化器,选择=[SGD,Adam,Adamax,AdamW,NAdam,RAdam,RMSProp,auto]
        verbose=True,  # (bool) 是否打印详细输出
        seed=0,  # (int) 用于可重复性的随机种子
        deterministic=True,  # (bool) 是否启用确定性模式
        single_cls=False,  # (bool) 将多类数据训练为单类
        rect=False,  # (bool) 如果mode='train',则进行矩形训练,如果mode='val',则进行矩形验证
        cos_lr=True,  # (bool) 使用余弦学习率调度器
        close_mosaic=10,  # (int) 在最后几个周期禁用马赛克增强
        resume=False,  # (bool) 从上一个检查点恢复训练
        amp=True,  # (bool) 自动混合精度(AMP)训练,选择=[True, False],True运行AMP检查
        fraction=1.0,  # (float) 要训练的数据集分数(默认为1.0,训练集中的所有图像)
        profile=False,  # (bool) 在训练期间为记录器启用ONNX和TensorRT速度
        freeze= None,  # (int | list, 可选) 在训练期间冻结前 n 层,或冻结层索引列表。
        # 超参数 ----------------------------------------------------------------------------------------------
        lr0=0.01,  # (float) 初始学习率(例如,SGD=1E-2,Adam=1E-3)
        lrf=0.01,  # (float) 最终学习率(lr0 * lrf)
        momentum=0.937,  # (float) SGD动量/Adam beta1
        weight_decay=0.0005,  # (float) 优化器权重衰减 5e-4
        warmup_epochs=3.0,  # (float) 预热周期(分数可用)
        warmup_momentum=0.8,  # (float) 预热初始动量
        warmup_bias_lr=0.1,  # (float) 预热初始偏置学习率
        box=6,  # (float) 盒损失增益
        cls=1.5,  # (float) 类别损失增益(与像素比例)
        dfl=1.5,  # (float) dfl损失增益
        pose=12.0,  # (float) 姿势损失增益
        kobj=1.0,  # (float) 关键点对象损失增益
        label_smoothing=0.05,  # (float) 标签平滑(分数)
        nbs=64,  # (int) 名义批量大小
        hsv_h=0.015,  # (float) 图像HSV-Hue增强(分数)
        hsv_s=0.7,  # (float) 图像HSV-Saturation增强(分数)
        hsv_v=0.4,  # (float) 图像HSV-Value增强(分数)
        degrees=90.0,  # (float) 图像旋转(+/- deg)
        translate=0.5,  # (float) 图像平移(+/- 分数)
        scale=0.5,  # (float) 图像缩放(+/- 增益)
        shear=0.4,  # (float) 图像剪切(+/- deg)
        perspective=0.0,  # (float) 图像透视(+/- 分数),范围为0-0.001
        flipud=0.5,  # (float) 图像上下翻转(概率)
        fliplr=0.5,  # (float) 图像左右翻转(概率)
        mosaic=1.0,  # (float) 图像马赛克(概率)
        mixup=0.0,  # (float) 图像混合(概率)
        copy_paste=0.0,  # (float) 分割复制-粘贴(概率)
    )



图像分类训练代码

from ultralytics import YOLO

model = YOLO("checkpoints/yolo11l-cls.pt")
model.train(
    data='/home/lizhijun/01.det/ultralytics-8.3.23/datasets/device_cls_merge_manual_with_21w_1218_train_val_224_truncate_grid_110%', 
    project='runs/cls_train',  # (str, optional) 项目名称
    name='exp',  # (str, optional) 实验名称,结果保存在'project/name'目录下
    epochs=20, 
    batch=1024,
    device='1,2,3,5',
    erasing=0.0,
    crop_fraction=1.0,
    augment=False,
    auto_augment=False,
    hsv_h=0.015,  # (float) 图像HSV-Hue增强(分数)
    hsv_s=0.7,  # (float) 图像HSV-Saturation增强(分数)
    hsv_v=0.4,  # (float) 图像HSV-Value增强(分数)
    degrees=0.0,  # (float) 图像旋转(+/- deg)
    translate=0.0,  # (float) 图像平移(+/- 分数)
    scale=0.0,  # (float) 图像缩放(+/- 增益)
    shear=0.0,  # (float) 图像剪切(+/- deg)
    perspective=0.0,  # (float) 图像透视(+/- 分数),范围为0-0.001
    flipud=0.5,  # (float) 图像上下翻转(概率)
    fliplr=0.5,  # (float) 图像左右翻转(概率)
    mosaic=1.0,  # (float) 图像马赛克(概率)
    mixup=0.0)  # (float) 图像混合(概率))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2265471.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HDR视频技术之十:MPEG 及 VCEG 的 HDR 编码优化

与传统标准动态范围( SDR)视频相比,高动态范围( HDR)视频由于比特深度的增加提供了更加丰富的亮区细节和暗区细节。最新的显示技术通过清晰地再现 HDR 视频内容使得为用户提供身临其境的观看体验成为可能。面对目前日益…

LabVIEW声音信号处理系统

开发了一种基于LabVIEW的声音信号处理系统,通过集成的信号采集与分析一体化解决方案,提升电子信息领域教学与研究的质量。系统利用LabVIEW图形化编程环境和硬件如USB数据采集卡及声音传感器,实现了从声音信号的采集到频谱分析的全过程。 项目…

OpenCL(壹):了解OpenCL模型到编写第一个CL内核程序

目录 1.前言 2.简单了解OpenCL 3.为什么要使用OpenCL 4.OpenCL架构 5.OpenCL中的平台模型(Platform Model) 6.OpenCL中的内存模型(Execution Model) 7.OpenCL中的执行模型(Memory Model) 8.OpenCL中的编程模型(Programmin Model) 9.OpenCL中的同步机制 10.编写第一个OpenCL程序…

Flutter组件————Scaffold

Scaffold Scaffold 是一个基础的可视化界面结构组件,它实现了基本的Material Design布局结构。使用 Scaffold 可以快速地搭建起包含应用栏(AppBar)、内容区域(body)、抽屉菜单(Drawer)、底部导…

【数据结构】数据结构整体大纲

数据结构用来干什么的?很简单,存数据用的。 (这篇文章仅介绍数据结构的大纲,详细讲解放在后面的每一个章节中,逐个击破) 那为什么不直接使用数组、集合来存储呢 ——> 如果有成千上亿条数据呢&#xff…

搭建Elastic search群集

一、实验环境 二、实验步骤 Elasticsearch 是一个分布式、高扩展、高实时的搜索与数据分析引擎Elasticsearch目录文件: /etc/elasticsearch/elasticsearch.yml#配置文件 /etc/elasticsearch/jvm.options#java虚拟机 /etc/init.d/elasticsearch#服务启动脚本 /e…

链原生 Web3 AI 网络 Chainbase 推出 AVS 主网, 拓展 EigenLayer AVS 场景

在 12 月 4 日,链原生的 Web3 AI 数据网络 Chainbase 正式启动了 Chainbase AVS 主网,同时发布了首批 20 个 AVS 节点运营商名单。Chainbase AVS 是 EigenLayer AVS 中首个以数据智能为应用导向的主网 AVS,其采用四层网络架构,其中…

玩转OCR | 探索腾讯云智能结构化识别新境界

📝个人主页🌹:Eternity._ 🌹🌹期待您的关注 🌹🌹 ❀ 玩转OCR 腾讯云智能结构化识别产品介绍服务应用产品特征行业案例总结 腾讯云智能结构化识别 腾讯云智能结构化OCR产品分为基础版与高级版&am…

生信软件开发2 - 使用PyQt5开发一个简易GUI程序

往期文章: 生信软件开发1 - 设计一个简单的Windwos风格的GUI报告软件 1. 使用PyQt5设计一个计算器主程序 要求PyQt5 > 5.6, calculator.py与MainWindow.py处于同一目录,下载mainwindow-weird.ui和mainwindow.ui资源,运行calculator.py即…

“计算几何”简介

计算几何(Computational Geometry)简单来说就是用计算机解决几何问题。 Computational指“using or connected with computers使用计算机的;与计算机有关的”,Geometry指“the branch of mathematics that deals with the measur…

TowardsDataScience 博客中文翻译 2018~2024(一百二十三)

TowardsDataScience 博客中文翻译 2018~2024(一百二十三) 引言 从 2018 年到 2024 年,数据科学的进展超越了许多技术领域的速度。Towards Data Science 博客依然是这个领域的关键平台,记录了从基础工具到前沿技术的多方面发展。…

GitHub 桌面版配置 |可视化界面进行上传到远程仓库 | gitLab 配置【把密码存在本地服务器】

🥇 版权: 本文由【墨理学AI】原创首发、各位读者大大、敬请查阅、感谢三连 🎉 声明: 作为全网 AI 领域 干货最多的博主之一,❤️ 不负光阴不负卿 ❤️ 文章目录 桌面版安装包下载clone 仓库操作如下GitLab 配置不再重复输入账户和密码的两个方…

今天最新早上好问候语精选大全,每天问候,相互牵挂,彼此祝福

1、朋友相伴,友谊真诚永不变!彼此扶持绿树荫,共度快乐雨后天!一同分享的表情,愿我们友情长存,一生相伴永相连! 2、人生几十年,苦累伴酸甜,风华不再茂,雄心非当…

Verdi -- 打开Consol,创建和执行tcl命令举例

1.Verdi打开Console的步骤: For ref: 2创建tcl脚本. tcl脚本路径: 在Makefile下,与.v文件在同一个目录8_demo这个文件夹下。 font.tcl代码内容: verdiSetFont -monoFont "Courier" -monoFontSize "24" 作用…

基于java博网即时通讯软件的设计与实现【源码+文档+部署讲解】

目 录 1. 绪 论 1.1. 开发背景 1.2. 开发意义 2. 系统设计相关技术 2.1 Java语言 2.2 MySQL数据库 2.3 Socket 3. 系统需求分析 3.1 可行性分析 3.2 需求分析 3.3 系统流程图 3.4 非功能性需求 4. 系统设计 4.1 系统功能结构 4.2 数据库设计 5. 系统实现 5.…

视频汇聚融合云平台Liveweb一站式解决视频资源管理痛点

随着5G技术的广泛应用,各领域都在通信技术加持下通过海量终端设备收集了大量视频、图像等物联网数据,并通过人工智能、大数据、视频监控等技术方式来让我们的世界更安全、更高效。然而,随着数字化建设和生产经营管理活动的长期开展&#xff0…

Hadoop集群(HDFS集群、YARN集群、MapReduce​计算框架)

一、 简介 Hadoop主要在分布式环境下集群机器,获取海量数据的处理能力,实现分布式集群下的大数据存储和计算。 其中三大核心组件: HDFS存储分布式文件存储、YARN分布式资源管理、MapReduce分布式计算。 二、工作原理 2.1 HDFS集群 Web访问地址&…

文本的AIGC率检测原理

背景 你可能在学生群里或者视频中看过这样的消息:“我们学校要求论文AI率不能超过30%!”、“你们学校查AI率吗?”之类的,这些消息到底是真是假? 随着人工智能的快速发展和广泛应用,不论是工作中还是学生学…

PODS:2024-12-21由麻省理工学院 和 OpenAI联合创建一个专门为个性化对象识别任务设计的数据集.

2024-12-21,由MIT和OpenAI联合创建的个性化视觉数据集,为细粒度和数据稀缺的个性化视觉任务提供了新的解决方案,推动了个性化模型的发展,具有重要的研究和应用价值。 一、研究背景: 在计算机视觉领域,现代…

OpenFeign快速入门 示例:黑马商城

使用起因 之前我们利用了Nacos实现了服务的治理,利用RestTemplate实现了服务的远程调用。这样一来购物车虽然通过远程调用实现了调用商品服务的方法,但是远程调用的代码太复杂了: 解决方法 并且这种调用方式比较复杂,一会儿远程调用,一会儿本地调用。 因…