【一文讲通】样本不均衡问题解决--下

news2024/9/29 17:29:05

1欠采样、过采样

  • 欠采样:减少多数类的数量(如随机欠采样、NearMiss、ENN)。

  • 过采样:尽量多地增加少数类的的样本数量(如随机过采样、以及2.1.2数据增强方法),以达到类别间数目均衡。

  • 还可结合两者做混合采样(如Smote+ENN)。

具体还可以参见【scikit-learn的imbalanced-learn.org/stable/user_guide.html以及github的awesome-imbalanced-learning】

2数据增强

数据增强(Data Augmentation)是在不实质性的增加数据的情况下,从原始数据加工出更多数据的表示,提高原数据的数量及质量,以接近于更多数据量产生的价值,从而提高模型的学习效果(其实也是过采样的方法的一种。

其原理是,通过对原始数据融入先验知识,加工出更多数据的表示,有助于模型判别数据中统计噪声,加强本体特征的学习,减少模型过拟合,提升泛化能力。

如经典的机器学习例子--哈士奇误分类为狼:通过可解释性方法,可发现错误分类是由于图像上的雪造成的。通常狗对比狼的图像里面雪地背景比较少,分类器学会使用雪作为一个特征来将图像分类为狼还是狗,而忽略了动物本体的特征。此时,可以通过数据增强的方法,增加变换后的数据(如背景换色、加入噪声等方式)来训练模型,帮助模型学习到本体的特征,提高泛化能力。
需要关注的是,数据增强样本也有可能是引入片面噪声,导致过拟合。此时需要考虑的是调整数据增强方法,或者通过算法(可借鉴Pu-Learning思路)选择增强数据的最佳子集,以提高模型的泛化能力。

常用数据增强方法可分为:基于样本变换的数据增强及基于深度学习的数据增强。

数据样本层面解决不均衡的方法,需要关注的是:

  • 随机欠采样可能会导致丢弃含有重要信息的样本。在计算性能足够下,可以考虑数据的分布信息(通常是基于距离的邻域关系)的采样方法,如ENN、NearMiss等。

  • 随机过采样或数据增强样本也有可能是强调(或引入)片面噪声,导致过拟合。也可能是引入信息量不大的样本。此时需要考虑的是调整采样方法,或者通过半监督算法(可借鉴Pu-Learning思路)选择增强数据的较优子集,以提高模型的泛化能力。

2.1 基于样本变换的数据增强

样本变换数据增强即采用预设的数据变换规则进行已有数据的扩增,包含单样本数据增强多样本数据增强。

单样本数据增强

单(图像)样本增强主要有几何操作、颜色变换、随机擦除、添加噪声等方法,可参见imgaug开源库。

多样本数据增强方法

多样本增强是通过先验知识组合及转换多个样本,主要有Smote、SamplePairing、Mixup等方法在特征空间内构造已知样本的邻域值。

  • Smote

Smote(Synthetic Minority Over-sampling Technique)方法较常用于样本均衡学习,核心思想是从训练集随机同类的两近邻样本合成一个新的样本,其方法可以分为三步:

  1. 对于各样本X_i,计算与同类样本的欧式距离,确定其同类的K个(如图3个)近邻样本;

  1. 从该样本k近邻中随机选择一个样本如近邻X_ik,生成新的样本:

Xsmote_ik =  Xi  +  rand(0,1) ∗ ∣X_i − X_ik∣
  1. 重复2步骤迭代N次,可以合成N个新的样本。

python代码:

# SMOTE
from imblearn.over_sampling import SMOTE

print("Before OverSampling, counts of label\n{}".format(y_train.value_counts()))
smote = SMOTE()
x_train_res, y_train_res = smote.fit_resample(x_train, y_train)
print("After OverSampling, counts of label\n{}".format(y_train_res.value_counts()))
  • SamplePairing

SamplePairing算法的核心思想是从训练集随机抽取的两幅图像叠加合成一个新的样本(像素取平均值),使用第一幅图像的label作为合成图像的正确label。

  • Mixup

Mixup算法的核心思想是按一定的比例随机混合两个训练样本及其标签,这种混合方式不仅能够增加样本的多样性,且能够使决策边界更加平滑,增强了难例样本的识别,模型的鲁棒性得到提升。其方法可以分为两步:

  1. 从原始训练数据中随机选取的两个样本(xi, yi) and (xj, yj)。其中y(原始label)用one-hot 编码。

  1. 对两个样本按比例组合,形成新的样本和带权重的标签

x˜ = λxi + (1 − λ)xj
y˜ = λyi + (1 − λ)yj

最终的loss为各标签上分别计算cross-entropy loss,加权求和。其中 λ ∈ [0, 1], λ是mixup的超参数,控制两个样本插值的强度。

python代码:

# Mixup
def mixup_batch(x, y, step, batch_size, alpha=0.2):
    """
    get batch data
    :param x: training data
    :param y: one-hot label
    :param step: step
    :param batch_size: batch size
    :param alpha: hyper-parameter α, default as 0.2
    :return:  x y 
    """
    candidates_data, candidates_label = x, y
    offset = (step * batch_size) % (candidates_data.shape[0] - batch_size)
 
    # get batch data
    train_features_batch = candidates_data[offset:(offset + batch_size)]
    train_labels_batch = candidates_label[offset:(offset + batch_size)]

    if alpha == 0:
        return train_features_batch, train_labels_batch

    if alpha > 0:
        weight = np.random.beta(alpha, alpha, batch_size)
        x_weight = weight.reshape(batch_size, 1)
        y_weight = weight.reshape(batch_size, 1)
        index = np.random.permutation(batch_size)
        x1, x2 = train_features_batch, train_features_batch[index]
        x = x1 * x_weight + x2 * (1 - x_weight)
        y1, y2 = train_labels_batch, train_labels_batch[index]
        y = y1 * y_weight + y2 * (1 - y_weight)
        return x, y

2.2基于深度学习的数据增强

特征空间的数据增强

不同于传统在输入空间变换的数据增强方法,神经网络可将输入样本映射为网络层的低维向量(表征学习),从而直接在学习的特征空间进行组合变换等进行数据增强,如MoEx方法等。

基于生成模型的数据增强

生成模型如变分自编码网络(Variational Auto-Encoding network, VAE)和生成对抗网络(Generative Adversarial Network, GAN),其生成样本的方法也可以用于数据增强。这种基于网络合成的方法相比于传统的数据增强技术虽然过程更加复杂, 但是生成的样本更加多样。

  • 变分自编码器VAE

变分自编码器(Variational Autoencoder,VAE)其基本思路是:将真实样本通过编码器网络变换成一个理想的数据分布,然后把数据分布再传递给解码器网络,构造出生成样本,模型训练学习的过程是使生成样本与真实样本足够接近。

python代码

# VAE模型
class VAE(keras.Model):
    ...
    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }
  • 生成对抗网络GAN

生成对抗网络-GAN(Generative Adversarial Network) 由生成网络(Generator, G)和判别网络(Discriminator, D)两部分组成, 生成网络构成一个映射函数G: ZX(输入噪声z, 输出生成的图像数据x), 判别网络判别输入是来自真实数据还是生成网络生成的数据。

python代码:

# DCGAN模型

class GAN(keras.Model):
    ...
    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # G: Z→X(输入噪声z, 输出生成的图像数据x)
        generated_images = self.generator(random_latent_vectors)
        # 合并生成及真实的样本并赋判定的标签
        combined_images = tf.concat([generated_images, real_images], axis=0)
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # 标签加入随机噪声
        labels += 0.05 * tf.random.uniform(tf.shape(labels))
        # 训练判定网络
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )
        
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # 赋生成网络样本的标签(都赋为真实样本)
        misleading_labels = tf.zeros((batch_size, 1))
        # 训练生成网络
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        # 更新损失
        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {
            "d_loss": self.d_loss_metric.result(),
            "g_loss": self.g_loss_metric.result(),
        }

2.3 基于神经风格迁移的数据增强

神经风格迁移(Neural Style Transfer)可以在保留原始内容的同时,将一个图像的样式转移到另一个图像上。除了实现类似色彩空间照明转换,还可以生成不同的纹理和艺术风格。

神经风格迁移是通过优化三类的损失来实现的:

style_loss:使生成的图像接近样式参考图像的局部纹理;

content_loss:使生成的图像的内容表示接近于基本图像的表示;

total_variation_loss:是一个正则化损失,它使生成的图像保持局部一致。

python代码:

# 样式损失
def style_loss(style, combination):
    S = gram_matrix(style)
    C = gram_matrix(combination)
    channels = 3
    size = img_nrows * img_ncols
    return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels ** 2) * (size ** 2))

# 内容损失
def content_loss(base, combination):
    return tf.reduce_sum(tf.square(combination - base))

# 正则损失
def total_variation_loss(x):
    a = tf.square(
        x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, 1:, : img_ncols - 1, :]
    )
    b = tf.square(
        x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, : img_nrows - 1, 1:, :]
    )
    return tf.reduce_sum(tf.pow(a + b, 1.25))

2.4 基于元学习的数据增强

深度学习研究中的元学习(Meta learning)通常是指使用神经网络优化神经网络,元学习的数据增强有神经增强(Neural augmentation)等方法。

  • 神经增强

神经增强(Neural augmentation)是通过神经网络组的学习以获得较优的数据增强并改善分类效果的一种方法。其方法步骤如下:

  1. 获取与target图像同一类别的一对随机图像,前置的增强网络通过CNN将它们映射为合成图像,合成图像与target图像对比计算损失;

  1. 将合成图像与target图像神经风格转换后输入到分类网络中,并输出该图像分类损失;

  1. 将增强与分类的loss加权平均后,反向传播以更新分类网络及增强网络权重。使得其输出图像的同类内差距减小且分类准确。

3 代价敏感学习cost-sensitive

最常用的是class weight

scikit模型的’class weight‘方法,If ‘balanced’, class weights will be given by n_samples / (n_classes * np.bincount(y)). If a dictionary is given, keys are classes and values are corresponding class weights. If None is given, the class weights will be uniform.

class weight可以为不同类别的样本提供不同的权重(少数类有更高的权重),从而模型可以平衡各类别的学习。

如下图通过为少数类做更高的权重,以避免决策偏重多数类的现象(类别权重除了设定为balanced,还可以作为一个超参搜索。示例代码请见github.com/aialgorithm):

clf2 = LogisticRegression(class_weight={0:1,1:10})  # 代价敏感学习

4 OHEM 和 Focal Loss

In this work, we first point out that the class imbalance can be summarized to the imbalance in difficulty and the imbalance in difficulty can be summarized to the imbalance in gradient norm distribution.
--原文可见《Gradient Harmonized Single-stage Detector》

大意是,类别的不平衡可以归结为难易样本的不平衡,而难易样本的不平衡可以归结为梯度的不平衡。按照这个思路,OHEM和Focal loss都做了两件事:难样本挖掘以及类别的平衡。

此外还有GHM、 PISA等方法,后续会更新

  • OHEM(Online Hard Example Mining)算法的核心是选择一些hard examples(多样性和高损失的样本)作为训练的样本,针对性地改善模型学习效果。对于数据的类别不平衡问题,OHEM的针对性更强。

  • Focal loss的核心思想是在交叉熵损失函数(CE)的基础上增加了类别的不同权重以及困难(高损失)样本的权重(如下公式),以改善模型学习效果。

5 采样+集成学习

通过重复组合少数类样本与抽样的同样数量的多数类样本,训练若干的分类器进行集成学习。

  • BalanceCascade BalanceCascade基于Adaboost作为基分类器,核心思路是在每一轮训练时都使用多数类与少数类数量上相等的训练集,然后使用该分类器对全体多数类进行预测,通过控制分类阈值来控制FP(False Positive)率,将所有判断正确的类删除,然后进入下一轮迭代继续降低多数类数量。

  • EasyEnsemble EasyEnsemble也是基于Adaboost作为基分类器,就是将多数类样本集随机分成 N 个子集,且每一个子集样本与少数类样本相同,然后分别将各个多数类样本子集与少数类样本进行组合,使用AdaBoost基分类模型进行训练,最后bagging集成各基分类器,得到最终模型。示例代码可见:www.kaggle.com/orange90/ensemble-test-credit-score-model-example

通常,在数据集噪声较小的情况下,可以用BalanceCascade,可以用较少的基分类器数量得到较好的表现(基于串行的集成学习方法,对噪声敏感容易过拟合)。噪声大的情况下,可以用EasyEnsemble,基于串行+并行的集成学习方法,bagging多个Adaboost过程可以抵消一些噪声影响。此外还有RUSB、SmoteBoost、balanced RF等其他集成方法可以自行了解。

6 异常检测

相关介绍:

【异常检测】14种异常检测算法_allein_STR的博客-CSDN博客_异常行为监测算法有哪些

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/150119.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

地址解析协议ARP

目录地址解析协议ARP1、流程2、动态与静态的区别3、ARP协议适用范围地址解析协议ARP 如何从IP地址找出其对应的MAC地址? 1、流程 ARP高速缓存表 当主机B要给主机C发送数据包时,会首先在自己的ARP高速缓存表中查找主机C的IP地址所对应的MAC地址&#xf…

Linux常用命令——lsblk命令

在线Linux命令查询工具 lsblk 列出块设备信息 补充说明 lsblk命令用于列出所有可用块设备的信息,而且还能显示他们之间的依赖关系,但是它不会列出RAM盘的信息。块设备有硬盘,闪存盘,cd-ROM等等。lsblk命令包含在util-linux-ng…

ES报文辅助生成工具-JavaFX

此程序为基于 Java8 开发的 JavaFX Maven 工程&#xff0c;是 Java 组装ElasticSearch请求报文工具的辅助 Java 代码生成工具&#xff0c;方便开发者快速编写代码。现学现用&#xff0c;写得不好。 工具界面 代码 pom.xml <project xmlns"http://maven.apache.org/P…

Android:URLEncoder空格被转码为“+”号

Android前段和后端接口交互时&#xff0c;经常会遇到特殊字符&#xff0c;比如表情、特殊标点等&#xff0c;这样在Url中是无法识别的&#xff0c;需要进行转码&#xff0c;后端进行解码交互。 但当使用URLEncoder时&#xff0c;会发现字符串中的空格被转换成“”号&#xff0…

客服系统即时通讯IM开发(四)网站实现实时在线访客列表【唯一客服】网站在线客服系统...

在使用我的客服系统时&#xff0c;如果引入了我的js &#xff0c;就可以实时看到网站上的所有访客了 使用 WebSocket 技术来实现实时通信。 在访客登录或退出时&#xff0c;向指定客服的 WebSocket 客户端发送消息。例如&#xff0c;你可以在访客登录时&#xff0c;向指定客服…

测试用例的设计? 万能公式

万能公式(必背)&#xff1a;功能测试性能测试界面测试兼容性测试易用性测试安全测试功能测试 &#xff1a;可能来自于需求文档&#xff0c;也可能来自生活经验性能测试 &#xff1a;功能没有问题不代表性能是ok的&#xff0c;性能往往体现在一些极端情况界面测试 &#xff1a;颜…

Prometheus-基于Consul的自动注册

一、背景介绍 如果我们的物理机有很多&#xff0c;不管是基于"file_sd_config"还是"kubernetes_sd_config"&#xff0c;我们都需要手动写入目标target及创建目标service&#xff0c;这样才能被prometheus自动发现&#xff0c;为了避免重复性工作过多&#…

【182】Java8利用二叉查找树实现Map

本文利用二叉查找树写了一个Map&#xff0c;用来保存键值对。 二叉查找树的定义 二叉查找树又名二叉搜索树&#xff0c;英文名称是 Binary Search Tree&#xff0c;缩写BST。 二叉排序树&#xff0c;英文名称是 Binary Sorted Tree&#xff0c;缩写BST。 二叉查找树、二叉搜…

excel实用技巧:如何构建多级下拉菜单

使用数据有效性制作下拉菜单对大多数小伙伴来说都不陌生&#xff0c;但说到二级和三级下拉菜单大家可能就不是那么熟悉了。什么是二级和三级下拉菜单呢&#xff1f;举个例子&#xff0c;在一个单元格选择某个省后&#xff0c;第二个单元格选项只能出现该省份所属的市&#xff0…

vue-router原理简单实现

vue-router简单实现 初步预习 动态路由 获取id方式 第一种强依赖路由 第二种父传子方式&#xff08;推荐&#xff09; 嵌套路由 相同的头和尾&#xff0c;默认index&#xff0c;替换为detail 编程时导航 this.$router.push() this.$router.repleace() this.$router.g…

吊炸天,springboot的多环境配置一下搞明白了!

1、 使用springboot的profile命名规则profile用于多环境的激活和配置&#xff0c;用来切换生产&#xff0c;测试&#xff0c;本地等多套不通环境的配置。如果每次去更改配置就非常麻烦&#xff0c;profile就是用来切换多环境配置的。在Spring Boot框架中&#xff0c;使用Profil…

漏洞优先级排序的六大关键因素

当我们谈及开源漏洞时&#xff0c;我们会发现其数量永远处于增长状态。根据安全公司 Mend 研究发现&#xff0c;在 2022 年前九个月发现并添加到其漏洞数据库中的开源漏洞数量比 2021 年增加了 33%。该报告从 2022 年 1 月到 2022 年 9 月对大约 1,000 家北美公司进行了代表性抽…

一篇文章解决C语言操作符

我的主页&#xff1a;一只认真写代码的程序猿本文章是关于C语言操作符的讲解收录于专栏【C语言的学习】 目录 1、算术操作符 2、赋值操作符 3、关系操作符 4、条件操作符&#xff08;三目&#xff09; 5、逻辑操作符 6、单目操作符 7、移位操作符 8、位操作符 9、逗号…

使用Docker+Nignx部署vue项目

文章目录一、前言二、vue项目打包三、nginx基本介绍①nginx常用的功能&#xff1a;②nginx默认的主题配置文件解读③nginx目录解读三、docker内部署nginx①拉取nginx镜像②创建数据持久化目录☆☆☆③创建需要映射进去的文件④运行nginx四、大工告成最近&#xff08;之前&#…

2023年DAMA-CDGA/CDGP数据治理工程师认证(线上班)报名

DAMA认证为数据管理专业人士提供职业目标晋升规划&#xff0c;彰显了职业发展里程碑及发展阶梯定义&#xff0c;帮助数据管理从业人士获得企业数字化转型战略下的必备职业能力&#xff0c;促进开展工作实践应用及实际问题解决&#xff0c;形成企业所需的新数字经济下的核心职业…

gcc、g++,linux升级gcc、g++

安装cv-cuda库&#xff0c;要求gcc11&#xff0c;cmake>3.22版本。 Linux distro:Ubuntu x86_64 > 18.04WSL2 with Ubuntu > 20.04 (tested with 20.04) CUDA Driver > 11.7 (Not tested on 12.0) GCC > 11.0 Python > 3.7 cmake > 3.22gcc、g介绍 参考&…

手把手安装GNN必备库 —— pytorch_geometric

0 BackGround GNN&#xff1a;图神经网络&#xff0c;由于传统的CNN网络无法表示顶点和边这种关系型数据&#xff0c;便出现了图神经网络解决这种图数据的表示问题&#xff0c;这属于CNN往图方向的应用扩展。 GCN&#xff1a;图卷积神经网络&#xff0c;GNN在训练过程中&#…

【ONE·R || 两次作业(二):GEO数据处理下载分析】

总言 两次作业汇报其二&#xff1a;GEO数据处理学习汇报。    文章目录总言2、作业二&#xff1a;GEO数据处理下载分析2.1、GEO数据库下载前准备2.2、GEO数据库下载及数据初步处理2.2.1、分阶段解析演示2.2.1.1、编号下载流程2.2.1.2、对gset[ 1 ]初步分析2.2.1.3、对gset[ 2…

基于requests框架实现接口自动化测试项目实战

requests库是一个常用的用于http请求的模块&#xff0c;它使用python语言编写&#xff0c;在当下python系列的接口自动化中应用广泛&#xff0c;本文将带领大家深入学习这个库&#xff0c;Python环境的安装就不在这里赘述了&#xff0c;我们直接开干。 01 requests的安装 win…

销售结束语话术

销售要记住&#xff0c;结束语不代表结束&#xff0c;而是下一次沟通的开始&#xff0c;所以销售要学会通过结束语来为自己争取下次沟通的机会。 前言 不论是哪一行业&#xff0c;对于销售而言&#xff0c;大多数成交的客户都是经过持续有效的跟踪的&#xff0c;还会出现有很多…