yolov5测试代码

news2024/9/20 14:31:46

一般源码的测试代码涉及很多文件,因项目需要写一个独立测试的代码。传入的是字典

import time
import cv2
import os
import numpy as np
import torch
from modules.detec.models.common import DetectMultiBackend
from modules.detec.utils.dataloaders import LoadImages
from modules.detec.utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr,
                           increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
from modules.detec.utils.augmentations import letterbox
from modules.detec.utils.plots import Annotator, colors


class DetectionEstimation:
    def __init__(self, model_path, conf_threshold=0.9, iou_threshold=0.45, img_size=(384,640)):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model = DetectMultiBackend(model_path).to(self.device)
        self.conf_threshold = conf_threshold
        self.iou_threshold = iou_threshold
        self.img_size = img_size

    def _preprocess_image(self, img_dict):
        img_tensor_list = []
        original_sizes = {}

        for serial, img in img_dict.items():
            original_size = img.shape[:2]
            img_resized = letterbox(img, self.img_size, stride=32, auto=True)[0]
            img_resized = img_resized.transpose((2, 0, 1))[::-1]
            img_resized = np.ascontiguousarray(img_resized)
            img_tensor = torch.from_numpy(img_resized).float().to(self.device)
            img_normalized = img_tensor / 255

            if len(img_normalized.shape) == 3:
                img_normalized = img_normalized[None]

            img_tensor_list.append(img_normalized)
            original_sizes[serial] = original_size

        image_input = torch.cat(img_tensor_list)
        return image_input, original_sizes

    def _postprocess_predictions(self, predictions, original_sizes):
        results = {}
        for i, (det, (serial, img)) in enumerate(zip(predictions, original_sizes.items())):
            if det is not None and len(det):
                det[:, :4] = scale_boxes(self.img_size, det[:, :4], img).round()
                labels = []
                coordinates = []
                for *xyxy, conf, cls in reversed(det):
                    label = self.model.names[int(cls)]
                    labels.append((label, conf.item()))
                    coordinates.append([xyxy[0].item(), xyxy[1].item(), xyxy[2].item(), xyxy[3].item()])
                results[serial] = {
                    'labels': labels,
                    'coordinates': coordinates
                }
        return results

    def predict(self, img_dict):
        start_total = time.time()

        start_preprocess = time.time()
        img_tensor, original_sizes = self._preprocess_image(img_dict)
        preprocess_time = time.time() - start_preprocess
        print(f"Preprocess Time: {preprocess_time * 1000:.2f}ms")

        start_inference = time.time()
        with torch.no_grad():
            predictions = self.model(img_tensor)
            inference_time = time.time() - start_inference
            print(f"Inference Time:{inference_time * 1000:.2f}ms")

        start_non_max_suppression = time.time()
        predictions = non_max_suppression(predictions, self.conf_threshold, self.iou_threshold)
        non_max_suppression_time = time.time() - start_non_max_suppression
        print(f"Non-Max Suppression Time: {non_max_suppression_time * 1000:.2f}ms")

        start_postprocess = time.time()
        results = self._postprocess_predictions(predictions, original_sizes)
        postprocess_time = time.time() - start_postprocess
        print(f"Postprocess Time: {postprocess_time * 1000:.2f}ms")

        total_time = time.time() - start_total
        print(f"Total Processing Time: {total_time * 1000:.2f}ms")

        print("res:",results)
        return results

    def draw_results(self, img_dict, results):
        annotated_images = {}
        for serial, img in img_dict.items():
            if serial in results:
                det = results[serial]['coordinates']  # 从 results 中提取处理后的坐标
                labels = results[serial]['labels']  # 提取标签和置信度
                annotator = Annotator(img, line_width=3, example=self.model.names)
                for i, (xyxy, (label, conf)) in enumerate(zip(det, labels)):
                    # 生成标签信息
                    label_str = f'{label} {conf:.2f}'
                    # 绘制检测框和标签
                    annotator.box_label(xyxy, label_str, color=colors(i, True))
                annotated_images[serial] = annotator.result()
        return annotated_images

    def _save_labels(self, results, output_folder, batch_size=3):
        os.makedirs(output_folder, exist_ok=True)
        img_serials = list(results.keys())

        for i in range(0, len(img_serials), batch_size):
            batch = img_serials[i:i + batch_size]
            combined_filename = '_'.join(batch) + '_labels.txt'
            labels_path = os.path.join(output_folder, combined_filename)

            with open(labels_path, 'w') as file:
                for serial in batch:
                    if serial in results:
                        result = results[serial]
                        file.write("{\n")
                        file.write(f"  'serial': '{result['serial']}',\n")
                        file.write(f"  'labels': {result['labels']},\n")
                        file.write(f"  'coordinates': {result['coordinates']},\n")
                        file.write("}\n\n")

if __name__ == "__main__":
    model_path = 'data/pt/best.pt'
    detector = DetectionEstimation(model_path)

    img_folder = './data/images/'
    img_dict = {}
    img_filenames = []

    for img_filename in os.listdir(img_folder):
        img_path = os.path.join(img_folder, img_filename)
        if img_path.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_data = cv2.imread(img_path)
            serial = os.path.splitext(img_filename)[0]
            img_dict[serial] = img_data
            img_filenames.append(img_filename)

    batch_size = 2
    img_keys = list(img_dict.keys())
    for i in range(0, len(img_keys), batch_size):
        batch_dict = {k: img_dict[k] for k in img_keys[i:i + batch_size]}
        results = detector.predict(batch_dict)
        annotated_images = detector.draw_results(batch_dict, results)

        os.makedirs('results', exist_ok=True)
        for serial, img in annotated_images.items():
            output_path = f'results/{serial}.jpg'
            success = cv2.imwrite(output_path, img)
            if not success:
                print(f'Error saving image {output_path}')
            else:
                print(f'Successfully saved image {output_path}')

        detector._save_labels(results, 'results/labels', batch_size=batch_size)

在该代码同级目录下放models、results、utils文件夹和export.py

运行该代码得到的txt文件是字典:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2149139.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

京东商品属性的详细api数据解析:颜色、尺寸与材质

京东(JD.com)作为一个大型电商平台,其商品信息通过API接口提供给开发者或第三方服务使用,以便进行商品搜索、展示、分析等操作。然而,直接访问京东的详细商品属性(如颜色、尺寸、材质等)API通常…

数据的表示和存储 第2讲 定点数的编码表示

​ 互联网行业 算法研发工程师 ​ 全文概括: 本讲介绍了定点数的编码表示,主要包括原码、补码和移码。 原码表示方式简单,正数用0表示,负数用1表示,但存在表示不唯一和加减运算不统一的问题。 补码表示方式解决了…

zabbix“专家坐诊”第256期问答

原作者:乐维社区 原文链接:https://forum.lwops.cn/questions 问题一 Q:zabbix 6.4.18版本的,使用zabbix_agentd2监控mysql数据库,只能在界面配置mysql的相关信息吗?这个在zabbix表里面是明文存储的&#x…

集采良药:从“天价神药”到低价良药,伊马替尼的真实世界研究!

在医疗科技日新月异的今天,有一种药物以其卓越的疗效和深远的影响力,成为了众多患者心中的“精准武器”——伊马替尼。这款药物不仅在慢性髓细胞白血病(CML)的治疗上屡创佳绩,更是胃肠道间质瘤(GIST&#x…

微信小程序自定义navigationBar顶部导航栏(背景图片)适配所有机型,使用tdesign-miniprogram t-navbar设置背景图片

设置导航栏样式自定义 一定要设置,不然页面会出现一个原生的导航栏,一个自定义的 // app.json文件 "window": {"navigationStyle": "custom" }设置导航栏样式 我这里使用tdesign-miniprogram t-navbar,t-na…

大模型的热度正在下降,大模型的未来在哪里?

“ 技术是一个需要沉淀和厚积薄发的过程 ” 任何事物都会经过起步,发展,顶峰,平稳,下降,灭亡的过程,大模型技术也不例外。 而从现今的趋势来看,大模型的热度正在不断下降,这到底意…

虫情测报灯的工作原理

型号:TH-CQ1】虫情测报灯是一种专门用于监测农田、林区等环境中昆虫数量和种类的设备,也称为智能虫情测报灯或物联网虫情测报灯。它通过特定的光源和颜色吸引昆虫,并利用高压电网或远红外自动处理技术等手段将昆虫击杀或处理,从而…

《黑龙江水产》是什么级别的期刊?是正规期刊吗?能评职称吗?

问题解答 问:《黑龙江水产》是不是核心期刊? 答:不是,是知网收录的第一批认定 学术期刊。 问:《黑龙江水产》级别? 答:省级。主管单位:黑龙江省农业农村厅 …

【QT】系统-下

欢迎来到Cefler的博客😁 🕌博客主页:折纸花满衣 🏠个人专栏:QT 目录 👉🏻QTheadrun() 👉🏻QMutex👉🏻QWaitCondition👉🏻Q…

视频存储EasyCVR视频监控汇聚管理平台设备录像下载报错404是什么原因?

EasyCVR视频监控汇聚管理平台是一款针对大中型项目设计的跨区域网络化视频监控集中管理平台。该平台不仅具备视频资源管理、设备管理、用户管理、运维管理和安全管理等功能,还支持多种主流标准协议,如GB28181、RTSP/Onvif、RTMP、部标JT808、GA/T 1400协…

基于SpringBoot的智能排课系统设计与实现

文未可获取一份本项目的java源码和数据库参考。 (一)选题来源与背景 高校的每学期伊始,排课是教务处工作中的重中之重。安排合理无资源冲突(教师、教室和设备等教学资源)的课表是教务工作必须面临的问题。传统的人工…

规模化电动汽车接入配电网调度方法

规模日益增长的电动汽车和可再生能源带来的不确定性给配电网的安全运营带来了严峻挑战。为综合考虑多重不确定性、平衡运营成本与系统可靠性,首先,提出一种基于分布鲁棒联合机会约束的电动汽车-配电网充放电调度模型。该模型将节点电压、支路功率、备用需求等通过联合机会约束建…

由一个 SwiftData “诡异”运行时崩溃而引发的钩深索隐(六)

概述 在 WWDC 24 中,苹果推出了数据库框架 SwiftData 2.0 版本。听说里面新增了能让数据记录“借尸还魂”的绝妙法器,到底是真是假呢? 我们在上篇博文中介绍了 History Trace 是如何稳妥的处理数据删除操作的。而在这里,我们将继续介绍 SwiftData 2.0 中另一个新特性:“墓…

Prometheus - nVisual插件让运维更轻松

Prometheus 是一个开源的服务监控系统和时间序列数据库,常用于对基础设施的监控,监控范围涵盖了硬件层、操作系统层、中间件层、应用层等运维所需的所有监控指标类型,同时可利用第三方可视化工具Grafana实现时序数据的展示。然而,…

Redis基础(数据结构和内部编码)

目录 前言 Redis的数据结构和内部编码 string结构和内部编码 string数据机构的特点 string数据结构的内部编码 list结构和内部编码 List 数据结构的特点 List 的内部编码 1. ziplist(压缩列表) 2. quicklist hash结构和内部编码 hash数据结构…

OpenCV特征检测(3)计算图像中每个像素处的特征值和特征向量函数cornerEigenValsAndVecs()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 计算图像块的特征值和特征向量用于角点检测。 对于每一个像素 p ,函数 cornerEigenValsAndVecs 考虑一个 blockSize blockSize 的邻…

Java 在 GIS 领域的学习路线?

Java是一门广泛应用于企业级开发的编程语言,而GIS则是一种常用于地理信息处理和分析的技术。将Java与GIS结合起来,可以在企业级应用中实现更多的功能和业务需求,且在实际领域越来越广泛。 Java在GIS中重要的作用 1、跨平台性 Java具有跨平台…

基于C语言+SQL Server2008实现(控制台)图书管理系统

第1章 概述 1.1项目背景 随着科技的发展,尤其是计算机技术的迅猛发展,图书馆管理的问题从以往的人工管理,到现在的电脑化,系统化,是对图书馆管理方法的质的飞跃,这些技术不仅让图书馆管理变得更加方便、快…

美国联邦基金有效利率及目标利率历史数据集(1990.1-2024.9)

美联储在2024年9月18日宣布将其调50个基点,降至4.75%至5.00%之间的水平。这是美联储自2020年3月以来首次降息,也是自2023年7月将利率水平调升至历史高位后的首次下调,标志着货币政策由紧缩周期向宽松周期的转向。一、数据介绍 数据名称&…

web基础—dvwa靶场(八)XSS

XSS(DOM) 跨站点脚本(XSS)攻击是一种注入攻击,恶意脚本会被注入到可信的网站中。当攻击者使用 web 应用程序将恶意代码(通常以浏览器端脚本的形式)发送给其他最终用户时,就会发生 XSS 攻击。允许这些攻击成…