半监督学习与数据增强

news2025/1/9 17:05:52


✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:人工智能、话题分享

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

目录

概述

算法原理

核心逻辑

效果演示

使用方式

参考文献


 本文所有资源均可在该地址处获取。

概述

本文复现论文 FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence[1] 提出的半监督学习方法。

半监督学习(Semi-supervised Learning)是一种机器学习方法,它将少量的标注数据(带有标签的数据)和大量的未标注数据(不带标签的数据)结合起来训练模型。在许多实际应用中,标注数据获取成本高且困难,而未标注数据通常较为丰富和容易获取。因此,半监督学习方法被引入并被用于利用未标注数据来提高模型的性能和泛化能力。

图1:半监督数据集

该论文介绍了一种基于一致性和置信度的半监督学习方法 FixMatch。FixMatch首先使用模型为弱增强后的未标注图像生成伪标签。对于给定图像,只有当模型产生高置信度预测时才保留伪标签。然后,模型在输入同一图像的强增强版本时被训练去预测伪标签。FixMatch 在各种半监督学习数据集上实现了先进的性能。

算法原理

FixMatch 结合了两种半监督学习方法:一致性正则化和伪标签。其主要创新点在于这两种方法的结合以及在执行一致性正则化时分别使用了弱增强和强增强。

FixMatch 的损失函数由两个交叉熵损失项组成:一个用于有标签数据的监督损失 lsls​ 和一个用于无标签数据的无监督损失 lulu​ 。具体来说,lsls​ 只是对弱增强有标签样本应用的标准交叉熵损失:

ls=1B∑b=1BH(pb,pm(y∣α(xb)))ls​=B1​b=1∑B​H(pb​,pm​(y∣α(xb​)))

其中 BB 表示 batch size,HH 表示交叉熵损失,pbpb​ 表示标记,pm(y∣α(xb))pm​(y∣α(xb​)) 表示模型对弱增强样本的预测结果。

FixMatch 对每个无标签样本计算一个伪标签,然后在标准交叉熵损失中使用该标签。为了获得伪标签,我们首先计算模型对给定无标签图像的弱增强版本的预测类别分布:qb=pm(y∣α(ub))qb​=pm​(y∣α(ub​))。然后,我们使用 q^b=arg⁡max⁡qbq^​b​=argmaxqb​ 作为伪标签,但我们在交叉熵损失中对模型对 ubub​ 的强增强版本的输出进行约束:

lu=1μB∑b=1μB1(max(qb)>τ)H(q^b,pm(y∣A(ub)))lu​=μB1​b=1∑μB​1(max(qb​)>τ)H(q^​b​,pm​(y∣A(ub​)))

其中 μμ 表示无标签样本与有标签样本数量之比,1(max(qb)>τ)1(max(qb​)>τ) 当前仅当 max(qb)>τmax(qb​)>τ 成立时为 1 否则为 0,ττ 表示置信度阈值,A(ub)A(ub​) 表示对无标签样本的强增强。

FixMatch的总损失就是 ls+λululs​+λu​lu​,其中 λuλu​ 是表示无标签损失相对权重的标量超参数。

图2:方法原理图

FixMatch 利用两种增强方法:“弱增强”和“强增强”。论文所使用的弱增强是一种标准的翻转和位移增强策略。具体来说,除了SVHN数据集之外,我们在所有数据集上以50%的概率随机水平翻转图像,并随机在垂直和水平方向上平移图像最多12.5%。对于“强增强”,我采用了基于随机幅度采样的 RandAugment,然后进行了 Cutout 处理。

我在CIFAR-10、CIFAR-100 、SVHN 和 FER2013 数据集上对 FixMatch 进行了实验。关于使用的神经网络,我在 CIFAR-10 和 SVHN 上使用了 Wide ResNet-28-2,在 CIFAR-100 上使用了 Wide ResNet-28-8,在 FER2013 上使用了 Wide ResNe-37-2。实验结果如下表所示:

数据集准确率(%)
CIFAR-1086.39
CIFAR-10068.88
SVHN91.25
FER201368.57

为了直观展示 FixMatch 的效果,我在线部署了基于 FER2013 数据集训练的 Wide ResNe-37-2 模型。FER2013[2] 是一个面部表情识别数据集,其包含约 30000 张不同表情的面部 RGB 图像,尺寸限制为 48×48。其主要标签可分为 7 种类型:愤怒(Angry),厌恶(Disgust),恐惧(Fear),快乐(Happy),悲伤(Sad),惊讶(Surprise),中性(Neutral)。厌恶表情的图像数量最少,只有 600 张,而其他标签的样本数量均接近 5,000 张。

核心逻辑

具体的核心逻辑如下所示:

for epoch in range(epochs):
    model.train()
    train_tqdm = zip(labeled_dataloader, unlabeled_dataloader)
    for labeled_batch, unlabeled_batch in train_tqdm:
        optimizer.zero_grad()
        # 利用标记样本计算损失
        data = labeled_batch[0].to(device)
        labels = labeled_batch[1].to(device)
        logits = model(normalize(strong_aug(data)))
        loss = F.cross_entropy(logits, labels)
        # 计算未标记样本伪标签
        with torch.no_grad():
            data = unlabeled_batch[0].to(device)
            logits = model(normalize(weak_aug(data)))
            probs = F.softmax(logits, dim=-1)
            trusted = torch.max(probs, dim=-1).values > threshold
            pseudo_labels = torch.argmax(probs[trusted], dim=-1)
            loss_factor = weight * torch.sum(trusted).item() / data.shape[0]
        # 利用未标记样本计算损失
        logits = model(normalize(strong_aug(data[trusted])))
        loss += loss_factor * F.cross_entropy(logits, pseudo_labels)
        # 反向梯度传播并更新模型参数
        loss.backward()
        optimizer.step()

以上代码仅作展示,更详细的代码文件请参见附件。

效果演示

网站提供了在线体验功能。用户需要输入一张长宽尽可能相等且大小不超过 1MB 的正面脸部 JPG 图像,网站就会返回图片中人物表情所表达的情绪。

图3:在线演示结果

使用方式

  • 解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:
unzip FixMatch.zip
cd FixMatch

  • 代码的运行环境可通过如下命令进行配置:
pip install -r requirements.txt

  • 如果希望在本地运行程序,请运行如下命令:
python main.py

  • 如果希望在线部署,请运行如下命令:
python main-flask.py

(以上内容皆为原创,请勿转载)

参考文献

[1] Sohn K, Berthelot D, Carlini N, et al. Fixmatch: Simplifying semi-supervised learning with consistency and confidence[J]. Advances in neural information processing systems, 2020, 33: 596-608.

[2] Wang L, Xu S, Wang X, et al. Eavesdrop the composition proportion of training labels in federated learning[J]. arXiv preprint arXiv:1910.06044, 2019.

​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2255107.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

位运算符I^~

&运算:上下相等才是1,有一个不同就是0 |运算:只要有1返回的就是1 ^(亦或)运算:上下不同是1,相同是0 ~运算:非运算,与数据全相反 cpu核心运算原理,四种cpu底层小电路 例&#xf…

蓝桥杯软件赛系列---lesson1

🌈个人主页:羽晨同学 💫个人格言:“成为自己未来的主人~” 我们今天会再开一个系列,那就是蓝桥杯系列,我们会从最基础的开始讲起,大家想要备战明年蓝桥杯的,让我们一起加油。 工具安装 DevC…

【0x01】HCI_Inquiry_Complete事件详解

目录 一、事件概述 二、事件格式及参数 2.1. HCI_Inquiry_Complete事件格式 2.2. 参数 三、HCI_Inquiry_Complete事件触发机制 3.1. 基于查询命令完成的触发 3.2. 受查询环境和设备状态影响的触发 3.3. 与蓝牙协议栈内部逻辑相关的触发 四、事件处理流程 4.1. 事件接…

安防视频监控平台Liveweb视频汇聚管理系统管理方案

智慧安防监控Liveweb视频管理平台能在复杂的网络环境中,将前端设备统一集中接入与汇聚管理。国标GB28181协议视频监控/视频汇聚Liveweb平台可以提供实时远程视频监控、视频录像、录像回放与存储、告警、语音对讲、云台控制、平台级联、磁盘阵列存储、视频集中存储、…

shell脚本实战案例

文章目录 实战第一坑功能说明脚本实现 实战第一坑 实战第一坑:在Windows系统写了一个脚本,比如上面,随后上传到服务,执行会报错 原因: 解决方案:在linux系统touch文件,并通过vim添加内容&…

波特图方法

在电路设计中,波特图为最常用的稳定性余量判断方法,波特图的根源是如何来的,却鲜有人知。 本章节串联了奈奎斯特和波特图的渊源,给出了其对应关系和波特图相应的稳定性余量。 理论贯通,不在于精确绘…

在ensp进行IS-IS网络架构配置

一、实验目的 1. 理解IS-IS协议的工作原理 2. 熟练ensp路由连接配置 二、实验要求 需求: 路由器可以互相ping通 实验设备: 路由器router6台 使用ensp搭建实验坏境,结构如图所示 三、实验内容 R1 u t m sys undo info en sys R1 #设…

vxe-table 键盘操作,设置按键编辑方式,支持覆盖方式与追加方式

vxe-table 全键盘操作,按键编辑方式设置,覆盖方式与追加方式; 通过 keyboard-config.editMode 设置按键编辑方式;支持覆盖方式编辑和追加方式编辑 安装 npm install vxe-pc-ui4.3.15 vxe-table4.9.15// ... import VxeUI from v…

MNIST数据集_CNN

前言 提醒: 文章内容为方便作者自己后日复习与查阅而进行的书写与发布,其中引用内容都会使用链接表明出处(如有侵权问题,请及时联系)。 其中内容多为一次书写,缺少检查与订正,如有问题或其他拓展…

【Flink】Flink Checkpoint 流程解析

Flink Checkpoint 流程解析 Checkpoint 流程解析 Flink Checkpoint 流程解析Checkpint 流程概括Checkpoint 触发流程解析 (Flink 1.20)任务启动后 JobManager 开始定期对任务执行 CheckpointJobManager 使用 CheckpointCoordinator 触发 CheckpointCheckpointCoordinator 初始化…

MIT工具课第六课任务 Git基础练习题

如果您之前从来没有用过 Git,推荐您阅读 Pro Git 的前几章,或者完成像 Learn Git Branching 这样的教程。重点关注 Git 命令和数据模型相关内容; 相关内容整理链接:Linux Git新手入门 git常用命令 Git全面指南:基础概念…

Sui 主网升级至 V1.38.3

Sui 主网现已升级至 V1.38.3 版本,同时协议升级至 69 版本。请开发者及时关注并调整! 其他升级要点如下所示: 协议 #20199 在共识快速路径投票中设置允许的轮次数量。 节点(验证节点与全节点) #20238 为验证节点…

【AI系统】低比特量化原理

低比特量化原理 计算机里面数值有很多种表示方式,如浮点表示的 FP32、FP16,整数表示的 INT32、INT16、INT8,量化一般是将 FP32、FP16 降低为 INT8 甚至 INT4 等低比特表示。 模型量化则是一种将浮点值映射到低比特离散值的技术,可…

项目文章 | RNA-seq+WES-seq+机器学习,揭示DNAH5是结直肠癌的预后标志物

肿瘤突变负荷(TMB)已成为预测结直肠癌(CRC)患者预后和对免疫治疗反应的关键生物标志物。然而,全外显子测序(WES-seq)作为TMB评估的金标准,成本高且耗时。此外,高TMB患者之…

【NLP修炼系列之Bert】Bert多分类多标签文本分类实战(附源码下载)

引言 今天我们就要用Bert做项目实战,实现文本多分类任务和我在实际公司业务中的多标签文本分类任务。通过本篇文章,可以让想实际入手Bert的NLP学习者迅速上手Bert实战项目。 1 项目介绍 本文是Bert文本多分类和多标签文本分类实战,其中多分…

【CSS in Depth 2 精译_069】11.3 利用 OKLCH 颜色值来处理 CSS 中的颜色问题(上)

当前内容所在位置(可进入专栏查看其他译好的章节内容) 第四部分 视觉增强技术 ✔️【第 11 章 颜色与对比】 ✔️ 11.1 通过对比进行交流 11.1.1 模式的建立11.1.2 还原设计稿 11.2 颜色的定义 11.2.1 色域与色彩空间11.2.2 CSS 颜色表示法 11.2.2.1 RGB…

基础算法——搜索与图论

搜索与图论 图的存储方式2、最短路问题2.1、Dijkstra算法(朴素版)2.2、Dijkstra算法(堆优化版)2.3、Bellman-Ford算法2.4、SPFA求最短路2.5、SPFA判负环2.6、Floyd算法 图的存储方式 2、最短路问题 最短路问题可以分为单源最短路…

IDEA创建Spring Boot项目配置阿里云Spring Initializr Server URL【详细教程-轻松学会】

1.首先打开idea选择新建项目 2.选择Spring Boot框架(就是选择Spring Initializr这个) 3.点击中间界面Server URL后面的三个点更换为阿里云的Server URL Idea中默认的Server URL地址:https://start.spring.io/ 修改为阿里云Server URL地址:https://star…

Git_如何更改默认路径

网上搜了一堆都不好使,其实可以直接使用git bash输入命令来解决 打开 Git Bash:首先打开 Git Bash 终端,这是一个类似于命令提示符的窗口,可在其中执行 Git 命令。设置 Git 默认存储路径:使用 git config 命令来修改 …

计算机毕业设计Python房价预测 房屋推荐 房价可视化 链家爬虫 房源爬虫 房源可视化 卷积神经网络 大数据毕业设计 机器学习 人工智能 AI

温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 作者简介:Java领…