【图像分类】基于yolov5的钢板表面缺陷分类(附代码和数据集)

news2024/11/18 23:36:46

写在前面:
首先感谢兄弟们的订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌。

Hello,大家好,我是augustqi。

今天给大家分享一个图像分类项目:基于yolov5的钢板表面缺陷分类(附代码和数据集),在之前的文章中,我们使用yolov7算法对钢板表面缺陷进行检测,属于目标检测范畴,鉴于yolov5-6.2和yolov5-7.0版本中增加了图像分类功能,我们尝试使用yolov5做一个图像分类任务,使用的是钢板表面缺陷数据集。

钢板表面缺陷检测文章回顾:
【表面缺陷检测】基于yolov7的钢板表面缺陷检测(Ubuntu系统)
【表面缺陷检测】基于yolov7的钢板表面缺陷检测(附代码)

本项目也是从零开始制作自己的数据集,学会如何使用yolov5训练图像分类模型的保姆级教程,多的不说,少的不唠,下面我们一起来学习一下吧。

以下内容,完全是我根据参考资料和个人理解撰写出来的,不存在滥用原创的问题。

1. 引言

从事深度学习行业的都知道,yolov5是一个很优秀的目标检测框架,代码已开源至github,截止到2023年1月14日,获得了34.5k Star,12.5k Fork,妥妥的人类高质量项目。项目也在持续更新,并不断完善和增加了新的功能,例如在yolov5-6.2和yolov5-7.0版本中增加了图像分类功能,yolov5-7.0版本中增加了图像分割功能。

资料显示,官方使用4张A100显卡,在ImageNet数据集上训练了90个epoch,得到了YOLOv5-cls分类模型,同时训练了ResNet和EfficientNet模型进行比较,YOLOv5-cls分类模型取得了不错的结果。

2. 背景

图像分类任务是计算机视觉领域的核心任务之一,其目标就是根据图像信息中所反映的不同特征,将不同类别的图像区分开来。钢厂在生产钢板的时候,由于工艺或者现场因素原因,有的钢板表面会产生缺陷,通过对钢板表面缺陷类别进行分类和统计,从而分析缺陷产生的原因,对进一步改善工艺,降低次品率有很大的帮助。

使用人工对每天生产的钢板进行缺陷检测和分类,不仅费时费力,而且很容易漏检和错检。基于计算机视觉的方法对图像进行分类现在已经很成熟了,目前比较主流的图像分类网络有ResNet、DenseNet、EfficientNet等。YOLOv5是目标检测方向的一个主流框架,但yolov5-6.2版本增加了图像分类功能,从官方公布的实验结果来看,取得了不错的效果,本项目首次基于yolov5训练钢板表面缺陷分类模型。

3. 数据

国内外的工业界和学术界目前开源了几个钢板表面缺陷数据集,说实话,部分数据集的中图片的数量和质量还是有待提高的,但是想一想整理并开源一个数据集耗费大量的人力和物力,成本是巨大的,而且这些数据都是商业数据,也是保密的,能够免费使用目前开源的数据已经很幸运了,且用且珍惜吧。其实本人也参与过钢板表面缺陷检测项目,深入钢厂在钢板生产现场收集了某个钢种的一些钢板表面缺陷图片,也和现场业务员进行沟通,对缺陷图片进行了标注,整理了一份包含1600张缺陷图片的数据集,总共4类缺陷,每类缺陷400张图片,但是这是商业数据,也是保密数据,无法公开。其实,我很想继续把这个数据集进行扩充,收集更多的缺陷图片,增加更多的缺陷类别,当有一天这个数据成果公开后,可以给工业界和学术界带来更多的图像分类、目标检测和图像分割研究成果,希望这一天早日到来吧。

本项目中使用的图像分类数据集,是东北大学带钢表面缺陷检测数据集,这个数据集在之前的“【表面缺陷检测】表面缺陷检测数据集汇总”介绍过,数据集中包含夹杂、划痕、压入氧化皮、裂纹、麻点和斑块6种缺陷,每种缺陷300张,图像尺寸为200×200。

英文名称中文名称图片数图片尺寸
crazing裂纹300200×200
inclusion夹杂300200×200
patches斑块300200×200
pitted_surface麻点300200×200
rolled-in_scale压入氧化皮300200×200
scratches划痕300200×200

4. 代码

4.1 项目搭建

下载代码,将代码上传到服务器,也可以是使用本地的Windows系统进行训练:

我们只要用到的是classify文件夹下的train.py、val.py和predict.py脚本代码:

train.py:训练脚本
val.py:评估脚本
predict.py:推理脚本

4.2 核心代码

主干网络使用yolov5s、yolov5x、yolov5m、yolov5n、yolov5l:

class ClassificationModel(BaseModel):
    # YOLOv5 classification model
    def __init__(self, cfg=None, model=None, nc=1000, cutoff=10):  # yaml, model, number of classes, cutoff index
        super().__init__()
        self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)

    def _from_detection_model(self, model, nc=1000, cutoff=10):
        # Create a YOLOv5 classification model from a YOLOv5 detection model
        if isinstance(model, DetectMultiBackend):
            model = model.model  # unwrap DetectMultiBackend
        model.model = model.model[:cutoff]  # backbone
        m = model.model[-1]  # last layer
        ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels  # ch into module
        c = Classify(ch, nc)  # Classify()
        c.i, c.f, c.type = m.i, m.f, 'models.common.Classify'  # index, from, type
        model.model[-1] = c  # replace
        self.model = model.model
        self.stride = model.stride
        self.save = []
        self.nc = nc

    def _from_yaml(self, cfg):
        # Create a YOLOv5 classification model from a *.yaml file
        self.model = None

class Classify(nn.Module):
    # YOLOv5 classification head, i.e. x(b,c1,20,20) to x(b,c2)
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        c_ = 1280  # efficientnet_b0 size
        self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
        self.pool = nn.AdaptiveAvgPool2d(1)  # to x(b,c_,1,1)
        self.drop = nn.Dropout(p=0.0, inplace=True)
        self.linear = nn.Linear(c_, c2)  # to x(b,c2)

    def forward(self, x):
        if isinstance(x, list):
            x = torch.cat(x, 1)
        return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))

损失函数:

def smartCrossEntropyLoss(label_smoothing=0.0):
    # Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
    if check_version(torch.__version__, '1.10.0'):
        return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    if label_smoothing > 0:
        LOGGER.warning(f'WARNING ⚠️ label smoothing {label_smoothing} requires torch>=1.10.0')
    return nn.CrossEntropyLoss()

优化器:

def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
    # YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
    g = [], [], []  # optimizer parameter groups
    bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d()
    for v in model.modules():
        for p_name, p in v.named_parameters(recurse=0):
            if p_name == 'bias':  # bias (no decay)
                g[2].append(p)
            elif p_name == 'weight' and isinstance(v, bn):  # weight (no decay)
                g[1].append(p)
            else:
                g[0].append(p)  # weight (with decay)

    if name == 'Adam':
        optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999))  # adjust beta1 to momentum
    elif name == 'AdamW':
        optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
    elif name == 'RMSProp':
        optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
    elif name == 'SGD':
        optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
    else:
        raise NotImplementedError(f'Optimizer {name} not implemented.')

    optimizer.add_param_group({'params': g[0], 'weight_decay': decay})  # add g0 with weight_decay
    optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})  # add g1 (BatchNorm2d weights)
    LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
                f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
    return optimizer

5. 训练

训练命令:

python classify/train.py --model weights/yolov5s-cls.pt --data data_custom --epochs 100  --batch-size 32 --imgsz 224 

开始训练:

结束训练:

生成的文件:

我使用了一张Tesla P100显卡,训练了100 epoch,仅用了不到6分钟的时间,真的是相当的快,而且top1和top5精度都达到了100%。

6. 评估和推理

6.1 评估

评估代码:

python classify/val.py --weights runs/train-cls/exp/weights/best.pt --data data_custom

评估结果:

6.2 推理

推理代码:

# 测试im1.jpg
python classify/predict.py --weights runs/train-cls/exp/weights/best.pt --source im1.jpg

# 测试im2.jpg
python classify/predict.py --weights runs/train-cls/exp/weights/best.pt --source im2.jpg

推理结果:

7. 导出

7.1 ONNX

执行命令导出onnx:

python export.py --weights runs/train-cls/exp/weights/best.pt --include onnx

输出:

Detect:          python classify/predict.py --weights runs/train-cls/exp/weights/best.onnx
Validate:        python classify/val.py --weights runs/train-cls/exp/weights/best.onnx
PyTorch Hub:     model = torch.hub.load('ultralytics/yolov5', 'custom', 'runs/train-cls/exp/weights/best.onnx')  # WARNING ⚠️ ClassificationModel not yet supported for PyTorch Hub AutoShape inference
Visualize:       https://netron.app

我们使用onnx进行模型部署,也可以把它当作中间件进行模型转换,根据项目需求进行选择。

7.2 TensorRT

执行命令导出engine:

python export.py --weights runs/train-cls/exp/weights/best.pt --include engine --device 0

输出:

Detect:          python classify/predict.py --weights runs/train-cls/exp/weights/best.engine
Validate:        python classify/val.py --weights runs/train-cls/exp/weights/best.engine
PyTorch Hub:     model = torch.hub.load('ultralytics/yolov5', 'custom', 'runs/train-cls/exp/weights/best.engine')  # WARNING ⚠️ ClassificationModel not yet supported for PyTorch Hub AutoShape inference
Visualize:       https://netron.app

我们导出的engine模型,可以使用英伟达的TensorRT框架进行部署,加速模型推理速度。

8. 结论

本次项目基于yolov5对钢板表面缺陷进行分类,从评估指标来看,验证集上准确率很高,取得了很好的结果。模型训练速度也很快,模型导出也很方便和友好。目前,还可以做的工作是使用多GPU对模型进行训练,对导出的模型进行部署。但是,虽然分类精度很高,但是我们不知道缺陷的具体位置,没办法对缺陷进行定位,目前的想法是使用卷积热力图,用于突出图像的类的特定区域,不知道是否可行,还需进一步验证。总后絮叨,yolov5你真强!截止发文前,ultralytics公司目前已将yolov8开源了,yolov8将在江湖上引起腥风血雨。

如果您觉得这篇文章对您有一点点的帮助和启发,希望您关注公众号,并点赞、转发。您可以联系我获取项目中的数据集和代码,数据整理和代码调试不易,公众号运营困难,为了公众号的正常运营,提供有偿指导,感谢您的理解和支持,祝好。

联系方式:公众号底部菜单栏–关于我–与我联系【订阅CSDN专栏的朋友,请加我v,发您数据和代码,不贴出代码和数据集的链接是为了防止爬虫,望理解】

参考资料

[1]https://github.com/ultralytics/yolov5
[2]https://blog.csdn.net/AugustMe/article/details/128111977

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/163625.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ArcGIS基础实验操作100例--实验100三维可视性分析

本实验专栏参考自汤国安教授《地理信息系统基础实验操作100例》一书 实验平台:ArcGIS 10.6 实验数据:请访问实验1(传送门) 空间分析篇--实验100 三维可视性分析 目录 一、实验背景 二、实验数据 三、实验步骤 (1&a…

JavaScript---DOM---高级事件---1.8

注册事件&#xff08;绑定事件&#xff09; 给元素添加事件称为注册事件或绑定事件。注册事件有两种方式&#xff1a;传统方式、方法监听注册方式。 传统注册方式&#xff1a; 利用on开头的事件onclick&#xff1a; <button onclick"alert(hi~)"></butt…

测试用例具体的设计方法

等价类法由于输入的集合是无穷的&#xff0c;不能全部覆盖到&#xff0c;所以通过划分若干个等价类&#xff0c;选出有代表性的达到尽量多的功能覆盖有效等价类&#xff1a;根据规格说明书是合理的、有意义的输入数据构成的集合无效等价类&#xff1a;根据需求说明书是不合理&a…

246页10万字省级政务专用云项目技术方案

【版权声明】本资料来源网络&#xff0c;知识分享&#xff0c;仅供个人学习&#xff0c;请勿商用。【侵删致歉】如有侵权请联系小编&#xff0c;将在收到信息后第一时间删除&#xff01;完整资料领取见文末&#xff0c;部分资料内容&#xff1a; 目录 对本项目的技术服务类总体…

【Java AWT 图形界面编程】LayoutManager 布局管理器 ⑦ ( Box 容器 | Box 容器中添加分割 )

文章目录一、Box 容器二、Box 容器 API三、Box 容器代码示例四、Box 容器中添加分割一、Box 容器 为了 方便使用 BoxLayout 布局 , Swing 中提供了 Box 容器 ; Box 容器 默认的 布局管理器 就是 BoxLayout ; 通过在 Box 容器构造函数中传入不同的参数 , 可以直接创建 水平排列…

Java基础语法(一)

注释1.1注释概述注释是在程序指定位置添加的说明性信息注释不参与程序运行&#xff0c;仅起到说明作用1.2注释分类单行注释格式&#xff1a;//注释信息多行注释格式&#xff1a;/*注释信息*/文档注释格式&#xff1a;/**注释信息*/文档注释目前用不上&#xff0c;暂不讲解/* Ja…

C++入门

目录 1. 命名空间 1.1 命名空间的定义 1.2 命名空间的使用 2. C的输入输出 3. 缺省参数 3.1 缺省参数概念 3.2 缺省参数分类 4.函数重载 4.1 函数重载概念 4.2 C支持函数重载的原理——名字修饰 5. 引用 5.1 引用概念 5.2 引用特性 5.3 常引用 5.4 使用场景 5.5 引用…

什么是测试金字塔?如何使用测试金字塔来构建自动化测试体系?

测试金字塔 &#xff08;Test Pyramid&#xff09;是一套使用单元测试&#xff0c;集成测试和端到端测试来构建自动化测试体系的方法。 如下图所示&#xff0c;在金字塔的最下方是单元测试&#xff0c;中段是集成测试&#xff0c;最上方是端到端测试。单元测试实现的成本最低&…

Android 深入系统完全讲解(17)

这个就是我们在初始化的时候给对应的属性设置上下文。chcon 这个可以修改上下文。 我们在遇见类似的属性读取不到的时候&#xff0c;一般操作是&#xff1a; getprop -z 看下属性的上下文&#xff0c;然后 ps -z 看下进程的上下文&#xff0c;然后判断出来是否有对应的 权限&am…

1. PyTorch是什么?

这篇博客将介绍PyTorch深度学习库&#xff0c;包括&#xff1a; PyTorch是什么如何安装PyTorch重要的PyTorch功能&#xff0c;包括张量和自动标记PyTorch如何支持GPU为什么PyTorch在研究人员中如此受欢迎PyTorch是否优于Keras/TensorFlow是否应该在项目中使用PyTorch或Keras/T…

ArcGIS10.2保姆式安装教程,超详细;附安装包

安装前请关闭杀毒软件&#xff0c;系统防火墙&#xff0c;断开网络连接 参考链接&#xff1a;请点击 下载链接&#xff1a; 通过百度网盘分享的文件&#xff1a;ArcGIS10.2zip 链接:https://pan.baidu.com/s/1s_xc1HvmMdo4fnnUo97ldA 提取码:v74k 复制这段内容打开「百度网盘A…

2022年11月下午案例分析真题及答案解析

试题一&#xff08;共15分&#xff09;&#xff08;202211&#xff09; 阅读下列说明和图&#xff0c;回答问题1至问题4&#xff0c;将解答填入答题纸的对应栏内。 【说明】 随着新能源车数量的迅猛增长&#xff0c;全国各地电动汽车配套充电桩急速增长&#xff0c;同时也带…

2023年网络安全比赛--Linux系统渗透提权中职组(超详细)

一、竞赛时间 180分钟 共计3小时 二、竞赛阶段 竞赛阶段 任务阶段 竞赛任务 竞赛时间 分值 1.使用渗透机对服务器信息收集,并将服务器中SSH服务端口号作为flag提交; 2.使用渗透机对服务器信息收集,并将服务器中主机名称作为flag提交; 3.使用渗透机对服务器信息收集,并将服…

数据科学家必备的 3 个 Jupyter Notebook 扩展

如果您是数据科学家、机器学习工程师或任何其他类型的数据专业人员&#xff0c;您可能已经花了很多时间使用 Jupyter 笔记本。虽然 Jupyter notebooks 已经是一个强大的工具&#xff0c;但还有许多扩展可以进一步增强您的体验。 在本文中&#xff0c;我们将向您介绍三个最有用…

活动星投票创心服务网络评选微信的投票方式线上免费投票

“创心服务”网络评选投票_视频投票评选小程序_线实时投票小程序_微信投票链接创建现来说&#xff0c;公司、企业、学校更多地想借助短视频推广自己。通过微信投票小程序&#xff0c;网友们就可以通过手机拍视频上传视频参加活动&#xff0c;而短视频微信投票评选活动既可以给用…

表单验证的简单实现

表单验证一. 作用二. 需求三. 实现需求一&#xff1a;HTML&#xff1a;JavaScript&#xff1a;需求二&#xff1a;JavaScript&#xff1a;一. 作用 如果没有表单验证&#xff0c;错误的数据就会发往服务端&#xff0c;会造成服务端压力过大&#xff1b; 所以在前端对数据进行过…

ArcGIS基础实验操作100例--实验98计算上游集水区污染值

本实验专栏参考自汤国安教授《地理信息系统基础实验操作100例》一书 实验平台&#xff1a;ArcGIS 10.6 实验数据&#xff1a;请访问实验1&#xff08;传送门&#xff09; 空间分析篇--实验98 计算上游集水区污染值 目录 一、实验背景 二、实验数据 三、实验步骤 &#xff0…

【阶段三】Python机器学习28篇:机器学习项目实战:KMeans算法的基本原理与KMeans聚类分群模型

本篇的思维导图: KMeans模型 KMeans算法的基本原理 KMeans算法名称中的K代表类别数量,Means代表每个类别内样本的均值,所以KMeans算法又称为K-均值算法。KMeans算法以距离作为样本间相似度的度量标准,将距离相近的样本分配至同一个类别。样本间距离的计算方式可以是…

QListWidget 自定义 item的图标和文字的位置

目录前言思路一思路二思路二缺陷思路三思路四前言 楼主并没有完整的解决这个问题&#xff0c;如果你是着急寻找解决方案的就可以划走了&#xff0c;如果你对楼主的解决思路有兴趣&#xff0c;那么可以继续向下阅读。首先需求是可以控制QListWidgetItem的icon和text x轴的位置&…

【树】树、二叉树的基础知识

树定义&#xff1a;树是n&#xff08;n≥0&#xff09;个结点的有限集合T。当n0时&#xff0c;称为空树&#xff1b;当n>0时&#xff0c;该集合满足如下条件&#xff1a; (1) 其中必有一个称为根&#xff08;root&#xff09;的特定结点&#xff0c;它没有直接前驱&#xff…