⌈ 传知代码 ⌋ 将一致性正则化用于弱监督学习

news2024/10/4 23:11:24

💛前情提要💛

本文是传知代码平台中的相关前沿知识与技术的分享~

接下来我们即将进入一个全新的空间,对技术有一个全新的视角~

本文所涉及所有资源均在传知代码平台可获取

以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦!!!

以下内容干货满满,跟上步伐吧~


📌导航小助手📌

  • 💡本章重点
  • 🍞一. 论文概述
  • 🍞二. 算法原理
  • 🍞三.核心逻辑
  • 🍞四.效果演示
  • 🫓总结


💡本章重点

  • 一致性正则化用于弱监督学习

🍞一. 论文概述

本文复现论文 Revisiting Consistency Regularization for Deep Partial Label Learning[1] 提出的偏标记学习方法。程序基于Pytorch,会保存完整的训练日志,并生成损失变化图和准确度变化图。

偏标记学习(Partial Label Learning)是一个经典的弱监督问题。在偏标记学习中,每个样例的监督信息为一个包含多个标签的候选标签集合。目前的偏标记方法大多基于自监督或者对比学习范式,或多或少地会遇到低性能或低效率的问题。该论文基于一致性正则化的思想,改进基于自监督的偏标记学习方法。具体地,该论文所提出的方法设计了两个训练目标。其中第一个训练目标为最小化非候选标签的预测输出,第二个目标最大化不同视图的预测输出之间的一致性。

在这里插入图片描述
总的来说,该论文所提出的方法着眼于将模型对同一图像不同增强视图的预测输出对齐,以提升模型输出的可靠性和对标签的消歧能力,这一方法同样能给其他弱监督学习任务带来提升。


🍞二. 算法原理

首先,论文所提出方法的第一项损失(监督损失)如下:

在这里插入图片描述

其中,当事件 A 为真时,I(A)= 1 否则 I(A)= 0,f(.)表示模型的输出概率。

然后,论文所提出方法的第二项损失(一致性损失)如下:

在这里插入图片描述
其在训练过程中通过所有增强视图预测结果的几何平均来更新标签分布:

在这里插入图片描述
由于数据增强的不稳定性,该论文通过叠加 K 个不同的增强视图的一致性损失来提升方法性能。

最后,考虑到训练初期模型的预测准确率较低,一致性损失的权重被设置为从零开始随着训练轮数的增加逐渐提高:

在这里插入图片描述

综上所述,模型的总损失函数如下:

在这里插入图片描述


🍞三.核心逻辑

具体的核心逻辑如下所示:

def dpll_sup_loss(probs, partial_labels):
    loss = -torch.sum(torch.log(1 + 1e-6 - probs) * (1 - partial_labels), dim=-1)
    loss_avg = torch.mean(loss)
    return loss_avg


def dpll_cont_loss(logits, targets):
    logits_log = torch.log_softmax(logits, dim=-1)
    loss = F.kl_div(logits_log, targets, reduction='batchmean')
    return loss

def train():
    # main loops
    for epoch_id in range(total_epochs):
        # train
        model.train()
        for batch in train_dataloader:
            optimizer.zero_grad()
            ids = batch['ids']
            data1 = batch['data1'].to(device)
            data2 = batch['data2'].to(device)
            data3 = batch['data3'].to(device)
            partial_labels = batch['partial_labels'].to(device)
            targets = train_targets[ids].to(device)
            logits1 = model(data1)
            logits2 = model(data2)
            logits3 = model(data3)
            probs1 = F.softmax(logits1, dim=-1)
            # update targets
            with torch.no_grad():
                probs2 = F.softmax(logits2.detach(), dim=-1)
                probs3 = F.softmax(logits3.detach(), dim=-1)
                new_targets = torch.pow(probs1.detach() * probs2 * probs3, 1 / 3)
                new_targets = F.normalize(new_targets * partial_labels, p=1, dim=-1)
                train_targets[ids] = new_targets.cpu()
            # dynamic weight
            balancing_weight = max_weight * (epoch_id + 1) / max_weight_epoch
            balancing_weight = min(max_weight, balancing_weight)
			# supervised loss
            loss_sup = dpll_sup_loss(probs1, partial_labels)
            # consistency regularization loss
            loss_cont1 = dpll_cont_loss(logits1, targets)
            loss_cont2 = dpll_cont_loss(logits2, targets)
            loss_cont3 = dpll_cont_loss(logits3, targets)
            # all loss
            loss = loss_sup + balancing_weight * (loss_cont1 + loss_cont2 + loss_cont3)
            loss.backward()
            optimizer.step()
        if epoch_id in lr_decay_epochs:
            lr_scheduler.step()

🍞四.效果演示

本文基于网络 Wide-ResNet[2] 和数据集 CIFAR-10[3] 进行实验,偏标记的随机翻转概率为0.1。当然,本文所提供的程序不仅仅提供了上述的实验设置,同时也可以直接基于CIFAR-100(100类图像分类数据集),SVHN(数字号牌识别数据集),Fashion-MNIST(时装识别数据集),Kuzushiji-MNIST(日本古草体识别数据集)进行实验。仅仅需要替换运行命令的对应部分即可(使用说明见下文)

  • 损失曲线:

在这里插入图片描述

  • 准确率曲线:

在这里插入图片描述


🫓总结

综上,我们基本了解了“一项全新的技术啦” 🍭 ~~

恭喜你的内功又双叒叕得到了提高!!!

感谢你们的阅读😆

后续还会继续更新💓,欢迎持续关注📌哟~

💫如果有错误❌,欢迎指正呀💫

✨如果觉得收获满满,可以点点赞👍支持一下哟~✨

【传知科技 – 了解更多新知识】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2188877.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是 NVIDIA 机密计算?( 上篇 )

什么是机密计算? 文章目录 前言1. 机密计算定义2. 机密计算有何独特之处?3. 机密计算是如何得名的4. 机密计算的工作原理是什么?5. 缩小安全边界6. 机密计算的使用案例7. 机密计算如何发展8. 加速机密计算9. 机密计算的下一步前言 机密计算是一种在计算机处理器的受保护区域…

全网最详细kubernetes中的资源

1、资源管理介绍 在kubernetes中,所有的内容都抽象为资源,用户需要通过操作资源来管理kubernetes。 kubernetes的本质上就是一个集群系统,用户可以在集群中部署各种服务。 所谓的部署服务,其实就是在kubernetes集群中运行一个个的…

csp-j模拟三补题报告

前言 今天题难&#xff0c;排名没进前十 &#xff08;“关于二进制中一的个数的研究与规律”这篇文章正在写&#xff09; 第一题 三个&#xff08;three&#xff09; 我的代码&#xff08;AC&#xff09; #include<bits/stdc.h> #define ll long long using namespac…

快停止这种使用U盘的行为!

前言 现在各行各业的小伙伴基本上都需要用电脑来办公了&#xff0c;你敢说你不需要用电脑办公&#xff1f; 啊哈哈哈&#xff0c;用iPad或者手机办公的也算。 有些小伙伴可能经常996&#xff0c;甚至有时候都是007。有时候到了下班时间&#xff0c;工作还没做完&#xff0c;…

Python技巧:如何处理未完成的函数

一、问题的提出 写代码的时候&#xff0c;我们有时候会给某些未完成的函数预留一个空位&#xff0c;等以后有时间再写具体内容。通常&#xff0c;大家会用 pass 或者 ... &#xff08;省略号&#xff09;来占位。这种方法虽然能让代码暂时不报错&#xff0c;但可能在调试的时候…

精准翻译神器:英汉互译软件的卓越表现

英文作为目前世界上使用最广的一种语言&#xff0c;是的很多先进的科学文献或者一些大厂产品的说明书都有英文的版本。为了方便我们的阅读和学习&#xff0c;现在有不少支持翻译英汉互译的工具&#xff0c;今天我们就一起来讨论一下吧。 1.福昕中英在线翻译 链接直达>>…

二叉树的前序遍历——非递归版本

1.题目解析 题目来源&#xff1a;144.二叉树的前序遍历——力扣 测试用例 2.算法原理 前序遍历&#xff1a; 按照根节点->左子树->右子树的顺序遍历二叉树 二叉树的前序遍历递归版本十分简单&#xff0c;但是如果树的深度很深会有栈溢出的风险&#xff0c;这里的非递归…

【论文笔记】DKTNet: Dual-Key Transformer Network for small object detection

【引用格式】&#xff1a;Xu S, Gu J, Hua Y, et al. Dktnet: dual-key transformer network for small object detection[J]. Neurocomputing, 2023, 525: 29-41. 【网址】&#xff1a;https://cczuyiliu.github.io/pdf/DKTNet%20Dual-Key%20Transformer%20Network%20for%20s…

vue3实现打字机的效果,可以换行

之前看了很多文章,效果是实现了,就是没有自动换行的效果,参考了文章写了一个,先上个效果图,卡顿是因为模仿了卡顿的效果,还是很丝滑的 目录 效果图:代码如下 效果图: ![请添加图片描述](https://i-blog.csdnimg.cn/direct/d8ef33d83dd3441a87d6d033d9e7cafa.gif 代码如下 原…

jmeter学习(8)结果查看

1&#xff09;查看结果树 查看结果树&#xff0c;显示取样器请求和响应的细节以及请求结果&#xff0c;包括消息头&#xff0c;请求的数据&#xff0c;响应的数据。 2&#xff09;汇总报告 汇总报告&#xff0c;为测试中的每个不同命名的请求创建一个表行。这与聚合报告类似&…

[数据结构] 树

n个结点的有限集合 除了根节点以外&#xff0c;每一个结点有且只有一条与父节点的连线&#xff1b; 总共有N-1条连线。 子树之间不相交。 术语 树的表示 每个结点的结构不知道 可以统一设置结构&#xff0c;优点&#xff1a;处理方便 缺点&#xff1a;会造成空间浪费&…

Chromium 硬件加速开关c++

选项页控制硬件加速开关 1、前端代码 <settings-toggle-button id"hardwareAcceleration"pref"{{prefs.hardware_acceleration_mode.enabled}}"label"$i18n{hardwareAccelerationLabel}"><template is"dom-if" if"[…

6.5 监控和日志 架构模式和应用实践

6.5 监控和日志 架构模式和应用实践 目录概述需求&#xff1a; 设计思路实现思路分析1.集中式监控2.分布式监控3.边缘监控4.集中式日志管理5.分布式日志管理6.实时日志流处理 监控工具最佳实践 参考资料和推荐阅读 Survive by day and develop by night. talk for import biz ,…

基于元神操作系统实现NTFS文件操作(六)

1. 背景 本文主要介绍$Root元文件属性的解析。先介绍元文件各属性的属性体构成&#xff0c;然后结合读取到的元文件内容&#xff0c;对测试磁盘中目标分区的根目录进行展示。 2. $Root元文件属性的解析 使用每个属性头偏移0x04-0x07处的值可以从第一个属性开始依次定位下一个…

Jupyter | jupyter notebook 使用 conda 环境

博客使用更佳 点我进入博客 创建虚拟环境 在 Anaconda Prompt 里面输入&#xff1a; conda create -n env-name并且输入 y 确认。例如我们创建环境名为 jupyter 激活环境 conda activate env-name激活之后发现环境从 base 变为 jupyter(笔者用的 env-name 为 jupyter) …

python-求一个整数的质因数/字符串的镜像/加数

一:求一个整数的质因数 题目描述 编写一个程序&#xff0c;返回给定整数的质因数。 定义函数get_prime_factors()&#xff0c;该函数接受一个参数num&#xff08;正整数&#xff09;。 该函数应返回传入参数的质因数列表&#xff0c;且从小到大排序。 比如150的质因数分解如…

Spring MVC__HttpMessageConverter、拦截器、异常处理器、注解配置SpringMVC、SpringMVC执行流程

目录 一、HttpMessageConverter1、RequestBody2、RequestEntity3、ResponseBody4、SpringMVC处理json5、SpringMVC处理ajax6、RestController注解7、ResponseEntity7.1、文件下载7.2、文件上传 二、拦截器1、拦截器的配置2、拦截器的三个抽象方法3、多个拦截器的执行顺序 三、异…

数据结构——计数、桶、基数排序

目录 引言 计数排序 1.算法思想 2.算法步骤 3.代码实现 4.复杂度分析 桶排序 1.算法思想 2.算法步骤 3.代码实现 4.复杂度分析 基数排序 1.算法思想 2.算法步骤 3.代码实现 4.复杂度分析 排序算法的稳定性 1.稳定性的概念 2.各个排序算法的稳定性 结束语 引…

初识Linux · 自主Shell编写

目录 前言&#xff1a; 1 命令行解释器部分 2 获取用户命令行参数 3 命令行参数进行分割 4 执行命令 5 判断命令是否为内建命令 前言&#xff1a; 本文介绍是自主Shell编写&#xff0c;对于shell&#xff0c;即外壳解释程序&#xff0c;我们目前接触到的命令行解释器&am…

基于vue框架的大学生四六级学习网站设计与实现i8o8z(程序+源码+数据库+调试部署+开发环境)系统界面在最后面。

系统程序文件列表 项目功能&#xff1a;学生,训练听力,学习单词,单词分类,阅读文章,文章类型,学习课程 开题报告内容 基于Vue框架的大学生四六级学习网站设计与实现开题报告 一、研究背景与意义 随着全球化进程的加速和国际交流的日益频繁&#xff0c;英语作为国际通用语言…