如何用mmclassification训练多标签多分类数据

news2024/12/24 7:10:29

这里使用的源码版本是 mmclassification-0.25.0
训练数据标签文件格式如下,每行的空格前面是路径(图像文件所在的绝对路径),后面是标签名,因为特殊要求这里我的每张图像都记录了三个标签每个标签用“,”分开(具体看自己的需求),我的训练标签数量是17个。
在这里插入图片描述
训练参数配置文件,用ResNet作为特征提取主干,多标签分类要使用MultiLabelLinearClsHead作为分类头。数据集的格式使用CustomDataset,并修改该结构的定义文件,后面有详细内容。

# checkpoint saving
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
    interval=100,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
optimizer = dict(lr=0.1, momentum=0.9, type='SGD', weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
runner = dict(max_epochs=100, type='EpochBasedRunner')
lr_config = dict(
    policy='step', step=[
        30,
        60,
        90,
    ])

model = dict(
    type='ImageClassifier',
    backbone=dict(type='ResNet',depth=18,num_stages=4,out_indices=(3, ),style='pytorch'), 
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='MultiLabelLinearClsHead',
        num_classes=17,
        in_channels=512,
    ))

dataset_type = 'CustomDataset'          #'MultiLabelDataset'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', size=224),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(256, -1)),
    dict(type='CenterCrop', crop_size=224),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]

data = dict(
    samples_per_gpu=32,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_prefix='rootpath/images',
        ann_file='rootpath/train.txt',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_prefix='rootpath/images',
        ann_file='rootpath/val.txt',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        data_prefix='rootpath/images',
        ann_file='rootpath/test.txt',
        pipeline=test_pipeline))

evaluation = dict(interval=1, metric='accuracy')

其他需要修改的地方:
1、修改加载数据的格式,将./mmclassification-0.25.0/mmcls/datasets/custom.py的CustomDataset里面的load_annotations函数替换成下面的函数:

    ###修改成多标签分类数据加载方式###
    def load_annotations(self):
            """Load image paths and gt_labels."""
            if self.ann_file is None:
                samples = self._find_samples()
            elif isinstance(self.ann_file, str):
                lines = mmcv.list_from_file(
                    self.ann_file, file_client_args=self.file_client_args)
                samples = [x.strip().rsplit(' ', 1) for x in lines]
            else:
                raise TypeError('ann_file must be a str or None')

            data_infos = []
            for filename, gt_label in samples:
                info = {'img_prefix': self.data_prefix}
                info['img_info'] = {'filename': filename.strip()}
                temp_label = np.zeros(len(self.CLASSES))
                # if not self.multi_label:
                #     info['gt_label'] = np.array(gt_label, dtype=np.int64)
                # else:
                ### multi-label classify
                if len(gt_label) == 1:
                    temp_label[np.array(gt_label, dtype=np.int64)] = 1
                    info['gt_label'] = temp_label
                else:
                    for label in gt_label.split(','):
                        i = self.CLASSES.index(label)
                        temp_label[np.array(i, dtype=np.int64)] = 1
                    # for i in range(np.array(gt_label.split(','), dtype=np.int64).shape[0]):
                    #     temp_label[np.array(gt_label.split(','), dtype=np.int64)[i]] = 1
                    info['gt_label'] = temp_label
                # print(info)
                data_infos.append(info)
            return data_infos

记得在初始函数__init__里修改成自己要训练的类别:
在这里插入图片描述

2、修改评估数据的函数,将./mmclassification-0.25.0/mmcls/models/losses/accuracy.py里面的accuracy_torch函数替换成如下函数。我这里只是增加了一些度量函数,方便可视化多标签的指标情况,并没有更新其他地方,训练时还是会验证原来的指标,里面调用的Metric类可以参考这篇文章:https://blog.csdn.net/u013250861/article/details/122727704

def accuracy_torch(pred, target, topk=(1,), thrs=0.):
    if isinstance(thrs, Number):
        thrs = (thrs,)
        res_single = True
    elif isinstance(thrs, tuple):
        res_single = False
    else:
        raise TypeError(f'thrs should be a number or tuple, but got {type(thrs)}.')

    res = []
    maxk = max(topk)
    num = pred.size(0)
    pred = pred.float()
    
    #### ysn修改,增加对多标签分类的度量函数 ###
    pred_ = (pred > 0.5).float()        # 将 pred 中大于0.5的元素替换为1,其余替换为0

    # print("pred shape:", pred.shape, "pred:", pred)
    # # print("pred_ shape:", pred_.shape, "pred_:", pred_)
    # # print("target shape", target.shape, "target:", target)
    from mmcls.utils import get_root_logger
    logger = get_root_logger()
    
    from sklearn.metrics import classification_report
    class_report = classification_report(target.numpy(), pred_.numpy(), target_names=[“这里可以写成你的训练类型列表,也可以不使用这个参数”])     #分类报告汇总了精确率、召回率和 F1 分数等指标
    logger.info("\nClassification Report:\n{}".format(class_report))

    myMetic = Metric(pred_.numpy(), target.numpy())
    ham = myMetic.hamming_distance()
    avgPrecision, _ = myMetic.avgPrecision()
    avgRecall, _, _  = myMetic.avgRecall()
    ranking_loss = myMetic.get_ranking_loss()
    accuracy_multiclass = myMetic.accuracy_multiclass()
    logger.info("\nHam:{}\tAvgPrecision:{}\tAvgRecall:{}\tRanking_loss:{}\tAccuracy_Multilabel:{}".format(ham, avgPrecision, avgRecall, ranking_loss, accuracy_multiclass))

    ####原来的代码###
    pred_score, pred_label = pred.topk(maxk, dim=1)
    pred_label = pred_label.t()

    target = target.argmax(dim=1)     ### ysn修改,这里是多标签分类标签列表的格式,单标签分类去掉这一句 ###

    correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
    for k in topk:
        res_thr = []
        for thr in thrs:
            # Only prediction values larger than thr are counted as correct
            _correct = correct & (pred_score.t() > thr)
            correct_k = _correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res_thr.append((correct_k.mul_(100. / num)))
        if res_single:
            res.append(res_thr[0])
        else:
            res.append(res_thr)
    return res

3、修改推理部分,将./mmclassification-0.25.0/mmcls/apis/inference.py里面的inference_model函数修改如下,推理多标签时候可以指定输出所有得分阈值大于0.5的所有标签类型。

def inference_model(model, img):
    """Inference image(s) with the classifier.

    Args:
        model (nn.Module): The loaded classifier.
        img (str/ndarray): The image filename or loaded image.

    Returns:
        result (dict): The classification results that contains
            `class_name`, `pred_label` and `pred_score`.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    if isinstance(img, str):
        if cfg.data.test.pipeline[0]['type'] != 'LoadImageFromFile':
            cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))
        data = dict(img_info=dict(filename=img), img_prefix=None)
    else:
        if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
            cfg.data.test.pipeline.pop(0)
        data = dict(img=img)
    test_pipeline = Compose(cfg.data.test.pipeline)
    data = test_pipeline(data)
    data = collate([data], samples_per_gpu=1)
    if next(model.parameters()).is_cuda:
        # scatter to specified GPU
        data = scatter(data, [device])[0]

    # forward the model
    # with torch.no_grad():
    #     scores = model(return_loss=False, **data)
    #     pred_score = np.max(scores, axis=1)[0]
    #     pred_label = np.argmax(scores, axis=1)[0]
    #     result = {'pred_label': pred_label, 'pred_score': float(pred_score)}
    # result['pred_class'] = model.CLASSES[result['pred_label']]
    # return result

    ## ysn修改 ##
    with torch.no_grad():
        scores = model(return_loss=False, **data)
        # print(scores, type(scores), len(scores), len(model.CLASSES))
    result = {'pred_label':[], 'pred_score': [], 'pred_class':[]}
    for i in range(len(scores[0])):
        if scores[0][i]>0.5:
            result['pred_label'].append(int(i))
            result['pred_score'].append(float(scores[0][i]))
            result['pred_class'].append(model.CLASSES[int(i)])
        else:
            continue
    return result

或者直接使用以下推理脚本:

# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
import warnings
import os
import mmcv
import torch
import numpy as np
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmcls.datasets.pipelines import Compose
from mmcls.models import build_classifier


def init_model(config, checkpoint=None, device='cuda:0', options=None):
    """Initialize a classifier from config file.

    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
        options (dict): Options to override some settings in the used config.

    Returns:
        nn.Module: The constructed classifier.
    """
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
    if options is not None:
        config.merge_from_dict(options)
    config.model.pretrained = None
    model = build_classifier(config.model)
    if checkpoint is not None:
        # Mapping the weights to GPU may cause unexpected video memory leak
        # which refers to https://github.com/open-mmlab/mmdetection/pull/6405
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
        if 'CLASSES' in checkpoint.get('meta', {}):
            model.CLASSES = checkpoint['meta']['CLASSES']
        else:
            from mmcls.datasets import ImageNet
            warnings.simplefilter('once')
            warnings.warn('Class names are not saved in the checkpoint\'s '
                          'meta data, use imagenet by default.')
            model.CLASSES = ImageNet.CLASSES
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model


def inference_model(model, img, threshold=0.5):
    """Inference image(s) with the classifier.

    Args:
        model (nn.Module): The loaded classifier.
        img (str/ndarray): The image filename or loaded image.

    Returns:
        result (dict): The classification results that contains
            `class_name`, `pred_label` and `pred_score`.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    if isinstance(img, str):
        if cfg.data.test.pipeline[0]['type'] != 'LoadImageFromFile':
            cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))
        data = dict(img_info=dict(filename=img), img_prefix=None)
    else:
        if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
            cfg.data.test.pipeline.pop(0)
        data = dict(img=img)
    test_pipeline = Compose(cfg.data.test.pipeline)
    data = test_pipeline(data)
    data = collate([data], samples_per_gpu=1)
    if next(model.parameters()).is_cuda:
        # scatter to specified GPU
        data = scatter(data, [device])[0]

    ### 原始代码 ###
    # forward the model
    # with torch.no_grad():
    #     scores = model(return_loss=False, **data)
    #     pred_score = np.max(scores, axis=1)[0]
    #     pred_label = np.argmax(scores, axis=1)[0]
    #     result = {'pred_label': pred_label, 'pred_score': float(pred_score)}
    # result['pred_class'] = model.CLASSES[result['pred_label']]
    # return result

    ### ysn修改 ###
    with torch.no_grad():
        scores = model(return_loss=False, **data)
        # print(scores, type(scores), len(scores), len(model.CLASSES))
    result = {'pred_label':[], 'pred_score': [], 'pred_class':[]}
    for i in range(len(scores[0])):
        if scores[0][i] > threshold:
            result['pred_label'].append(int(i))
            result['pred_score'].append(round(float(scores[0][i]), 4))
            result['pred_class'].append(model.CLASSES[int(i)])
        else:
            continue
    return result


def show_result(img, result, out_file):
    import matplotlib.pyplot as plt
    plt.imshow(img)
    plt.title(f'{result["pred_class"]}: {result["pred_score"]}')
    plt.axis('off')
    if out_file is not None:
        plt.savefig(out_file)
    plt.show()


def save_result(imgpath, result, outfile="result.txt"):
    # print(result['pred_label'], result['pred_class'], result['pred_score'])
    with open(outfile, "a+") as f:
        f.write(imgpath + "\t" + ",".join(result["pred_class"]) + "\n")
    f.close()

def main():
    parser = ArgumentParser()
    parser.add_argument('--imgpath', default="./images", help='Image file')
    parser.add_argument('--img', default=None, help='Image file')
    parser.add_argument('--outpath', default="./res", help='Image file')
    parser.add_argument('--config', default="config.py",  help='Config file')
    parser.add_argument('--checkpoint', default="./epoch_100.pth",  help='Checkpoint file')
    parser.add_argument('--device', default='cuda:0', help='Device used for inference')
    args = parser.parse_args()

    if not os.path.exists(args.outpath):
        os.mkdir(args.outpath)

    model = init_model(args.config, args.checkpoint, device=args.device)

    if args.img is None and os.path.exists(args.imgpath):
        for imgname in os.listdir(args.imgpath):
            img_path = os.path.join(args.imgpath, imgname)
            img = mmcv.imread(img_path)
            if img is None:
                continue
            result = inference_model(model, img, threshold=0.5)
            print("img_path: ", img_path, result)
            save_result(img_path, result, outfile=os.path.join(args.outpath, "result.txt"))
            show_result(img, result, out_file=os.path.join(args.outpath, imgname.replace('.jpg', '_res.jpg')))

    elif args.img is not None and os.path.exists(args.img):
        result = inference_model(model, args.img, threshold=0.5)
        # print(result['pred_label'], result['pred_class'], result['pred_score'])
    else:
        raise Exception('No such file or directory: {}'.format(args.img))



if __name__ == '__main__':
    main()

通过以上修改,可以成功训练、评估、推理多标签分类训练了。
由于我没有找到mmcls官方的训练多标签的训练教程,因此做了上述修改。如果有其他更方便有效的多标签多分类方法或者项目,欢迎在该文章下面留言,非常感谢。

参考文章
https://blog.csdn.net/litt1e/article/details/125316552
https://blog.csdn.net/u013250861/article/details/122727704

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2224379.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WORFBENCH:一个创新的评估基准,目的是全面测试大型语言模型在生成复杂工作流 方面的性能。

2024-10-10,由浙江大学和阿里巴巴集团联合创建的WORFBENCH,一个用于评估大型语言模型(LLMs)生成工作流能力的基准测试。它包含了一系列的测试和评估协议,用于量化和分析LLMs在处理复杂任务时分解问题和规划执行步骤的能力。WORFBE…

【Super-resolved q-space learning of diffusion MRI】

扩散MRI的超分辨q-空间学习 摘要: 背景:扩散磁共振成像 (dMRI) 提供了一种强大的工具,可以无创地研究活人大脑中的神经结构。然而,它对神经结构的重建性能依赖于 q 空间中扩散梯度的数量。高角度(HA)dMRI…

【北京迅为】《STM32MP157开发板嵌入式开发指南》-第六十七章 Trusted Firmware-A 移植

iTOP-STM32MP157开发板采用ST推出的双核cortex-A7单核cortex-M4异构处理器,既可用Linux、又可以用于STM32单片机开发。开发板采用核心板底板结构,主频650M、1G内存、8G存储,核心板采用工业级板对板连接器,高可靠,牢固耐…

(二十三)Java反射

1.反射概念 反射允许对成员变量,成员方法和构造方法的信息进行编程访问,通俗理解就是允许从类里面拿东西,用途有提示词等,如下所示都是通过反射实现的 所以,学习反射就是学习从字节码class文件中获取成员信息并且对其…

流媒体协议.之(RTP,RTCP,RTSP,RTMP,HTTP)(一)

闲着没事做,记录一下开发项目用过的协议,项目中,大多是是实时显示播放的,通过私有协议,传输到上位机,实时播放,延时小于200ms,仿照这些协议,定义的数据格式。如果用这些协…

C语言实现Go的defer功能

之前笔者写了一篇博文C实现Go的defer功能,介绍了如何在C语言中实现Go的defer功能,那在C语言中是否也可以实现这样的功能呢?本文就将介绍一下如何在C语言中实现Go的defer功能。 我们还是使用C实现Go的defer功能中的示例: void te…

一文彻底理解 JavaScript 解构赋值

一、基本概念 为什么需要解构呢,先来看一个例子: const student {name: ZhangSan,age: 18,scores: {math: 19,english: 85,chinese: 100} };function displayInfo(student) {console.log(name:, student.name);console.log(math:, student.scores.mat…

排序(二)快速排序的多种实现方法

目录 一.快速排序 1.左右指针法 2.挖坑法 3.前后指针法 4.非递归实现 5.快速排序特性总结 二.整体代码 1.Sort.h 2.Sort.c 3.Stack.h 4.Stack.c 5.test.c 一.快速排序 1.左右指针法 我们找到一个key,begin去找比key大的值,end去找比key小的值,找到了就将begin和end…

从头学PHP之运算符

关于运算符的图片均来自网络,主要是自己写太麻烦了,程序是个简化自己工作量的方式,能复制粘贴就不要手写了(建议初期还是多写写,加深下记忆)在这里我就偷个懒,图片涉及到侵权及时,请…

阻塞队列——Java

一、前言 阻塞队列也是队列的一种,但是带有阻塞性质。但是这种阻塞情况是极端情况,在生产、消费者模型中,当生产者与消费者不协调时,就会出现阻塞情况。 二、特性 线程安全 阻塞特性 若队列为空,当尝试出队列时&am…

深度解析跨境支付之跨境支付与国内支付对比

跨境支付和国内支付的不同点主要体现在5个方面: 1.交易币种不同 这一点其实有两层含义 第一层含义是二者的支付行为的交易币种不同,国内支付基本是人民币但是跨境支付可以是人民币也可以是外币,具体交易币种要取决于收款方要求的交易币种。…

数据结构(8.4_1)——简单选择排序

简单选择排序 每一趟在待排序元素中选取关键字最小的元素加入有序子序列 代码实现 //简单选择排序 void SelectSort(int A[], int n) {for (int i 0; i < n - 1; i) {//一共进行n-1趟int min i;//记录最小元素位置for (int j i 1; j < n; j)//在A[i...n-1中选择最…

RabbitMQ延迟消息插件安装(Docker环境)

背景&#xff1a;当我们需要使用RabbitMQ发送延迟消息的时候&#xff0c;为了简化延迟消息发送的实现&#xff0c;一般都会给RabbitMQ安装延迟插件"rabbitmq_delayed_message_exchange" 如下会说明使用Docker启动的RabbitMQ容器如何安装延迟消息插件。 1. Docker启动…

用接地气的例子趣谈 WWDC 24 全新的 Swift Testing 入门(一)

概述 从 WWDC 24 开始&#xff0c;苹果推出了全新的测试机制&#xff1a;Swift Testing。利用它我们可以大幅度简化之前“老态龙钟”的 XCTest 编码范式&#xff0c;并且使得单元测试更加灵动自由&#xff0c;更符合 Swift 语言的优雅品味。 在这里我们会和大家一起初涉并领略…

docker配置mysql8报错 ERROR 2002 (HY000)

通过docker启动的mysql&#xff0c;发现navicat无法连接&#xff0c;后来进入容器内部也是无法连接&#xff0c;产生以下错误 root9f3b90339a14:/var/run/mysqld# mysql -u root -p Enter password: ERROR 2002 (HY000): Cant connect to local MySQL server through socket …

LINUX1.5.1(vim编辑器)

vim: 1. vim 2.vim /PATTERN vi编辑器与三种常见的模式&#xff1a; 复制 粘贴 剪切 删除 编辑 退出 保存 行间跳转 显示行号 查找替换 命令模式&#xff1a;光标的移动&#xff0c;使用快捷键&#xff0c;复制&#xff0c;粘贴&#xff0c;删除等基础操作 编辑模式&…

【mysql进阶】4-7. 通用表空间

通⽤表空间 - General Tablespace 1 通⽤表空间的作⽤和特性&#xff1f; ✅ 解答问题 通⽤表空间是使⽤ CREATE tablespace 语法创建的共享InnoDB表空间 通⽤表空间能够存储多个表的数据&#xff0c;与系统表空间类似也是共享表空间&#xff1b; 服务器运⾏时会把表空间元数…

【C++】智能指针:解决内存泄漏、悬空指针等问题

⭐️个人主页&#xff1a;小羊 ⭐️所属专栏&#xff1a;C 很荣幸您能阅读我的文章&#xff0c;诚请评论指点&#xff0c;欢迎欢迎 ~ 目录 前言一、RAII二、智能指针原理三、auto_ptr四、unique_ptr五、shared_ptr第一步&#xff1a;实现出RAII的框架第二步&#xff1a;如何…

信息收集-shodan专题一

shodan介绍 一、shodan简介 1.工作原理解析&#xff1a; 2.优缺点 3.功能 二、安装shodan流程 三、shodan使用方法 1.搜索 1.1.search 搜索 1.2. count 总数 1.3. download 下载与解析 2. 指定查看 2.1 指定IP的详细信息 2.2 hostname: 搜索指定的域名 2.3 port:…

百度智能云推出11.11活动,各大云厂商香港服务器优惠活动汇总

2024年双十一活动就要来了&#xff0c;作为百度集团旗下的云智能服务平台——百度智能云今年率先开始了11.11狂欢购活动&#xff0c;上新促销活动的动作如此之快&#xff0c;难道是百度云要大发力了&#xff1f;感觉今年百度智能云要比阿里云、腾讯云、硅云、华为云等厂商更加卖…