【深度学习炼丹】不平衡样本的处理

news2025/1/16 8:05:34

目录:不平衡样本的处理

  • 一、前言
  • 二、数据层面处理方法
    • 2.1 数据扩充
    • 2.2 数据(重)采样
    • 2.3 类别平衡采样
  • 三、算法(损失函数)层面处理方法
    • 3.1 Focal Loss
    • 3.2 损失函数加权
  • 四、参考资料

一、前言

在机器学习的经典假设中往往假设训练样本各类别数目是均衡的,但在实际场景中,训练样本数据往往都是不均衡(不平衡)的。比如在图像二分类问题中,一个极端的例子是,训练集中有 95 个正样本,但是负样本只有 5 个。这种类别数据不均衡的情况下,如果不做不平衡样本的处理,会导致模型在数目较少的类别上出现“欠学习”现象,即可能在测试集上完全丧失对负样本的预测能力。

除了常见的分类、回归任务,类似图像语义分割、深度估计等像素级别任务中也是存在不平衡样本问题的。

解决不平衡样本问题的处理方法一般有两种:

  • 从“数据层面”入手:分为数据采样法和类别平衡采样法。
  • 从“算法层面”入手:代价敏感方法。

注意本文只介绍不平衡样本的处理思想和策略,不涉及具体代码,在实际项目中,需要针对具体任务,结合不平衡样本的处理策略来设计具体的数据集处理或损失函数代码,从而解决对应问题。

二、数据层面处理方法

数据层面的处理方法总的来说分为数据扩充和采样法,数据扩充会直接改变数据样本的数量和丰富度,采样法的本质是使得输入到模型的训练集样本趋向于平衡,即各类样本的数目趋向于一致。

数据层面的采样处理方法主要有两种策略:

  • 数据重采样方法,发生在数据预处理阶段,会改变整体训练集的数目和分布。
  • 类别平衡采样方法,发生在数据加载阶段(这里的加载是指加载到模型中,不是指从硬盘中读取文件),通过设置采样策略来使得不同类别样本送入模型训练总的次数是近似的。

2.1 数据扩充

所谓数据不平衡,其实就是某些类别的数据量太少,那就直接增加一些呗,简单直接。如果有的选,那肯定是优先选择重新采取数据的办法了,当然大部分时候我们都没得选,这个时候最有效的办法自然是通过数据增强来扩充数据了。

数据增强的手段有多种,常见的如下:

  • 水平 / 竖直翻转
  • 90°,180°,270° 旋转
  • 翻转 + 旋转(旋转和翻转其实是保证了数据特征的旋转不变性能被模型学习到,卷积层面的方法可以参考论文 ACNet)
  • 亮度,饱和度,对比度的随机变化
  • 随机裁剪(Random Crop)
  • 随机缩放(Random Resize)
  • 加模糊(Blurring)
  • 加高斯噪声(Gaussian Noise)

值得注意的是数据增强手段的使用必须结合具体任务而来,除了前三种以外,其他的要慎重考虑。因为不同的任务场景下数据特征依赖不同,比如高斯噪声,在天池铝材缺陷检测竞赛中,如果高斯噪声增加不当,有些图片原本在采集的时候相机就对焦不准,导致工件难以看清,倘若再增加高斯模糊属性,部分图片样本基本就废了。

虽然目前深度学习框架中都自带了一些数据增强函数,但更多更强的数据增强手段可以使用一些图像增强库,比如 imgaug 这个 python 库。

2.2 数据(重)采样

简单的数据重采样方法分为数据上采样(over-sampling、up-sampling,也叫数据过采样) 或 也叫数据欠采样数据下采样(under-sampling 、down-sampling )。

1,对于样本数目较少的类别,可用数据过采样方法(over-sampling),即通过复制方法使得该类图像数目增至与样本最多类的样本数一致。

2,而对于样本数较多的类别,可使用数据欠采样(Under-sampling,也叫数据欠采样)方法。对于深度学习和计算机视觉领域的任务来说,下采样并不是直接随机丢弃一部分图像,正确的下采样策略是: 在批处理训练时(数据加载阶段 dataloader),对于样本较多的类别,严格控制每批(batch)随机抽取的图像数目,使得每批读取的数据中正负样本是均衡的(类别均衡)。以二分类任务为例,假设原始数据分布情况下每批处理训练正负样本平均数量比例为 9:1,如仅使用下采样策略,则可在每批随机挑选训练样本时每 9 个正样本只取 1 个作为该批训练集的正样本,负样本选择策略不变,这样可使得每批读取的训练数据中正负样本时平衡的。

数据过采样和欠采样示意图如下所示。

在这里插入图片描述
数据采样方法总结:

数据过采样和欠采样本质的简单理解就是“增加图片”和“删图片”:

  • 过采样:重复正比例数据,实际上没有为模型引入更多形式数据,过分强调正比例数据,会放大正比例噪音对模型的影响。
  • 欠采样:丢弃大类别的部分数据,和过采样一样会存在过拟合的问题。

同时两种数据重采样方法都是会改变数据原始分布的,比如数据过采样增加较小类别的样本数,数据欠采样减少较大类别的样本数,有可能产生模型过拟合等问题。

这里的较小类别的意思是样本数目较少的类别,较大类别即样本数目较多的类别。

以上内容都是对解决类别不平衡问题中数据采样方法的策略描述,但想要在实际任务中解决问题,还要求我们加深对任务(task)的分析、对数据的理解分析,以及要求我们有更多的数据处理、数据采样的代码经验,即良好的策略 + 熟练的工具。

需要注意的是,因为仅仅使用数据上采样策略有可能会引起模型过拟合问题,所以在实际任务中,更为保险的数据采样策略哇往往是将上采样和下采样结合起来使用。

2.3 类别平衡采样

前面的数据重采样策略是着重于类别样本数量,而另一类采样策略则是直接着重于类别本身,不改变数据总体样本数,即类别平衡采样方法。其简单策略是把样本按类别分组,每个类别生成一个样本列表,训练过程中随机选择 1 个或几个类别,然后从每个类别所对应的样本列表中随机选择样本,这样可保证每个类别参与训练的机会比较均衡。

上述类别平衡方法过于简单,实际应用中有很多限制,比如在类别数很多的多分类任务中(如 ImageNet 数据集)。由此,在类别平衡采样的基础上,国内海康威视研究院提出了一种“类别重组采样”的平衡方法。

类别重组法是在《解析卷积神经网络》这本书中看到的,可惜没在网上找到原论文和代码,但这个方法感觉还是很有用的,且也比较好复现。

如下图所示,类别重组方法步骤如下:
在这里插入图片描述

  1. 对原始样本的每个类别的样本分别排序好,计算每个类别的样本数目,并记录样本数最多的那个类别的样本数量 max_num。
  2. 基于最大样本数 max_num 产生一个随机数列表,然后用此列表中的随机数对各自类别的样本数求余,得到对应索引值列表
    index_list。
  3. 根据该索引值列表 index_list,从该类的图像数据中提取图像,生成该类的图像随机列表。
  4. 最后吧所有类别的随机列表连接在一起后一起随机打乱次序,即可得到最终的图像列表,可以发现最终的这个图像随机列表中每个类别的样本数目是一致的(样本数较少的类别,图像会存在多次采样)。然后每轮(epoch)都对此列表进行遍历数据用于模型训练,如此重复。

以上方法整体还是比较复现的,结合具体任务来设计代码就行,这里给出一个简单的生成一段范围为 [1, 10] 的随机整数列示例代码。

import random
# 生成一段范围为[0, 9]的随机整数列表
# sample(L, n) 函数: 从序列L中随机抽取n个元素,并将n个元素以list形式返回。
# 也可用 random.shuffle(L) 函数原地打乱列表
random_list = random.sample(range(0, 10), 10)
print(random_list)

类别重组法对有点很明显,在设计好重组代码函数后,只需要原始图像列表即可,所有操作都在内存中在线完成,易于实现且更通用。其实仔细深究可以发现,海康提出的这个类别重组法和前面的数据采样方法是很类似的,其本质都是通过采样(sampler)策略让类别不均衡的各类数据在每轮训练中出现的次数是一致的。

三、算法(损失函数)层面处理方法

类别不平衡问题的本质是导致样本数目较少的类别出现“欠学习”这一机器学习现象,直观表现是较小样本的损失函数权重占比也较少。一个很自然的解决办法是增加小样本错分的惩罚代价,并将此代价直接体现在目标函数(损失函数)里,这就是“代价敏感”的方法。“代价敏感”方法的本质可以理解为调整模型在小类别上的注意力。

3.1 Focal Loss

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.2 损失函数加权

除了 Focal Loss 这种高明的损失函数策略外,针对图像分类问题,还有一种简单直接的损失函数加权方法,即在计算损失函数过程中,对每个类别的损失做加权处理,具体的 PyTorch 实现方式如下:

weights = torch.FloatTensor([1, 1, 8, 8, 4]) # 类别权重分别是 1:1:8:8:4
# pos_weight_weight(tensor): 1-D tensor,n 个元素,分别代表 n 类的权重,
# 为每个批次元素的损失指定的手动重新缩放权重,
# 如果你的训练样本很不均衡的话,是非常有用的。默认值为 None。
criterion = nn.BCEWithLogitsLoss(pos_weight=weights).cuda()

四、参考资料

  • 如何针对数据不平衡做处理: http://spytensor.com/index.php/archives/45/
  • 数据增强-imgaug: https://datawhalechina.github.io/thorough-pytorch/%E7%AC%AC%E5%85%AD%E7%AB%A0/6.5%20%E6%95%B0%E6%8D%AE%E5%A2%9E%E5%BC%BA-imgaug.html
  • focal loss 论文: https://arxiv.org/pdf/1708.02002.pdf
  • 如何针对数据不平衡做处理: http://spytensor.com/index.php/archives/45/
  • 10 Techniques to deal with Imbalanced Classes in Machine Learning:
    https://www.analyticsvidhya.com/blog/2020/07/10-techniques-to-deal-with-class-imbalance-in-machine-learning/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/178476.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VMWare 移动Linux CentOS 7虚拟机后连不上网怎么办

研究hadoop的时候发现虚拟机太大了,于是把3台节点的虚拟机剪切粘贴到移动硬盘上,但是出现了上不了网的问题 VMWare 移动Linux CentOS 7虚拟机后连不上网,ifconfig命令只出现lo不出现有IP地址的ens33,jps命令也出现了jps command …

机器学习模型搭建与评估

模型搭建和评估第三章 模型搭建和评估--建模模型搭建任务一:切割训练集和测试集任务二:模型创建任务三:输出模型预测结果第三章 模型搭建和评估-评估模型评估任务一:交叉验证任务二:混淆矩阵任务三:ROC曲线…

python爬虫学习笔记-mysql数据库介绍下载安装

数据库概述 为什么要使用数据库? 那我们在没有学习数据库的时候,数据存放在json或者磁盘文件中不也挺好的嘛,为啥还要学习数据库? 文件中存储数据,无法基于文件直接对数据进行操作或者运算,必须借助python将…

IDEA搭建Finchley.SR2版本的SpringCloud父子基础项目-------Ribbon负载均衡

1.概念 Spring Cloud Ribbon是基于Netflix Ribbon实现的一套客户端负载均衡的工具。简单的说,Ribbon是Netflix发布的开源项目,主要功能是提供客户端的软件负载均衡算法,将Netflix的中间层服务连接在一起。Ribbon客户端组件提供一系列完善的配…

Python闭包与闭包陷阱

1 什么是闭包 在 Python 中,闭包是一种特殊的函数,它能够记住它所在的环境(也称作上下文)。这意味着闭包能够访问定义它的作用域中的变量。闭包通常用于封装数据和提供对外部访问的接口。 在 Python 中使用闭包有以下几点好处&a…

数据库和SQL概述

数据库和SQL概述 数据库的好处 实现数据的持久化使用完整的管理系统统一管理,易于查询 常用的一些名称缩写 DB:数据库(Database):存储数据的“仓库”。它保存了一系列有组织的数据DBMS:数据库管理系统(Database Management Sy…

离线用户召回定时更新

3.6 离线用户召回定时更新 学习目标 目标 知道离线内容召回的概念知道如何进行内容召回计算存储规则应用 应用spark完成离线用户基于内容的协同过滤推荐 3.6.1 定时更新代码 完整代码 import os import sys # 如果当前代码文件运行测试需要加入修改路径,否则后面…

游戏启动器:LaunchBox Premium with Big Box v13.1

LaunchBox知道您会喜欢的功能,具有风格的游戏启动器,我们最初将 Launchbox 构建为 DOSBox 的一个有吸引力的前端,但它现在拥有对现代游戏和复古游戏模拟的支持。我们让您的所有游戏看起来都很漂亮。 整理您的游戏收藏 我们不仅漂亮&#xff…

基于微信小程序奶茶店在线点餐下单系统

奶茶在线下单系统用户端是基于微信小程序端,管理员端是基于web端,基于java编程语言,mysql数据库,idea工具开发,用户微信端可以注册登陆小程序,查看奶茶详情,搜索下单奶茶,在线奶茶评…

CSS @property(CSS 自定义属性)

CSS property(CSS 自定义属性)参考描述propertyHoudiniproperty兼容性描述符规则syntax扩展initial-valueinherits示例描述符的注意事项使用 JavaScript 来创建自定义属性CSS 变量与自定义属性重复赋值过渡简单的背景过渡动画更复杂的背景过渡动画错误示…

【ARM体系结构】之数据类型约定与工作模式

1、RISC和CISC的区别 1.1 RISC : 精简指令集 使用精简指令集的架构:ARM架构 RISC-V架构 PowerPC架构 MIPS架构ARM架构 :目前使用最广泛的架构,ARM面向的低端消费类市场RISC-V架构 :第五代,精简指令集的架构&#xff…

这样定义通用人工智能

🍿*★,*:.☆欢迎您/$:*.★* 🍿 正文 人类解决问题的途径,大体可以分为两种。一种是事实推理,另一种是事实验证。 为什么只是两种分类,因为根据和环境的交互与否。 事实推理解释为当遇到事件发生的时候,思考的过程。可以使用概率模型,或者更复杂的模型(目前没…

Out of Vocabulary处理方法

Out of Vocabulary 我们在NLP任务中一般都会有一个词表,这个词表一般可以使用一些大牛论文中的词表或者一些大公司的词表,或者是从自己的数据集中提取的词。但是无论当后续的训练还是预测,总有可能会出现并不包含在词表中的词,这…

(教程)如何在BERT模型中添加自己的词汇(pytorch版)

来源:投稿 作者:皮皮雷 编辑:学姐 参考文章: NLP | How to add a domain-specific vocabulary (new tokens) to a subword tokenizer already trained like BERT WordPiece | by Pierre Guillou | Medium https://medium.com/pi…

ROS2机器人编程简述humble-第三章-BUMP AND GO IN C++ .3

简述本章项目,参考如下:ROS2机器人编程简述humble-第三章-PERCEPTION AND ACTUATION MODELS .1流程图绘制,参考如下:ROS2机器人编程简述humble-第三章-COMPUTATION GRAPH .2然后,在3.3和3.4分别用C和Python编程实现&am…

Bus Hound 工具抓取串口数据(PC端抓取USB转串口数据)

测试环境: PC端 USB转串口 链接终端板卡串口 目标:抓取通信过程中的通信数据 工具介绍:Bus Hound是是由美国perisoft公司研制的一款超级软件总线协议分析器,它是一种专用于PC机各种总线数据包监视和控制的开发工具软件&#xff0c…

通信原理简明教程 | 数字调制传输

文章目录1 二进制数字调制和解调1.1 二进制数字调制的基本原理1.2 二进制数字调制信号的特性1.3 解调方法2 二进制差分相移键控2.1 2PSK的倒π现象2.2 2DPSK调制和解调3 二进制调制系统的抗噪性能3.1 2ASK系统的抗噪声性能3.2 2FSK系统的抗噪声性能4 二进制数字调制系统性能比较…

服务器配置定时脚本 crontab + Python;centos6或centos 7或centos8 实现定时执行 python 脚本

一、crontab的安装 默认情况下,CentOS 7中已经安装有crontab,如果没有安装,可以通过yum进行安装。 yum install crontabs 二、crontab的定时语法说明 corntab中,一行代码就是一个定时任务,其语法结构可以通过这个图来理解。 字符含义如下: * 代表取值范围内的数字 /…

Linux内核驱动初探(三) 以太网卡

目录 0. 前言 1. menuconfig 2. 设备树 0. 前言 这次的网卡驱动就比较顺利,基本就是参考 4.19.x 内核以及 imx6qdl-sabrelite.dtsi、imx6qdl-sabreauto.dtsi 中的设备树,来设置以太网各项参数。 1. menuconfig 其实笔者接手的时候,网口这…

本质安全设备标准(IEC60079-11)的理解(三)

本质安全设备标准(IEC60079-11)的理解(三) 对于标准中“fault”的理解 第一,标准中对fault的定义是这样的: 3.7.2 fault any defect of any component, separation, insulation or connection between c…