传知代码-半监督学习与数据增强(论文复现)

news2024/10/2 1:22:17

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

概述

本文复现论文 FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence[1] 提出的半监督学习方法。

半监督学习(Semi-supervised Learning)是一种机器学习方法,它将少量的标注数据(带有标签的数据)和大量的未标注数据(不带标签的数据)结合起来训练模型。在许多实际应用中,标注数据获取成本高且困难,而未标注数据通常较为丰富和容易获取。因此,半监督学习方法被引入并被用于利用未标注数据来提高模型的性能和泛化能力。

在这里插入图片描述

该论文介绍了一种基于一致性和置信度的半监督学习方法 FixMatch。FixMatch首先使用模型为弱增强后的未标注图像生成伪标签。对于给定图像,只有当模型产生高置信度预测时才保留伪标签。然后,模型在输入同一图像的强增强版本时被训练去预测伪标签。FixMatch 在各种半监督学习数据集上实现了先进的性能。

算法原理

FixMatch 结合了两种半监督学习方法:一致性正则化和伪标签。其主要创新点在于这两种方法的结合以及在执行一致性正则化时分别使用了弱增强和强增强。

FixMatch 的损失函数由两个交叉熵损失项组成:一个用于有标签数据的监督损失 l s l_s ls 和一个用于无标签数据的无监督损失 l u l_u lu 。具体来说, l s l_s ls 只是对弱增强有标签样本应用的标准交叉熵损失:
l s = 1 B ∑ b = 1 B H ( p b , p m ( y ∣ α ( x b ) ) ) l_s=\frac{1}{B}\sum_{b=1}^B H(p_b,p_m(y|\alpha(x_b))) ls=B1b=1BH(pb,pm(yα(xb)))
其中 B B B 表示 batch size, H H H 表示交叉熵损失, p b p_b pb 表示标记, p m ( y ∣ α ( x b ) ) p_m(y|\alpha(x_b)) pm(yα(xb)) 表示模型对弱增强样本的预测结果。

FixMatch 对每个无标签样本计算一个伪标签,然后在标准交叉熵损失中使用该标签。为了获得伪标签,我们首先计算模型对给定无标签图像的弱增强版本的预测类别分布: q b = p m ( y ∣ α ( u b ) ) q_b=p_m(y|\alpha(u_b)) qb=pm(yα(ub))。然后,我们使用 q ^ b = arg ⁡ max ⁡ q b \hat q_b=\arg\max q_b q^b=argmaxqb 作为伪标签,但我们在交叉熵损失中对模型对 u b u_b ub 的强增强版本的输出进行约束:
l u = 1 μ B ∑ b = 1 μ B 1 ( m a x ( q b ) > τ ) H ( q ^ b , p m ( y ∣ A ( u b ) ) ) l_u=\frac{1}{\mu B}\sum_{b=1}^{\mu B} 1(max(q_b)>\tau)H(\hat q_b,p_m(y|A(u_b))) lu=μB1b=1μB1(max(qb)>τ)H(q^b,pm(yA(ub)))
其中 μ \mu μ 表示无标签样本与有标签样本数量之比, 1 ( m a x ( q b ) > τ ) 1(max(q_b)>\tau) 1(max(qb)>τ) 当前仅当 m a x ( q b ) > τ max(q_b)>\tau max(qb)>τ 成立时为 1 否则为 0, τ \tau τ 表示置信度阈值, A ( u b ) A(u_b) A(ub) 表示对无标签样本的强增强。

FixMatch的总损失就是 l s + λ u l u l_s + \lambda_u l_u ls+λulu,其中 λ u \lambda_u λu 是表示无标签损失相对权重的标量超参数。

在这里插入图片描述

FixMatch 利用两种增强方法:“弱增强”和“强增强”。论文所使用的弱增强是一种标准的翻转和位移增强策略。具体来说,除了SVHN数据集之外,我们在所有数据集上以50%的概率随机水平翻转图像,并随机在垂直和水平方向上平移图像最多12.5%。对于“强增强”,我采用了基于随机幅度采样的 RandAugment,然后进行了 Cutout 处理。

我在CIFAR-10、CIFAR-100 、SVHN 和 FER2013 数据集上对 FixMatch 进行了实验。关于使用的神经网络,我在 CIFAR-10 和 SVHN 上使用了 Wide ResNet-28-2,在 CIFAR-100 上使用了 Wide ResNet-28-8,在 FER2013 上使用了 Wide ResNe-37-2。实验结果如下表所示:

数据集准确率(%)
CIFAR-1086.39
CIFAR-10068.88
SVHN91.25
FER201368.57

为了直观展示 FixMatch 的效果,我在线部署了基于 FER2013 数据集训练的 Wide ResNe-37-2 模型。FER2013[2] 是一个面部表情识别数据集,其包含约 30000 张不同表情的面部 RGB 图像,尺寸限制为 48×48。其主要标签可分为 7 种类型:愤怒(Angry),厌恶(Disgust),恐惧(Fear),快乐(Happy),悲伤(Sad),惊讶(Surprise),中性(Neutral)。厌恶表情的图像数量最少,只有 600 张,而其他标签的样本数量均接近 5,000 张。

核心逻辑

具体的核心逻辑如下所示:

for epoch in range(epochs):
    model.train()
    train_tqdm = zip(labeled_dataloader, unlabeled_dataloader)
    for labeled_batch, unlabeled_batch in train_tqdm:
        optimizer.zero_grad()
        # 利用标记样本计算损失
        data = labeled_batch[0].to(device)
        labels = labeled_batch[1].to(device)
        logits = model(normalize(strong_aug(data)))
        loss = F.cross_entropy(logits, labels)
        # 计算未标记样本伪标签
        with torch.no_grad():
            data = unlabeled_batch[0].to(device)
            logits = model(normalize(weak_aug(data)))
            probs = F.softmax(logits, dim=-1)
            trusted = torch.max(probs, dim=-1).values > threshold
            pseudo_labels = torch.argmax(probs[trusted], dim=-1)
            loss_factor = weight * torch.sum(trusted).item() / data.shape[0]
        # 利用未标记样本计算损失
        logits = model(normalize(strong_aug(data[trusted])))
        loss += loss_factor * F.cross_entropy(logits, pseudo_labels)
        # 反向梯度传播并更新模型参数
        loss.backward()
        optimizer.step()

以上代码仅作展示,更详细的代码文件请参见附件。

效果演示

网站提供了在线体验功能。用户需要输入一张长宽尽可能相等且大小不超过 1MB 的正面脸部 JPG 图像,网站就会返回图片中人物表情所表达的情绪。

在这里插入图片描述

使用方式

  • 解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:
unzip FixMatch.zip
cd FixMatch
  • 代码的运行环境可通过如下命令进行配置:
pip install -r requirements.txt
  • 如果希望在本地运行程序,请运行如下命令:
python main.py
  • 如果希望在线部署,请运行如下命令:
python main-flask.py

(以上内容皆为原创,请勿转载)

参考文献

[1] Sohn K, Berthelot D, Carlini N, et al. Fixmatch: Simplifying semi-supervised learning with consistency and confidence[J]. Advances in neural information processing systems, 2020, 33: 596-608.

[2] Wang L, Xu S, Wang X, et al. Eavesdrop the composition proportion of training labels in federated learning[J]. arXiv preprint arXiv:1910.06044, 2019.

源码下载

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1979425.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用Python库开发Markdown编辑器并将内容导出为图片

简介 在本文中,我们将探索如何使用Python的wxPython库开发一个Markdown编辑器应用程序。这个应用程序不仅能浏览和编辑Markdown文件,还可以将编辑的内容导出为PNG图片。 C:\pythoncode\new\markdowneditor.py 完整代码 import wx import markdown2 im…

前端实现下载word(多个word下载)-- docxtemplater

文章目录 🔎什么是docxtemplater?👻docxtemplater语法👻普通插值for循环选择图片(🌰代码) 👻实现下载👻安装依赖🌰完整代码关键逻辑解释 👻实现多…

CSP-J 复赛 模拟题 解析版

根据解析写代码1&#xff1a; #include <bits/stdc.h> using namespace std; long long a[101010]; long long b[101010]; int main(){bool flag0;long long t;cin>>t;while(t--){long long n,k;cin>>n>>k;for(int i1;i<n;i){cin>>a[i]>…

kickstart自动安装脚本

1、准备阶段 #开启图形 init 5 ​ #安装带GUI的服务器包组 yum -y groupinstall "Server with GUI" ​ #在xshell做需要加X ssh -Xl root 172.25.254.128 ​ #开启图形 gedit ​ 2、kickstart [rootpxe ~]# cat /root/anaconda-ks.cfg #此文件是在系统安装好后…

大数据Flink(一百零九):阿里云Flink的基本名称概念

文章目录 阿里云Flink的基本名称概念 一、层次结构 二、​​​​​​​​​​​​​​概念说明 1、工作空间&#xff08;Workspace&#xff09; 2、项目空间&#xff08;Namespace&#xff09; 3、资源&#xff08;Resource&#xff09; 4、草稿&#xff08;Draft&#…

将本地微服务发布到docker镜像二:

上一篇文章我们介绍了如何将一个简单的springboot服务发布到docker镜像中&#xff0c;这一篇我们将介绍如何将一个复杂的微服务&#xff08;关联mysql、redis&#xff09;发布到docker镜像。 我们将使用以下两种不同的方式来实现此功能。 redis、mysql、springboot微服务分开…

linux的自动检测的脚本:用于检测应用程序状态的linux脚本

目录 一、要求 1、需求内容 2、分析 二、脚本介绍 1、脚本代码 2、脚本解释 &#xff08;1&#xff09;脚本结构 &#xff08;2&#xff09;应用程序和服务列表 &#xff08;3&#xff09;日志文件路径 &#xff08;4&#xff09;测试 URL 列表 &#xff08;5&#…

智能小程序 Ray 开发面板 SDK —— 无线开关一键执行模板教程(一)

1. 准备工作 前提条件 已阅读 Ray 新手村任务&#xff0c;了解 Ray 框架的基础知识已阅读 使用 Ray 开发万能面板&#xff0c;了解 Ray 面板开发的基础知识 构建内容 在此 Codelab 中&#xff0c;您将利用面板小程序开发构建出一个支持一键执行及自动化的无线开关面板&…

HCIP----BGP综合实验

一、实验拓扑 二、实验要求 三、实验思路 1.基于172.16.0.0/16根据实验拓扑图进行IP地址规划&#xff0c;规划过程如下&#xff1a; 2.根据上述的IP地址的规划进行配置&#xff0c;配置完后在AS2中配置OSPF使其内部实现全网通&#xff08;互相建邻的条件&#xff09;。 3.在A…

keil编程中#pragma NOAREGS的作用和优点

参考 功能 不直接操作内存地址 #pragma NOAREGS在Keil中的使用含义是禁用自动分配寄存器&#xff0c;开发人员指定控制的寄存器。‌例如中断的执行使用的寄存器需要人为的指定&#xff0c;避免分配同样的寄存器导致数据错误。对寄存器R0到R7不直接操作寄存器地址&#xff0c…

C# 设计模式六大原则之依赖倒置原则

总目录 前言 1 基本介绍 1. 定义 依赖倒置原则 Dependence Inversion Principle&#xff0c;简称&#xff1a;DIP。 依赖倒置原则&#xff1a;高层模块不应该依赖低层模块&#xff0c;二者都应该依赖其抽象&#xff1b;抽象不应该依赖细节&#xff0c;细节应该依赖抽象。 2…

GO之基本语法

一、Golang变量 一&#xff09;变量的声明&#xff1a;使用var关键字 Go语言是静态类型语言 Go语言的基本类型有&#xff1a; boolstringint、int8、int16、int32、int64uint、uint8、uint16、uint32、uint64、uintptrbyte // uint8 的别名rune // int32 的别名 代表一个 Unic…

CTF-web 基础 网络协议

网络协议 OSI七层参考模型&#xff1a;一个标准的参考模型 物理层 网线&#xff0c;网线接口等。 数据链路层 可以处理物理层传入的信息。 网络层 比如IP地址 传输层 控制传输的内容的传输&#xff0c;在传输的过程中将要传输的信息分块传输完成之后再进行合并。 应用…

sql语句精讲

目录 一、MySql的细节知识 SQL语句的结束 SQL语句的大小写 SQL语句中的空格 SQL查询结果并排序 WHERE和ORDER BY的位置 常用数据类型 MYSQL比较运算符 MYSQL算术运算符​编辑 MYSQL逻辑运算符 运算符的优先级 二、MySql数据库的基本操作 1.创建数据库 2.选择数据…

kubenetes证书续签

转载&#xff1a;k8s证书续签 1 检查证书年限 kubeadm certs check-expiration 2 对现有证书进行备份 # 备份kubeadm cp -ra /usr/bin/kubeadm /usr/bin/kubeadm.bak # 备份证书 cp -ra /etc/kubernetes /etc/kubernetes.bak 3 重新编译kubeadm 拉取k8s仓库代码 git clone…

【数据分析--Pandas实战指南在真实世界数据中的应用】

前言&#xff1a; &#x1f49e;&#x1f49e;大家好&#xff0c;我是书生♡&#xff0c;本阶段和大家一起分享和探索数据分析—基础介绍&#xff0c;本篇文章主要讲述了&#xff1a;数据分析的介绍&#xff0c;Python开源库&#xff0c;配置Jupyter&#xff0c;Pandas读取数据…

echarts 漏斗图 渐变金字塔

使用echarts实现金字塔效果&#xff0c;颜色渐变&#xff0c;左右显示其对应的值 效果&#xff1a; 如果要实现一个正三角的形状&#xff0c;需要在data数组中&#xff0c;将value赋值成有序递增&#xff0c;bl代表他的分值&#xff0c;显示在左侧。 var data [{name: "…

NSS [SWPUCTF 2022 新生赛]file_master

NSS [SWPUCTF 2022 新生赛]file_master 开题&#xff0c;一眼文件上传。 network看看返回包。后端语言是PHP。 除了文件上传还有个查看文件功能。 起手式查询/etc/passwd&#xff0c;发现查询方法是GET提交参数&#xff0c;后端使用file_get_contents()函数包含文件。同时有op…

MySQL基础练习题21-按日期分组销售产品

目录 题目 准备数据 分析数据 总结 题目 找出每个日期、销售的不同产品的数量及其名称。每个日期的销售产品名称应按词典序排列。 返回按 sell_date 排序的结果表。 准备数据 ## 创建库 create database db; use db;## 创建表 Create table If Not Exists Activities (s…

初阶数据结构4 二叉树

1. 树 1.1 树的概念与结构 树是⼀种⾮线性的数据结构&#xff0c;它是由 n&#xff08;n>0&#xff09; 个有限结点组成⼀个具有层次关系的集合。把它叫做树是因为它看起来像⼀棵倒挂的树&#xff0c;也就是说它是根朝上&#xff0c;⽽叶朝下的。 有⼀个特殊的结点&#…