偏标记学习+图像分类

news2025/1/9 16:59:35


✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:人工智能、话题分享

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

 

概述

本文复现论文 Progressive Identification of True Labels for Partial-Label Learning[1] 提出的偏标记学习方法。

随着深度神经网络的发展,机器学习任务对标注数据的需求不断增加。然而,大量的标注数据十分依赖人力资源与标注者的专业知识。弱监督学习可以有效缓解这一问题,因其不需要完全且准确的标注数据。该论文关注一个重要的弱监督学习问题——偏标记学习(Partial Label Learning),其中每个训练实例与一组候选标签相关联,但仅有一个标签是真实的。

该论文提出了一种渐进式真实标签识别方法,旨在训练过程中逐渐确定样本的真实标签。该论文所提出的方法获得了接近监督学习的性能,且与具体的网络结构、损失函数、随机优化算法无关。

算法原理

传统的监督学习常用交叉熵损失和随机梯度下降来优化深度神经网络。交叉熵损失定义如下:

LCE(x,y;θ)=∑i=1cyilog⁡fi(x;θ)LCE​(x,y;θ)=i=1∑c​yi​logfi​(x;θ)

其中, xx 表示样本特征;y=[y1,y2,…,yc]y=[y1​,y2​,…,yc​] 表示样本标签,其为独热码,即除了真实标签对应维度值为 1,其余为零;fi(x;θ)fi​(x;θ) 表示模型预测样本 xx 标签为 ii 的概率。

该论文提出的方法使用一个软标签 y^=[y^1,y^2,…,y^c]y^​=[y^​1​,y^​2​,…,y^​c​],其对任意 i∈[0,c]i∈[0,c] 满足 ∑iy^i=1∑i​y^​i​=1 且 0≤y^i≤10≤y^​i​≤1。为了使用该软标签,论文根据候选标签集 ss 对软标签进行初始化:

y^i={1∣s∣i∈s0i∉sy^​i​={∣s∣1​0​i∈si​∈s​

为了渐进式地识别真实标签,算法在每次更新参数之前,根据预测结果为下轮训练使用的软标签赋值:

y^i=fi(x;θ)I(i∈s)∑jfj(x;θ)I(j∈s)y^​i​=∑j​fj​(x;θ)I(j∈s)fi​(x;θ)I(i∈s)​

其中,I(j∈s)=1I(j∈s)=1 当且仅当 j∈sj∈s 为真,否则 I(j∈s)=0I(j∈s)=0。

核心逻辑

具体的核心逻辑如下所示:

import models
import datasets
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm

def CE_loss(probs, targets):
    """交叉熵损失函数"""
    loss = -torch.sum(targets * torch.log(probs), dim = -1)
    loss_avg = torch.sum(loss)/probs.shape[0]
    return loss_avg

class Proden:
    def __init__(self, configs):
        self.configs = configs
    
    def train(self, save = False):
        configs = self.configs
        # 读取数据集
        dataset_path = configs['dataset path']
        if configs['dataset'] == 'CIFAR-10':
            train_data, train_labels, test_data, test_labels = datasets.cifar10_read(dataset_path)
            train_dataset = datasets.Cifar(train_data, train_labels)
            test_dataset = datasets.Cifar(test_data, test_labels)
            output_dimension = 10
        elif configs['dataset'] == 'CIFAR-100':
            train_data, train_labels, test_data, test_labels = datasets.cifar100_read(dataset_path)
            train_dataset = datasets.Cifar(train_data, train_labels)
            test_dataset = datasets.Cifar(test_data, test_labels)
            output_dimension = 100
        # 生成偏标记
        partial_labels = datasets.generate_partial_labels(train_labels, configs['partial rate'])
        train_dataset.load_partial_labels(partial_labels)
        # 计算数据的均值和方差,用于模型输入的标准化
        mean = [np.mean(train_data[:, i, :, :]) for i in range(3)]
        std = [np.std(train_data[:, i, :, :]) for i in range(3)]
        normalize = transforms.Normalize(mean, std)
        # 设备:GPU或CPU
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # 加载模型
        if configs['model'] == 'ResNet18':
            model = models.ResNet18(output_dimension = output_dimension).to(device)
        elif configs['model'] == 'ConvNet':
            model = models.ConvNet(output_dimension = output_dimension).to(device)
        # 设置学习率等超参数
        lr = configs['learning rate']
        weight_decay = configs['weight decay']
        momentum = configs['momentum']
        optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay = weight_decay, momentum = momentum)
        lr_step = configs['learning rate decay step']
        lr_decay = configs['learning rate decay rate']
        lr_scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_decay)
        for epoch_id in range(configs['epoch count']):
            # 训练模型
            train_dataloader = DataLoader(train_dataset, batch_size = configs['batch size'], shuffle = True)
            model.train()
            for batch in tqdm(train_dataloader, desc='Training(Epoch %d)' % epoch_id, ascii=' 123456789#'):
                ids = batch['ids']
                # 标准化输入
                data = normalize(batch['data'].to(device))
                partial_labels = batch['partial_labels'].to(device)
                targets = batch['targets'].to(device)
                optimizer.zero_grad()
                # 计算预测概率
                logits = model(data)
                probs = F.softmax(logits, dim=-1)
                # 更新软标签
                with torch.no_grad():
                    new_targets = F.normalize(probs * partial_labels, p=1, dim=-1)
                    train_dataset.targets[ids] = new_targets.cpu().numpy()
                # 计算交叉熵损失
                loss = CE_loss(probs, targets)
                loss.backward()
                # 更新模型参数
                optimizer.step()
            # 调整学习率
            lr_scheduler.step()

以上代码仅作展示,更详细的代码文件请参见附件。

效果演示

我提前在 CIFAR-10[2] 数据集和 12 层的 ConvNet[3] 网络上训练了一份模型参数。为了测试其准确率,需要配置环境并运行main.py脚本,得到结果如下:

由图可见,该算法在测试集上获得了 89.8% 的准确率。

进一步地,测试训练出的模型在真实图片上的预测结果。在线部署模型后,将一张轮船的图片输入,可以得到输出的预测类型为 “Ship”:


我所使用的数据集(CIFAR-10)共包含十个类,示意图如下:


网站提供了在线演示功能,使用者请输入一张小于1MB、类别为上述十个类别之一、长宽尽可能相等的JPG图像。

使用方式

  • 解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:
unzip Proden-implemention.zip
cd Proden-implemention

  • 代码的运行环境可通过如下命令进行配置:
pip install -r requirements.txt

  • 运行如下命令以下载并解压数据集
bash download.sh

  • 如果希望在本地训练模型,请运行如下命令:
python main.py -c [你的配置文件路径] -r [选择下者之一:"train"、"test"、"infer"]

  • 如果希望在线部署,请运行如下命令:
python main-flask.py

参考文献

[1] Lv J, Xu M, Feng L, et al. Progressive identification of true labels for partial-label learning[C]//International conference on machine learning. PMLR, 2020: 6500-6510.

[2] Krizhevsky A, Hinton G. Learning multiple layers of features from tiny images[J]. 2009.

[3] Laine S, Aila T. Temporal ensembling for semi-supervised learning[J]. arXiv preprint arXiv:1610.02242, 2016.

​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2255112.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pdf转word/markdown等格式——MinerU的部署:2024最新的智能数据提取工具

一、简介 MinerU是开源、高质量的数据提取工具,支持多源数据、深度挖掘、自定义规则、快速提取等。含数据采集、处理、存储模块及用户界面,适用于学术、商业、金融、法律等多领域,提高数据获取效率。一站式、开源、高质量的数据提取工具&…

fedora下Jetbrains系列IDE窗口中文乱码解决方法

可以看到窗口右部分的中文内容为小方块。 进入 Settings - Appearance & Behavior - Appearance - Use custom font : Note Sans Mono CJK SC ,设置后如下图:

机器学习详解(2):线性回归之理论学习

文章目录 1 监督学习2 线性回归2.1 简单/多元线性回归2.2 最佳拟合线2.3 成本函数和梯度下降2.4 线性回归的假设2.5 线性回归的评估指标函数 3 总结 机器学习是人工智能的一个分支,主要致力于开发能够从数据中学习并进行预测的算法和统计模型。线性回归是机器学习的…

半监督学习与数据增强

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨ 🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢,在这里我会分享我的知识和经验。&am…

位运算符I^~

&运算:上下相等才是1,有一个不同就是0 |运算:只要有1返回的就是1 ^(亦或)运算:上下不同是1,相同是0 ~运算:非运算,与数据全相反 cpu核心运算原理,四种cpu底层小电路 例&#xf…

蓝桥杯软件赛系列---lesson1

🌈个人主页:羽晨同学 💫个人格言:“成为自己未来的主人~” 我们今天会再开一个系列,那就是蓝桥杯系列,我们会从最基础的开始讲起,大家想要备战明年蓝桥杯的,让我们一起加油。 工具安装 DevC…

【0x01】HCI_Inquiry_Complete事件详解

目录 一、事件概述 二、事件格式及参数 2.1. HCI_Inquiry_Complete事件格式 2.2. 参数 三、HCI_Inquiry_Complete事件触发机制 3.1. 基于查询命令完成的触发 3.2. 受查询环境和设备状态影响的触发 3.3. 与蓝牙协议栈内部逻辑相关的触发 四、事件处理流程 4.1. 事件接…

安防视频监控平台Liveweb视频汇聚管理系统管理方案

智慧安防监控Liveweb视频管理平台能在复杂的网络环境中,将前端设备统一集中接入与汇聚管理。国标GB28181协议视频监控/视频汇聚Liveweb平台可以提供实时远程视频监控、视频录像、录像回放与存储、告警、语音对讲、云台控制、平台级联、磁盘阵列存储、视频集中存储、…

shell脚本实战案例

文章目录 实战第一坑功能说明脚本实现 实战第一坑 实战第一坑:在Windows系统写了一个脚本,比如上面,随后上传到服务,执行会报错 原因: 解决方案:在linux系统touch文件,并通过vim添加内容&…

波特图方法

在电路设计中,波特图为最常用的稳定性余量判断方法,波特图的根源是如何来的,却鲜有人知。 本章节串联了奈奎斯特和波特图的渊源,给出了其对应关系和波特图相应的稳定性余量。 理论贯通,不在于精确绘…

在ensp进行IS-IS网络架构配置

一、实验目的 1. 理解IS-IS协议的工作原理 2. 熟练ensp路由连接配置 二、实验要求 需求: 路由器可以互相ping通 实验设备: 路由器router6台 使用ensp搭建实验坏境,结构如图所示 三、实验内容 R1 u t m sys undo info en sys R1 #设…

vxe-table 键盘操作,设置按键编辑方式,支持覆盖方式与追加方式

vxe-table 全键盘操作,按键编辑方式设置,覆盖方式与追加方式; 通过 keyboard-config.editMode 设置按键编辑方式;支持覆盖方式编辑和追加方式编辑 安装 npm install vxe-pc-ui4.3.15 vxe-table4.9.15// ... import VxeUI from v…

MNIST数据集_CNN

前言 提醒: 文章内容为方便作者自己后日复习与查阅而进行的书写与发布,其中引用内容都会使用链接表明出处(如有侵权问题,请及时联系)。 其中内容多为一次书写,缺少检查与订正,如有问题或其他拓展…

【Flink】Flink Checkpoint 流程解析

Flink Checkpoint 流程解析 Checkpoint 流程解析 Flink Checkpoint 流程解析Checkpint 流程概括Checkpoint 触发流程解析 (Flink 1.20)任务启动后 JobManager 开始定期对任务执行 CheckpointJobManager 使用 CheckpointCoordinator 触发 CheckpointCheckpointCoordinator 初始化…

MIT工具课第六课任务 Git基础练习题

如果您之前从来没有用过 Git,推荐您阅读 Pro Git 的前几章,或者完成像 Learn Git Branching 这样的教程。重点关注 Git 命令和数据模型相关内容; 相关内容整理链接:Linux Git新手入门 git常用命令 Git全面指南:基础概念…

Sui 主网升级至 V1.38.3

Sui 主网现已升级至 V1.38.3 版本,同时协议升级至 69 版本。请开发者及时关注并调整! 其他升级要点如下所示: 协议 #20199 在共识快速路径投票中设置允许的轮次数量。 节点(验证节点与全节点) #20238 为验证节点…

【AI系统】低比特量化原理

低比特量化原理 计算机里面数值有很多种表示方式,如浮点表示的 FP32、FP16,整数表示的 INT32、INT16、INT8,量化一般是将 FP32、FP16 降低为 INT8 甚至 INT4 等低比特表示。 模型量化则是一种将浮点值映射到低比特离散值的技术,可…

项目文章 | RNA-seq+WES-seq+机器学习,揭示DNAH5是结直肠癌的预后标志物

肿瘤突变负荷(TMB)已成为预测结直肠癌(CRC)患者预后和对免疫治疗反应的关键生物标志物。然而,全外显子测序(WES-seq)作为TMB评估的金标准,成本高且耗时。此外,高TMB患者之…

【NLP修炼系列之Bert】Bert多分类多标签文本分类实战(附源码下载)

引言 今天我们就要用Bert做项目实战,实现文本多分类任务和我在实际公司业务中的多标签文本分类任务。通过本篇文章,可以让想实际入手Bert的NLP学习者迅速上手Bert实战项目。 1 项目介绍 本文是Bert文本多分类和多标签文本分类实战,其中多分…

【CSS in Depth 2 精译_069】11.3 利用 OKLCH 颜色值来处理 CSS 中的颜色问题(上)

当前内容所在位置(可进入专栏查看其他译好的章节内容) 第四部分 视觉增强技术 ✔️【第 11 章 颜色与对比】 ✔️ 11.1 通过对比进行交流 11.1.1 模式的建立11.1.2 还原设计稿 11.2 颜色的定义 11.2.1 色域与色彩空间11.2.2 CSS 颜色表示法 11.2.2.1 RGB…