yolov8训练自己的数据集遇到的问题

news2024/11/18 23:34:44

训练分类模型

1.如何更改模型的类别数nc

根据本地模型配置文件.yaml可以设置nc
在这里插入图片描述
但是,这里无法用到预训练模型.pt模型文件,预训练模型的权重参数是在大数据集上训练得到的,泛化性能可能比较好,所以,下载了官方的分类模型yolov8n-cls.pt

由于需要魔改yolov8,所以下载了官方源码,在default.yaml配置文件中是各种超参数,包括data与model路径的设置

将data路径设置为本地路径,或者调用数据配置文件coco128.yaml文件进行更改也可,这里有个问题,由于我也利用pip安装了yolov8的源码库,导致在利用coco128.yaml文件时,代码会在官方下载数据,导致报错,可能与这部分代码有关,目前暂未解释,直接定位本地文件夹更方便。

model路径设置为下载的yolov8n-cls.pt本地路径,或者模型配置文件yolov8n.yaml路径,这里选择前者。
问题来了,官方配置文件的nc是1000,由上图模型配置文件也可看出,先说方法,在task.py文件中找到attempt_load_one_weight这个函数,这个函数是用来下载.pt模型文件的,在train.py文件中的下面函数也可以看到

def setup_model(self):
        """
        load/create/download model for any task
        """
        # classification models require special handling

        if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
            return

        model = str(self.model)
        # Load a YOLO model locally, from torchvision, or from Ultralytics assets
        if model.endswith(".pt"):
            self.model, _ = attempt_load_one_weight(model, device='cpu')
            for p in self.model.parameters():
                p.requires_grad = True  # for training
        elif model.endswith(".yaml"):
            self.model = self.get_model(cfg=model)

解决办法:
attempt_load_one_weight这个函数中添加代码如下,

def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
    # Loads a single model weights
    ckpt = torch_safe_load(weight)  # load ckpt
    args = {**DEFAULT_CFG_DICT, **ckpt['train_args']}  # combine model and default args, preferring model args
    model = (ckpt.get('ema') or ckpt['model']).to(device).float()  # FP32 model

    ##########################change_nc###################################
    #方法一:
    nc = 5
    ch = 256
    m = model.model[-1]  # last layer
    #ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels  # ch into module
    c = Classify(ch, nc)  # Classify()
    c.i, c.f, c.type = m.i, m.f, 'models.common.Classify'  # index, from, type
    model.model[-1] = c  # replace
    print("############################################################")
    print(model)
    # print("model.layers:",len(model.layers))
    # for layer in model.layers[:-10]:
    #     layer.trainable = False
    #######################################################################

nc为自己数据的类别数,ch为模型最后一层的输入通道层数,这里由于模型没有layers参数,模型中卷积层没有in_channel参数,所以无法直接调用,所以咱不能进行冻结层训练,考虑到模型层数也不多,先暂且这样吧,ch可以将原.pt模型先转为onnx通过netron进行查看

这里要借鉴了ClassificationModel类下的_from_detection_model函数,根据onnx对比moudles.py文件中的Classify函数找到ch=256,因为这里的.pt模型文件中的最后一层为
![# YOLOv8.0n head
head:
  - [-1, 1, Classify, [nc]]  # Classify](https://img-blog.csdnimg.cn/9b83d6b87de54c26b044ed4eaac5b1aa.png)

class Classify(nn.Module):
    # YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        c_ = 1280  # efficientnet_b0 size
        self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
        self.pool = nn.AdaptiveAvgPool2d(1)  # to x(b,c_,1,1)
        self.drop = nn.Dropout(p=0.0, inplace=True)
        self.linear = nn.Linear(c_, c2)  # to x(b,c2)

    def forward(self, x):
        if isinstance(x, list):
            x = torch.cat(x, 1)
        x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
        return x if self.training else x.softmax(1)

注意:在训练完之后需要注释掉添加的代码,否则在export,val,predict模型的时候,比如onnx会调用模型下载函数,导致模型破坏。

在进行测试的时候,竟然发现原来没有改动类别数训练之后的模型发现虽然类别数是1000.但是在进行测试的时候,以及转换为onnx之后进行测试发现结果没有变化,甚至高一个点,什么鬼,据我猜测,数据没有发生变化,我直接设置的是本地路径,所以以经是设置为5类了,现在就剩模型最后一层的类别数不一样,也就是说在训练的时候,无论你模型设置多少类,只要数据集类数确定了,那么最后训练的结果是一样的,就好像多余的类别数默认输出为0,但是前面的类别数输出结果不变。

2.如何将pt模型的三通道改为单通道

1.输入层的修改,将模型第一个卷积层的输入通道数从3改为1,权重改为单通道,具体代码如下:
原模型第一层的conv的输入通道数为3,权重通道为[:, :3, :, :]
attempt_load_one_weight这个函数中添加代码如下,

	#修改输入通道数,权重通道数
    model.model[0].conv.in_channels = 1
    model.model[0].conv.weight = torch.nn.Parameter(model.model[0].conv.weight[:, :1, :, :])
    #model.model[0].conv.weight.data = model.model[0].conv.weight.data[:, :1, :, :]

2.修改数据集图片的维度,将其由三通道rgb转为单通道灰度图
在dataset.py文件中修改def getitem(self, i)函数

def __getitem__(self, i):
        f, j, fn, im = self.samples[i]  # filename, index, filename.with_suffix('.npy'), image
        if self.cache_ram and im is None:
            im = self.samples[i][3] = cv2.imread(f)
        elif self.cache_disk:
            if not fn.exists():  # load npy
                np.save(fn.as_posix(), cv2.imread(f))
            im = np.load(fn)
        else:  # read image
            im = cv2.imread(f)  # BGR
        if self.album_transforms:
            sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]#源代码
        else:
            #修改为单通道
            import numpy as np
            image=cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
            sample = self.torch_transforms(image)
            
            
            #sample = self.torch_transforms(im)#原代码

然后,torch_transforms函数索引到augment.py函数中的classify_transforms函数,在这个函数中分别修改CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)函数

def classify_transforms(size=224):
    # Transforms to apply if albumentations not installed
    if not isinstance(size, int):
        raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
    # T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
    return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])

CenterCrop(size), ToTensor()两个函数中分别添加im = np.array(im)[ :, :, np.newaxis]代码添加维度,否则会报错
T.Normalize(IMAGENET_MEAN, IMAGENET_STD)修改IMAGENET_MEAN, IMAGENET_STD参数的数值由三个通道改为单个通道数值

如果要进行验证集测试的话会遇到报错RuntimeError: Given groups=1, weight of size [16, 1, 3, 3], expected input[1, 3, 64, 64] to have 1 channels, but got 3 channels instead
将validator.py文件中__call__函数中

model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz))
修改为
model.warmup(imgsz=(1 if pt else self.args.batch, 1, imgsz, imgsz))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/444934.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flink+Kafka、Pulsar实现端到端的exactly-once语义

End-to-End Exactly-Once Processing in Apache Flink with Apache Kafka 2017年12月Apache Flink社区发布了1.4版本。该版本正式引入了一个里程碑式的功能:两阶段提交Sink,即TwoPhaseCommitSinkFunction。该SinkFunction提取并封装了两阶段提交协议中的…

【离散数学】测试五 图论

1. n层正则m叉树一共有()片树叶。 A. nm B. mn C. mn 正确答案: B 2. 下图是一棵最优二叉树 A. 对 B. 错 正确答案: B 3. 要构造权为1,4,9,16,25,36,49,64,81,100一棵最优二叉树,则必须先构造权为5,9,16,25,36,49,64,81,100一棵最优二叉树. A. 对 B. 错 …

视频剪辑必备,这6个网站承包你一年的音效素材

视频剪辑中需要用到各种声音、音效素材,这些音效不仅能让你的视频更丰富,还能更好的表达视频内容,传递情绪让观者感到共鸣。很多朋友剪辑过程中为了找到好的配乐、音效,往往会花费大量的时间,找到了还有可能受版权限制…

装机必备(二补充)--Win10系统盘,装Win10系统(无法引导启动问题-找不到任务设备驱动程序。请确保安装介质包含正确的驱动程序)

对于联想的thinkpad,开机时候按F1来更改bios设置,F12是选择U盘引导启动 thinkpad如何进入bios界面_thinkpad怎么进入u盘启动-系统城 1 F1界面1.按→方向键移动到Security,将secure boot改成disabled,关闭安全启动&…

【数据结构】简单快速过一遍红黑树

文章目录 红黑树1 红黑树的概念2 红黑树的性质3 红黑树节点的定义4 红黑树的插入操作5 红黑树的验证6 红黑树与AVL树的比较7.C实现红黑树 红黑树 1 红黑树的概念 ​ 红黑树,是一种二叉搜索树,但在每个结点上增加一个存储位表示结点的颜色,…

记一次oracle入库慢,log file switch (checkpoint incomplete)

AWR报告生成:Oracle AWR报告生成步骤_小百菜的博客-CSDN博客 发现log file switch (checkpoint incomplete) 这里出现了大量的log file switch(checkpoint incomplete)等待事件。 查看redo每个组的大小、状态 select group#,thread#,archived,status, bytes/102…

Python数据结构-----非递归实现快速排序

目录 前言: 非递归快排 1.概念原理 2.示例 Python代码实现 非递归快速排序 前言: 上一期我们学习了通过递归来实现快速排序的方法,那这一期我们就来一起学习怎么去通过非递归的方法来去实现快速排序的功能。(上一期连接Pytho…

新来一00后,给我卷崩溃了..

2022年已经结束结束了,最近内卷严重,各种跳槽裁员,相信很多小伙伴也在准备今年的金三银四的面试计划。 在此展示一套学习笔记 / 面试手册,年后跳槽的朋友可以好好刷一刷,还是挺有必要的,它几乎涵盖了所有的…

SLIC超像素分割算法

SLIC超像素分割算法 《SLIC Superpixels》 摘要 超像素在计算机视觉应用中越来越受欢迎。然而,很少有算法能够输出所需数量的规则、紧凑的超级像素,并且计算开销低。我们介绍了一种新的算法,将像素聚类在组合的五维颜色和图像平面空间中&a…

大四的告诫

👂 LOCK OUT - $atori Zoom/KALONO - 单曲 - 网易云音乐 👂 喝了一口星光酒(我只想爱爱爱爱你一万年) - 木小雅 - 单曲 - 网易云音乐 其实不是很希望这篇文章火,不然就更卷了。。 从大一开始,每天10小时…

ccf b类及以上会议(准备)

SoCCACM Symposium on Cloud Computing http://dblp.uni-trier.de/db/conf/cloud/ SimLess: simulate serverless workflows and their twins and siblings in federated FaaS.Pisces: efficient federated learning via guided asynchronous training论文截止时间&#xff1a…

实现自定义dialog样式

1定义弹出的dialog样式 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"android:orientation"vertical"android:layout_width"match_parent"a…

3. 排序

3. 排序 3.1 总纲 3.2 Comparable与Comparator接口介绍 由于我们这里要讲排序&#xff0c;所以肯定会在元素之间进行比较。规则的。在实际应用中&#xff0c;我们往往有需要比较两个自定义对象大小的地方。而这些自定义对象的比较&#xff0c;就不像简单的整型数据那么简单&a…

Python轻量级Web框架Flask(7)——翻页功能/多表操作

1、使用paginate实现分页&#xff1a; 基础指令&#xff08;新建对象&#xff09; 基础指令&#xff08;使用对象属性&#xff09; 2、几种类型的表操作&#xff1a; 一对一&#xff1a;例如一个人只能有一张身份证。一对多&#xff1a;例如班级和学生&#xff08;一个班级…

苦中作乐 ---竞赛刷题 完结篇(25分)

&#xff08;一&#xff09;目录 L2-014 列车调度 L2-024 部落 L2-033 简单计算器 L2-042 老板的作息表 L2-041 插松枝 &#xff08;二&#xff09;题目 L2-014 列车调度 火车站的列车调度铁轨的结构如下图所示。 两端分别是一条入口&#xff08;Entrance&#xff09;轨道…

Java分布式事务(十二)

文章目录 🔥Hmily实现TCC分布式事务_项目搭建🔥Hmily实现TCC分布式事务实战_公共模块🔥Hmily实现TCC分布式事务_集成Dubbo框架🔥Hmily实现TCC分布式事务_项目搭建 创建父工程tx-tcc 设置逻辑工程 <packaging>pom</packaging>创建公共模块 创建转出银行…

分析 | 通过 NFTScan 率先捕获 NFT 投资趋势

NFT 市场信息高度动态且机会稍纵即逝&#xff0c;了解市场第一信息对于 NFT 的参与者来说都是至关重要的。所以市场主体参与者必须密切关注各种渠道&#xff0c;努力获取最新一手 NFT 信息&#xff0c;这对参与者抓住先机和获益至关关键&#xff0c;若信息滞后&#xff0c;容易…

【流畅的Python学习笔记】2023.4.21

此栏目记录我学习《流畅的Python》一书的学习笔记&#xff0c;这是一个自用笔记&#xff0c;所以写的比较随意 特殊方法&#xff08;魔术方法&#xff09; 不管在哪种框架下写程序&#xff0c;都会花费大量时间去实现那些会被框架本身调用的方法&#xff0c;Python 也不例外。…

【Python】matplotlib设置图片边缘距离和plt.lengend图例放在图像的外侧

一、问题提出 我有这样一串代码&#xff1a; import matplotlib.pyplot as plt plt.figure(figsize (10, 6)) " 此处省略代码 " legend.append("J") plt.legend(legend) plt.xlabel(recall) plt.ylabel(precision) plt.grid() plt.show()我们得到的图像…

KMP算法原理原来这么简单

我觉得这句话说的很好&#xff1a; kmp算法关键在于&#xff1a;在当前对文本串和模式串检索的过程中&#xff0c;若出现了不匹配&#xff0c;如何充分利用已经匹配的部分&#xff0c;来继续接下来的检索。 暴力解决字符串匹配 暴力解法就是两层for循环,每次都一对一的匹配&…