【超详细】基于YOLOv11的PCB缺陷检测

news2024/10/9 14:40:52

主要内容如下:

1、数据集介绍
2、下载PCB数据集
3、不同格式数据集预处理(Json/xml),制作YOLO格式训练集
4、模型训练及可视化
5、Onnxruntime推理

运行环境:Python=3.8(要求>=3.8),torch1.12.0+cu113(要求>=1.8),onnxruntime-gpu=1.12.0
YOLO格式下载链接【可直接跳过步骤123】:https://aistudio.baidu.com/datasetdetail/297149

往期内容:

【超详细】跑通YOLOv8之深度学习环境配置1-Anaconda安装
【超详细】跑通YOLOv8之深度学习环境配置2-CUDA安装
【超详细】跑通YOLOv8之深度学习环境配置3-YOLOv8安装
【超详细】基于YOLOv8的PCB缺陷检测

0 YOLOv11

代码地址:https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/models/11
在这里插入图片描述

1 数据集介绍

1.1 简介

印刷电路板(PCB)瑕疵数据集:是一个公共的合成PCB数据集,由北京大学发布,其中包含1386张图像以及6种缺陷(缺失孔,鼠标咬伤,开路,短路,杂散,伪铜),用于检测,分类和配准任务。本文我们选取了其中适用与检测任务的693张图像,随机选择593张图像作为训练集,100张图像作为验证集。

1.2 示例

在这里插入图片描述

2 下载数据集

官方链接:https://robotics.pkusz.edu.cn/resources/dataset/
注意:百度网盘下载,速度很慢,不推荐!推荐去百度AI stduio数据集下载,速度快!
在这里插入图片描述
百度AI stduio下载链接:https://aistudio.baidu.com/datasetdetail/272346
在这里插入图片描述

3 制作YOLO格式训练集

具体见第3节3.2:【超详细】基于YOLOv8的PCB缺陷检测
YOLO格式下载链接【直接拿来训练Aistudio快速下载链接】:https://aistudio.baidu.com/datasetdetail/297149
在这里插入图片描述
在这里插入图片描述

4 模型训练及可视化

4.1 更新ultralytics

pip install --upgrade ultralytics

4.2 创建数据集yaml文件

注意:路径一定填对,类别与id一定要对应!!!
创建ultralytics\cfg\datasets\PCB.yaml文件,内容如下:

path: E:\\datasets\\PCB\\PCB_DATASET_YOLO # dataset root dir
train: images/train2017 # train images (relative to 'path') 4 images
val: images/val2017 # val images (relative to 'path') 4 images

# Classes for DOTA 1.0
names:
  0: missing_hole
  1: mouse_bite
  2: open_circuit
  3: short
  4: spur
  5: spurious_copper

4.3 创建一个训练脚本

在主目录下创建一个train.py,内容如下:

from ultralytics import YOLO

if __name__ == '__main__':
    # Load a model
    # model = YOLO("yolo11s.yaml")  # build a new model from scratch
    model = YOLO("yolo11s.pt")  # load a pretrained model (recommended for training)


    # Use the model
    model.train(data="PCB.yaml", imgsz=640, batch=16, workers=8, cache=True, epochs=100)  # train the model
    metrics = model.val()  # evaluate model performance on the validation set
    # results = model("ultralytics\\assets\\bus.jpg")  # predict on an image
    path = model.export(format="onnx", opset=13)  # export the model to ONNX format

问题1:若爆显存,降低batch和workers大小!

训练结果如下【比yolov8s略高】
在这里插入图片描述
在这里插入图片描述

4.4 创建一个预测脚本

在主目录下创建一个detect.py,内容如下:

from ultralytics import YOLO

if __name__ == '__main__':
    # Load a model
    model = YOLO(r"runs\detect\train\weights\best.pt")  # load model
    model.predict(source=r"E:\datasets\PCB\PCB_DATASET_YOLO\images\val2017\01_missing_hole_07.jpg", save=True, save_conf=True, save_txt=True, name='output')

预测结果
在这里插入图片描述

5 Onnxruntime推理及可视化

5.1 onnx推理

在主目录下创建一个onnx_infer.py,内容如下:

import argparse
import time 
import cv2
import numpy as np

import onnxruntime as ort  # 使用onnxruntime推理用上,pip install onnxruntime-gpu==1.12.0 -i  https://pypi.tuna.tsinghua.edu.cn/simple,默认安装CPU
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

class YOLOv11:
    """YOLOv11 object detection model class for handling inference and visualization."""

    def __init__(self, onnx_model, imgsz=(640, 640)):
        """
        Initialization.

        Args:
            onnx_model (str): Path to the ONNX model.
        """
        
        # 构建onnxruntime推理引擎
        self.ort_session = ort.InferenceSession(onnx_model,
                                            providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
                                            if ort.get_device() == 'GPU' else ['CPUExecutionProvider'])
        print(ort.get_device())
        # Numpy dtype: support both FP32 and FP16 onnx model
        self.ndtype = np.half if self.ort_session.get_inputs()[0].type == 'tensor(float16)' else np.single
       
        self.model_height, self.model_width = imgsz[0], imgsz[1]  # 图像resize大小
     

    def __call__(self, im0, conf_threshold=0.4, iou_threshold=0.45):
        """
        The whole pipeline: pre-process -> inference -> post-process.

        Args:
            im0 (Numpy.ndarray): original input image.
            conf_threshold (float): confidence threshold for filtering predictions.
            iou_threshold (float): iou threshold for NMS.

        Returns:
            boxes (List): list of bounding boxes.
        """
        # 前处理Pre-process
        t1 = time.time()
        im, ratio, (pad_w, pad_h) = self.preprocess(im0)
        pre_time = round(time.time() - t1, 3)
        # print('det预处理时间:{:.3f}s'.format(time.time() - t1))
        
        # 推理 inference
        t2 = time.time()
        preds = self.ort_session.run(None, {self.ort_session.get_inputs()[0].name: im})[0]
        # print('det推理时间:{:.2f}s'.format(time.time() - t2))
        det_time = round(time.time() - t2, 3)
        
        # 后处理Post-process
        t3 = time.time()
        boxes = self.postprocess(preds,
                                im0=im0,
                                ratio=ratio,
                                pad_w=pad_w,
                                pad_h=pad_h,
                                conf_threshold=conf_threshold,
                                iou_threshold=iou_threshold,
                                )
        # print('det后处理时间:{:.3f}s'.format(time.time() - t3))
        post_time = round(time.time() - t3, 3)

        return boxes, (pre_time, det_time, post_time)
        
    # 前处理,包括:resize, pad, HWC to CHW,BGR to RGB,归一化,增加维度CHW -> BCHW
    def preprocess(self, img):
        """
        Pre-processes the input image.

        Args:
            img (Numpy.ndarray): image about to be processed.

        Returns:
            img_process (Numpy.ndarray): image preprocessed for inference.
            ratio (tuple): width, height ratios in letterbox.
            pad_w (float): width padding in letterbox.
            pad_h (float): height padding in letterbox.
        """
        # Resize and pad input image using letterbox() (Borrowed from Ultralytics)
        shape = img.shape[:2]  # original image shape
        new_shape = (self.model_height, self.model_width)
        r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
        ratio = r, r
        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
        pad_w, pad_h = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2  # wh padding
        if shape[::-1] != new_unpad:  # resize
            img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
            
        top, bottom = int(round(pad_h - 0.1)), int(round(pad_h + 0.1))
        left, right = int(round(pad_w - 0.1)), int(round(pad_w + 0.1))
        img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))  # 填充

        # Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional)
        img = np.ascontiguousarray(np.einsum('HWC->CHW', img)[::-1], dtype=self.ndtype) / 255.0
        img_process = img[None] if len(img.shape) == 3 else img
        return img_process, ratio, (pad_w, pad_h)
    
    # 后处理,包括:阈值过滤与NMS
    def postprocess(self, preds, im0, ratio, pad_w, pad_h, conf_threshold, iou_threshold):
        """
        Post-process the prediction.

        Args:
            preds (Numpy.ndarray): predictions come from ort.session.run().
            im0 (Numpy.ndarray): [h, w, c] original input image.
            ratio (tuple): width, height ratios in letterbox.
            pad_w (float): width padding in letterbox.
            pad_h (float): height padding in letterbox.
            conf_threshold (float): conf threshold.
            iou_threshold (float): iou threshold.

        Returns:
            boxes (List): list of bounding boxes.
        """
        x = preds  # outputs: predictions (1, 84, 8400)
        # Transpose the first output: (Batch_size, xywh_conf_cls, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls)
        x = np.einsum('bcn->bnc', x)  # (1, 8400, 84)
   
        # Predictions filtering by conf-threshold
        x = x[np.amax(x[..., 4:], axis=-1) > conf_threshold]

        # Create a new matrix which merge these(box, score, cls) into one
        # For more details about `numpy.c_()`: https://numpy.org/doc/1.26/reference/generated/numpy.c_.html
        x = np.c_[x[..., :4], np.amax(x[..., 4:], axis=-1), np.argmax(x[..., 4:], axis=-1)]

        # NMS filtering
        # 经过NMS后的值, np.array([[x, y, w, h, conf, cls], ...]), shape=(-1, 4 + 1 + 1)
        x = x[cv2.dnn.NMSBoxes(x[:, :4], x[:, 4], conf_threshold, iou_threshold)]
       
        # 重新缩放边界框,为画图做准备
        if len(x) > 0:
            # Bounding boxes format change: cxcywh -> xyxy
            x[..., [0, 1]] -= x[..., [2, 3]] / 2
            x[..., [2, 3]] += x[..., [0, 1]]

            # Rescales bounding boxes from model shape(model_height, model_width) to the shape of original image
            x[..., :4] -= [pad_w, pad_h, pad_w, pad_h]
            x[..., :4] /= min(ratio)

            # Bounding boxes boundary clamp
            x[..., [0, 2]] = x[:, [0, 2]].clip(0, im0.shape[1])
            x[..., [1, 3]] = x[:, [1, 3]].clip(0, im0.shape[0])

            return x[..., :6]  # boxes
        else:
            return []


if __name__ == '__main__':
    # Create an argument parser to handle command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--det_model', type=str, default=r"E:\Code\yolov11\ultralytics-main\runs\detect\train\weights\best.onnx", help='Path to ONNX model')
    parser.add_argument('--source', type=str, default=str(r'E:\datasets\PCB\PCB_DATASET_YOLO\images\val2017'), help='Path to input image')
    parser.add_argument('--out_path', type=str, default=str(r'E:\Code\yolov11\ultralytics-main\runs\detect\res'), help='结果保存文件夹')
    parser.add_argument('--imgsz_det', type=tuple, default=(640, 640), help='Image input size')
    parser.add_argument('--classes', type=list, default=['missing_hole', 'mouse_bite', 'open_circuit', 'short', 'spur', 'spurious_copper'], help='类别')

    parser.add_argument('--conf', type=float, default=0.25, help='Confidence threshold')
    parser.add_argument('--iou', type=float, default=0.6, help='NMS IoU threshold')
    args = parser.parse_args()

    if not os.path.exists(args.out_path):
        os.mkdir(args.out_path)
    print('开始运行:')
    # Build model
    det_model = YOLOv11(args.det_model, args.imgsz_det)
    color_palette = np.random.uniform(0, 255, size=(len(args.classes), 3))  # 为每个类别生成调色板
    
    for i, img_name in enumerate(os.listdir(args.source)):
        try:
            start_time = time.time()
            # Read image by OpenCV
            img = cv2.imread(os.path.join(args.source, img_name))

            # 检测Inference
            boxes, (pre_time, det_time, post_time) = det_model(img, conf_threshold=args.conf, iou_threshold=args.iou)
            print('{}/{} ==>总耗时间: {:.3f}s, 其中, 预处理: {:.3f}s, 推理: {:.3f}s, 后处理: {:.3f}s, 识别{}个目标'.format(i+1, len(os.listdir(args.source)), time.time() - start_time, pre_time, det_time, post_time, len(boxes)))
            
            for (*box, conf, cls_) in boxes:
                cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
                                color_palette[int(cls_)], 2, cv2.LINE_AA)
                cv2.putText(img, f'{args.classes[int(cls_)]}: {conf:.3f}', (int(box[0]), int(box[1] - 9)),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
            cv2.imwrite(os.path.join(args.out_path, img_name), img)
        
        except Exception as e:
            print(e)   

资源消耗:显存不到1.5G,推理速度5ms.

5.2 结果可视化

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2195427.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ubuntu ssh远程执行k8s命令报错the connection to the server localhost:8080 was refused

修改前: ssh root192.168.31.167 kubectl apply -f /root/jenkinsexcute/saas.demo.api.k8s.yml --recordecho "export KUBECONFIG/etc/kubernetes/admin.conf" >> /root/.bashrc 修改后 添加一段:export KUBECONFIG/etc/kubernetes/a…

【专题】操作系统概述

1. 操作系统的目标和作用 操作系统的目标与应用环境有关。 在查询系统中所用的OS,希望能提供良好的人—机交互性; 对于应用于工 业控制、武器控制以及多媒体环境下的OS,要求其具有实时性; 对于微机上配置的OS,则更看…

什么是强基计划?

“强基计划”是中国教育部于2020年推出的一项全新的高等教育招生改革计划,旨在通过更加科学、公正的选拔机制,选拔出有志于基础学科并具备扎实学科功底、创新潜质的优秀学生,从而推动国家基础学科的发展,提升自主创新能力。与传统…

【自动驾驶】UniAD代码解析

1.参考 论文:https://arxiv.org/pdf/2212.10156 代码:https://github.com/OpenDriveLab/UniAD 2.环境配置 docs/INSTALL.md (1)虚拟conda环境 conda create -n uniad python3.8 -y conda activate uniad (2&#…

微信小程序和抖音小程序的分享和广告接入代码

开发完成小程序或者小游戏之后,我们为什么要接入分享和广告视频功能,主要原因有以下几个方面。 微信小程序和抖音小程序接入分享和广告功能主要基于以下几个原因: 用户获取与增长:分享功能可以帮助用户将小程序内容传播给更多人&…

C语言刷题--有关闰年

满足以下一种即是闰年 能被4整除,但不能被100整除能被400整除 //输入年,月,输出该月的天数 //1月 31,28,31,30,31,30,31,31,30,31,30,31int is_leap_year(int y) {if (((y % 4 0) && (y % 100 ! 0)) || y % 400 0)return 1;return…

步进电机和步进电机驱动器详解

一、步进电机的概念 步进电机是通过步进电机配套的驱动器,将控制器传来的脉冲信号转换成角位移的开环电机(没有反馈)。 步进电机工作时的位置和速度信号不反馈给控制系统,如果电机工作时的位置和速度信号反馈给控制系统&#xff…

《从零开始大模型开发与微调》真的把大模型说透了!零基础入门一定要看!

2022年底,ChatGPT震撼上线,大语言模型技术迅速“席卷”了整个社会,人工智能技术因此迎来了一次重要进展。与大语言模型相关的研发岗薪资更是水涨船高,基本都是5w月薪起。很多程序员也想跟上ChatGPT脚步,今天给大家带来…

apache.poi读取.xls文件时The content of an excel record cannot exceed 8224 bytes

目录 问题描述版本定位:打印size最大的Record定位:RefSubRecord解决代码 问题描述 使用apache.poi读取.xls文件时有The content of an excel record cannot exceed 8224 bytes的报错。待读取的文件的内容也是通过apache.poi写入的,我的文件修…

【sqlmap】sqli-labs速通攻略

sqli-labs工具速通 Less-1 sqlmap -u http://127.0.0.1:8081/Less-1/?id1 --batch --dbs sqlmap -u http://127.0.0.1:8081/Less-1/?id1 --batch -D security --tables sqlmap -u http://127.0.0.1:8081/Less-1/?id1 --batch -D security -T users --columns sqlmap -u ht…

购物清单 | 双十一加购率最高好物合集,数码购物车必备!

​双十一来临,小伙伴们肯定已经被种草了很多很多清单,开始买买买了!但是,作为一个数码博主,怎么能少了数码产品!今天我给大家准备了一份数码人专属的购物清单,快来看看吧! 运动耳机…

[水墨:创作周年纪念] 特别篇!

本篇是特别篇!! 个人主页水墨不写bug // _ooOoo_ // // o8888888o // // 88" . "88 …

如何方便地打出「」和『』

比起英文中的引号 ‘’和 “”,我更喜欢使用中文直角引号:「」和 『』。 此外,在港澳台、日本这几个地区中,就经常使用『』和「」: ​ ‍ 注意:不同地区的习惯可能有所不同。在汉语中『』、「」分别为双…

数学公式编辑器免费版下载,mathtype和latex哪个好用

选择适合自己的公式编辑器需要考虑多个因素。首先,您需要确定编辑器支持的功能和格式是否符合您的需求,例如是否可以插入图片、导出各种文件格式等。其次,您可以考虑编辑器的易用性和界面设计是否符合您的个人喜好。另外,您还可以…

面向对象特性中 继承详解

目录 概念: 定义: 定义格式 继承关系和访问限定符 基类和派生类对象赋值转换: 继承中的作用域: 派生类的默认成员函数 继承与友元: 继承与静态成员: 复杂的菱形继承及菱形虚拟继承: 虚…

手机号归属地查询-手机号归属地-手机号归属地-运营商归属地查询-手机号码归属地查询手机号归属地-运营商归属地

手机号归属地查询API接口是一种网络服务接口,允许开发者通过编程方式查询手机号码的注册地信息。关于快证签API接口提供的手机号归属地查询服务,以下是一些关键信息: 一、快证签API接口简介 快证签API接口可能是一个提供多种验证和查询服务…

「自动化测试」Selenium 的使用

使用 Selenium 需要先导入相关依赖 <dependency> <groupId>org.seleniumhq.selenium</groupId> <artifactId>selenium-java</artifactId> <version>4.0.0</version> </dependency><dependency><groupId>io.gith…

免费录屏神器!这四款软件让你快捷录屏~

随着技术的进步&#xff0c;免费的录屏软件如雨后春笋般涌现&#xff0c;为我们的工作、学习和娱乐提供了极大的便利。今天&#xff0c;就让我来为大家推荐几款备受好评的免费录屏软件&#xff0c;并分享一下使用感受吧&#xff01; 一、福昕录屏 直通车&#xff08;粘贴到浏览…

OJ在线评测系统 微服务高级 Gateway网关接口路由和聚合文档 引入knife4j库集中查看管理并且调试网关项目

Gateway微服务网关接口路由 各个服务之间已经能相互调用了 为什么需要网关 因为我们的不同服务是放在不同的端口上面的 如果前端调用服务 需要不同的端口 8101 8102 8103 8104 我们最好提供一个唯一的 给前端去调用的路径 我们学习技术的时候必须要去思考 1.为什么要用&am…

回溯算法:一个模板解决排列组合问题

回溯算法 在初遇排列组合题目时&#xff0c;总让人摸不着头脑&#xff0c;但是当我做了很多题目后&#xff0c;发现几乎能用同一个模板做完所有这种类型的题目&#xff0c;大大提高了解题效率。回溯是递归的副产品&#xff0c;只要有递归就会有回溯。 回溯法很难&#xff0c;…