半监督学习与数据增强(论文复现)

news2024/11/25 2:26:29

半监督学习与数据增强(论文复现)

本文所涉及所有资源均在传知代码平台可获取

文章目录

    • 半监督学习与数据增强(论文复现)
        • 概述
        • 算法原理
        • 核心逻辑
        • 效果演示
        • 使用方式

概述

本文复现论文提出的半监督学习方法,半监督学习(Semi-supervised Learning)是一种机器学习方法,它将少量的标注数据(带有标签的数据)和大量的未标注数据(不带标签的数据)结合起来训练模型。在许多实际应用中,标注数据获取成本高且困难,而未标注数据通常较为丰富和容易获取。因此,半监督学习方法被引入并被用于利用未标注数据来提高模型的性能和泛化能力

在这里插入图片描述

该论文介绍了一种基于一致性和置信度的半监督学习方法 FixMatch。FixMatch首先使用模型为弱增强后的未标注图像生成伪标签。对于给定图像,只有当模型产生高置信度预测时才保留伪标签。然后,模型在输入同一图像的强增强版本时被训练去预测伪标签。FixMatch 在各种半监督学习数据集上实现了先进的性能

算法原理

FixMatch 结合了两种半监督学习方法:一致性正则化和伪标签。其主要创新点在于这两种方法的结合以及在执行一致性正则化时分别使用了弱增强和强增强。

FixMatch 的损失函数由两个交叉熵损失项组成:一个用于有标签数据的监督损失 lsl**s 和一个用于无标签数据的无监督损失 lul**u 。具体来说,lsl**s 只是对弱增强有标签样本应用的标准交叉熵损失

在这里插入图片描述

其中 BB 表示 batch size,HH 表示交叉熵损失,pbp**b 表示标记,pm(y∣α(xb))p**m(yα(x**b)) 表示模型对弱增强样本的预测结果。

FixMatch 对每个无标签样本计算一个伪标签,然后在标准交叉熵损失中使用该标签。为了获得伪标签,我们首先计算模型对给定无标签图像的弱增强版本的预测类别分布:qb=pm(y∣α(ub))q**b=p**m(yα(u**b))。然后,我们使用 qb=arg⁡max⁡qb*q*b=argmaxq**b 作为伪标签,但我们在交叉熵损失中对模型对 ubu**b 的强增强版本的输出进行约束

在这里插入图片描述

其中 μμ 表示无标签样本与有标签样本数量之比,1(max(qb)>τ)1(max(q**b)>τ) 当前仅当 max(qb)>τmax(q**b)>τ 成立时为 1 否则为 0,ττ 表示置信度阈值,A(ub)A(u**b) 表示对无标签样本的强增强。

FixMatch的总损失就是 ls+λulul**s+λul**u,其中 λuλ**u 是表示无标签损失相对权重的标量超参数

在这里插入图片描述

FixMatch 利用两种增强方法:“弱增强”和“强增强”。论文所使用的弱增强是一种标准的翻转和位移增强策略。具体来说,除了SVHN数据集之外,我们在所有数据集上以50%的概率随机水平翻转图像,并随机在垂直和水平方向上平移图像最多12.5%。对于“强增强”,我采用了基于随机幅度采样的 RandAugment,然后进行了 Cutout 处理。

我在CIFAR-10、CIFAR-100 、SVHN 和 FER2013 数据集上对 FixMatch 进行了实验。关于使用的神经网络,我在 CIFAR-10 和 SVHN 上使用了 Wide ResNet-28-2,在 CIFAR-100 上使用了 Wide ResNet-28-8,在 FER2013 上使用了 Wide ResNe-37-2。实验结果如下表所示

在这里插入图片描述

为了直观展示 FixMatch 的效果,我在线部署了基于 FER2013 数据集训练的 Wide ResNe-37-2 模型。FER2013[2] 是一个面部表情识别数据集,其包含约 30000 张不同表情的面部 RGB 图像,尺寸限制为 48×48。其主要标签可分为 7 种类型:愤怒(Angry),厌恶(Disgust),恐惧(Fear),快乐(Happy),悲伤(Sad),惊讶(Surprise),中性(Neutral)。厌恶表情的图像数量最少,只有 600 张,而其他标签的样本数量均接近 5,000 张

核心逻辑

具体的核心逻辑如下所示:

for epoch in range(epochs):
    model.train()
    train_tqdm = zip(labeled_dataloader, unlabeled_dataloader)
    for labeled_batch, unlabeled_batch in train_tqdm:
        optimizer.zero_grad()
        # 利用标记样本计算损失
        data = labeled_batch[0].to(device)
        labels = labeled_batch[1].to(device)
        logits = model(normalize(strong_aug(data)))
        loss = F.cross_entropy(logits, labels)
        # 计算未标记样本伪标签
        with torch.no_grad():
            data = unlabeled_batch[0].to(device)
            logits = model(normalize(weak_aug(data)))
            probs = F.softmax(logits, dim=-1)
            trusted = torch.max(probs, dim=-1).values > threshold
            pseudo_labels = torch.argmax(probs[trusted], dim=-1)
            loss_factor = weight * torch.sum(trusted).item() / data.shape[0]
        # 利用未标记样本计算损失
        logits = model(normalize(strong_aug(data[trusted])))
        loss += loss_factor * F.cross_entropy(logits, pseudo_labels)
        # 反向梯度传播并更新模型参数
        loss.backward()
        optimizer.step()
效果演示

网站提供了在线体验功能。用户需要输入一张长宽尽可能相等且大小不超过 1MB 的正面脸部 JPG 图像,网站就会返回图片中人物表情所表达的情绪

在这里插入图片描述

使用方式

解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令

unzip FixMatch.zip
cd FixMatch

代码的运行环境可通过如下命令进行配置

pip install -r requirements.txt

如果希望在本地运行程序,请运行如下命令

python main.py

如果希望在线部署,请运行如下命令

python main-flask.py

文章代码资源点击附件获取

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2186350.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C题(二)字符串转数字 --- atoi

———————————————————**目录**—————————————————— 一、 atoi函数介绍 功能函数原型使用示例 二、题解之一 三、留言 问题引入👉 输入样例👉 5 01234 00123 00012 00001 00000 输出样例👉 1234 123 …

‌文件名称与扩展名:批量重命名的技巧与指南

在日常的文件管理中,我们经常需要处理大量的文件,这些文件可能有着各种各样的名称和扩展名。为了更好地管理和识别这些文件,批量重命名成为了一项非常实用的技能。能够帮助我们快速整理文件,提高工作效率。本文将深入探讨文件名称…

vue2圆形标记(Marker)添加点击事件不弹出信息窗体(InfoWindow)的BUG解决

目录 一、问题详情 二、问题排查 三、解决方案 一、问题详情 地图上面的轨迹点希望能通过点击看到详细的经纬度信息,但是点击的时候就是显示不出来。 二、问题排查 代码都是参考高德的官方文档,初步看没有问题啊,但是点击事件就感觉失效…

10.3今日错题解析(软考)

目录 前言计算机网络——路由配置数据库系统——封锁协议 前言 这是用来记录我备考软考设计师的错题的,今天知识点为路由配置、封锁协议,大部分错题摘自希赛中的题目,但相关解析是原创,有自己的思考,为了复习&#xf…

Pix2Pix实现图像转换

tutorials/application/source_zh_cn/generative/pix2pix.ipynb MindSpore/docs - Gitee.com Pix2Pix概述 Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,该模型是由Ph…

Comparable接口和Comparator接口

前言 Java中基本数据类型可以直接比较大小,但引用类型呢?同时引用对象中可能存在多个可比较的字段,那么我们该怎么比较呢? Java中引用类型不能直接进行大小的比较,这种行为在编译器看来是危险的,所以会编译…

程序员在AI时代的生存指南:打造不可替代的核心竞争力

在这个AI大行其道的时代,似乎每天都有新的语言模型像变魔术一样涌现出来,比如ChatGPT、midjourney、claude等等。这些家伙不仅会聊天,还能帮忙写代码,让程序员们感受到了前所未有的“压力”。我身边的一些程序员朋友开始焦虑&…

SpringCloud入门(十)统一网关Gateway

一、网关的作用 Spring Cloud Gateway 是 Spring Cloud 的一个全新项目,该项目是基于 Spring 5.0,Spring Boot 2.0 和 Project Reactor 等响应式编程和事件流技术开发的网关,它旨在为微服务架构提供一种简单有效的统一的 API 路由管理方式。 …

E. Tree Pruning Codeforces Round 975 (Div. 2)

原题 E. Tree Pruning 解析 本题题意很简单, 思路也很好想到, 假设我们保留第 x 层的树叶, 那么对于深度大于 x 的所有节点都要被剪掉, 而深度小于 x 的节点, 如果没有子节点深度大于等于 x, 那么也要被删掉 在做这道题的时候, 有关于如何找到一个节点它的子节点能通到哪里,…

关于鸿蒙next 调用系统权限麦克风

使用app的时候都清楚,想使用麦克风、摄像头,存储照片等,都需要调用系统的权限,没有手机操作系统权限你也使用不了app所提供的功能,虽然app可以正常打开,但是你需要的功能是没办法使用的。今天把自己在鸿蒙学…

想怎样书写HTML5自结束标签,您随意就好(✪▽✪)

书写后接斜杠还是不接,看过ai给的详细解析就不再迷茫了。 (笔记模板由python脚本于2024年10月03日 10:42:41创建,本篇笔记适合HTML5标签的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网:https://www.python.org/ Free:大咖…

【数据库差异研究】update与delete使用表别名的研究

目录 ⚛️总结 ☪️1 Update ♋1.1 测试用例UPDATE users as a SET a.age 111 WHERE a.name Alice; ♏1.2 测试用例UPDATE users as a SET a.age 111 WHERE name Alice; ♐1.3 测试用例UPDATE users as a SET age 111 WHERE a.name Alice; ♑1.4 测试用例UPDATE us…

TIM“PWM”输出比较原理解析

PWM最重要的就是占空比,所有都是在为占空比服务,通过设置不同的占空比,产生不同的电压,产生不同的效果 定时器的输出通道 基本定时器: 基本定时器没有通道 通用定时器: 4个通道(CH1, CH2, C…

Python性能优化:实战技巧与最佳实践

Python性能优化:实战技巧与最佳实践 Python 作为一种动态解释型语言,虽然以其简洁和易用性闻名,但在性能方面可能不如静态编译型语言如 C 和 Java 高效。为了在高性能要求的应用场景下更好地利用 Python,我们需要掌握一些常见的优…

STM32GPIO输入和输出

一、先看IO端口位的结构 上面部分是输入,下面是输出。 1、I/O输入: 首先,从I/O引脚开始,有两个保护二极管,主要作用是对输入电压限幅,保护内部电路。上面二极管接VDD为3.3V,下面二极管接VSS为0V。当输入电…

认知杂谈71《创业抉择:定制化与标准化的权衡之路》

内容摘要: *嘿,彦祖们!今天来聊聊创业的事,创业选产品类型很关键。定制化产品如魔法,贴合客户需求但成本高且有边际递减风险。要掌握物联网技术,用 3D 建模软件,参考特定书籍,参加展…

在线JSON可视化工具--支持缩放

先前文章提到的超好用的JSON可视化工具,收到反馈,觉得工具好用,唯一不足就是不能缩放视图,其实是支持的,因为滚轮有可能是往下滚动,会与缩放冲突,所以这个工具设计为需要双击视图来触发打开缩放…

C++ 线性表、内存操作、 迭代器,数据与算法分离。

线性表: 线性表是最基本、最简单、也是最常用的一种数据结构。线性表(linear list)是数据结构的 一种,一个线性表是n个具有相同特性的数据元素的有限序列。 线性表中数据元素之间的关系是一对一的关系,即除了第一个和…

Ubuntu2404安装

Ubuntu是一款非常优秀的发行版本,起初她的优势主要在于桌面版,但是随着Centos 从服务版的支持的退出,Ubuntu server也在迅猛的成长,并且不断收获了用户,拥有了一大批忠实的粉丝。好了,废话不多说&#xff0…

基于SSM的出租车租赁管理系统的设计与实现

文未可获取一份本项目的java源码和数据库参考。 1 选题的背景 现代社会,许多个人、家庭,因为生活、工作方式的改变,对汽车不再希望长期拥有,取而代之的是希望汽车能“召之即…