Early Stopping中基于测试集(而非验证集)上的表现选取模型的讨论

news2025/1/9 2:02:25

论文中一般都是用在验证集上效果最好的模型去预测测试集,多次预测的结果取平均计算准确率或者mAP值而不是单纯的取一次验证集最好的结果作为论文的结果。如果你在写论文的过程中,把测试集当做验证集去验证的话,这其实是作假的,建议不要这样,一旦有人举报或者复现出来你的结果和你论文中的结果相差很大的话,是会受到很大处分的。

我之前曾遇到过这种情况,我在图像分类的过程中曾经用过CutMix增强方式,CutMix其实就是将两张图片放在一起,如下图所示,这种结果会造成验证集上准确率很大的波动,可能一会儿变成99%,一会儿变成88%,那我总不能拿99%作为我论文中的结果啊,所以还是要以最终的测试集的准确率为主,因为这个才是我们需要关注的。

如果只是单纯的取提高准确率的话可以看看文中下面的一些方式,这些方式的提升一定会比单纯取最好的模型的效果要好的。

首先我们需要理解一下概念,什么是训练集?什么是验证集?什么是测试集?大家很容易将“验证集”与“测试集”,“交叉验证”混淆。

首先我们来了解一下基本的概念哈,然后在分析如何解决分类问题,提高模型的准确率和泛化能力。

训练集、验证集、测试集:

训练集(train set) —— 用于模型拟合的数据样本。
验证集(development set)—— 是模型训练过程中单独留出的样本集,它可以用于调整模型的超参数(包括EarlyStopping的epoch)和用于对模型的能力进行初步评估。
测试集 —— 用来评估模最终模型的泛化能力,也就是计算论文里最后各个评价指标的值。但不能作为调参、选择特征等算法相关的选择的依据。

一个形象的比喻:
训练集:学生的课本;学生 根据课本里的内容来掌握知识。
验证集:作业,通过作业可以知道 不同学生学习情况、进步的速度快慢。
测试集:考试,考的题是平常都没有见过,考察学生举一反三的能力。

为什么验证数据集测试数据集两者都需要?

因为验证数据集(Validation Set)用来调整模型参数从而选择最优模型,模型本身已经同时知道了输入和输出,所以从验证数据集上得出的误差(Error)会有偏差(Bias)。
但是我们只用测试数据集(Test Set) 去评估模型的表现,并不会去调整优化模型
传统上,一般三者切分的比例是:6:2:2,验证集并不是必须的,即验证集可有可无(可以选择不用验证集调整超参数)
所以,很多人口中的测试集应该是验证集的意思,测试集不能作为调参、选择特征等算法相关的选择的依据。

K-折交叉验证(K-fold Cross Validation,记为K-CV)

就按照作者说的10折交叉来说,算法步骤是(图如1):

  1. 将数据集分成十份,轮流将其中9份作为训练数据,1份作为测试数据,进行试验。每次试验都会得出相应的正确率。
  2. 10次的结果的正确率的平均值作为对算法精度的估计,一般还需要进行多次10折交叉验证(例如10次10折交叉验证),再求其均值,作为对算法准确性的估计。
    在这里插入图片描述在数据缺乏的情况下使用,如果设原始数据有N个样本,那么LOO-CV就是N-CV,即每个样本单独作为验证集,其余的N-1个样本作为训练集,故LOO-CV会得到N个模型,用这N个模型最终的验证集的分类准确率的平均数,作为此下LOO-CV分类器的性能指标。
    优点:(1)每一回合中几乎所有的样本皆用于训练模型,因此最接近原始样本的分布,这样评估所得的结果比较可靠。(2)实验过程中没有随机因素会影响实验数据,确保实验过程是可以被复制的。
    缺点:计算成本高,需要建立的模型数量与原始数据样本数量相同。当数据集较大时几乎不能使用。

关于EarlyStopping保存最优的模型

Pytroch保存最优的训练模型

		min_loss = 100000#随便设置一个比较大的数
    for epoch in range(epochs):
	    train()
	    val_loss = val()
	    if val_loss < min_loss:
	       min_loss = val_loss
	       print("save model")
	       torch.save(net.state_dict(),'model.pth')

图像分类

当我们拿到数据集的第一步就是对数据进行分析

  • 各个种类的样本是否平衡,如果数据不平衡的话对模型准确性的影响是非常大的。举个栗子:如果有100个样本,其中99个是正样本,1个是负样本。那么就算我们把所有的样本都预测成正样本,准确率也有99%,显然这是不对的,因为这个模型只能识别正样本了,泛化能力非常差。这个时候,可以考虑改动损失函数,增加难样本(负样本)的权重,减少简单样本(正样本)的权重。
  • 数据集中是否含有大量噪声。
  • 数据量是否充足等等.

接下来我们讲解一下图像分类中的一些trick:
1.数据清洗
借助弱监督方式引入外部数据集中的高质量数据——解决了自行扩展数据集带来的测试偏移。步骤如下:

  • 使用训练数据建立模型
  • 预测爬取的数据的标签,对外部数据进行伪标签标注。
  • 结合样本分布和混淆矩阵的结果,设置了多级阈值,选择可信度高的数据,组合成新的数据集
  • 重复1,2,3。

2.数据增强
①几何变换——只进行水平翻转和平移0.05

transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
transforms.RandomHorizontalFlip()

②CutMix

3.数据不均衡
在数据层面和算法层面同时测试选取—— 上采样和class_wight相结合
①上采样——通过混淆矩阵和验证集的随机化设置,提取模型预测错误的数据,然后按照一定的权重进行数据复制扩充,为了减少上采样可能带来的过拟合问题,我们对扩充的数据进行了区域裁剪,使数据更倾向于需要关注的部分。
②class_weight——将不同的类别映射为不同的权值,该参数用来在训练过程中调整损失函数(只能用于训练)。该参数在处理非平衡的训练数据(某些类的训练样本数很少)时,可以使得损失函数对样本数不足的数据更加关注。

4.数据标签平滑
目的:减小过拟合问题
平滑过后的样本交叉熵损失就不仅考虑到了训练样本中正确的标签位置的损失,也稍微考虑到其他错误标签位置的损失,导致最后的损失增大,导致模型的学习能力提高,即要下降到原来的损失,就得学习的更好,也就是迫使模型往增大正确分类概率并且同时减小错误分类概率的方向前进。

#!/usr/bin/python
# -*- encoding: utf-8 -*-


import torch
import torch.nn as nn


class LabelSmoothSoftmaxCE(nn.Module):
    def __init__(self,
                 lb_pos=0.9,
                 lb_neg=0.005,
                 reduction='mean',
                 lb_ignore=255,
                 ):
        super(LabelSmoothSoftmaxCE, self).__init__()
        self.lb_pos = lb_pos
        self.lb_neg = lb_neg
        self.reduction = reduction
        self.lb_ignore = lb_ignore
        self.log_softmax = nn.LogSoftmax(1)

    def forward(self, logits, label):
        logs = self.log_softmax(logits)
        ignore = label.data.cpu() == self.lb_ignore
        n_valid = (ignore == 0).sum()
        label = label.clone()
        label[ignore] = 0
        lb_one_hot = logits.data.clone().zero_().scatter_(1, label.unsqueeze(1), 1)
        label = self.lb_pos * lb_one_hot + self.lb_neg * (1-lb_one_hot)
        ignore = ignore.nonzero()
        _, M = ignore.size()
        a, *b = ignore.chunk(M, dim=1)
        label[[a, torch.arange(label.size(1)), *b]] = 0

        if self.reduction == 'mean':
            loss = -torch.sum(torch.sum(logs*label, dim=1)) / n_valid
        elif self.reduction == 'none':
            loss = -torch.sum(logs*label, dim=1)
        return loss


if __name__ == '__main__':
    torch.manual_seed(15)
    criteria = LabelSmoothSoftmaxCE(lb_pos=0.9, lb_neg=5e-3)
    net1 = nn.Sequential(
        nn.Conv2d(3, 3, kernel_size=3, stride=2, padding=1),
    )
    net1.cuda()
    net1.train()
    net2 = nn.Sequential(
        nn.Conv2d(3, 3, kernel_size=3, stride=2, padding=1),
    )
    net2.cuda()
    net2.train()

    with torch.no_grad():
        inten = torch.randn(2, 3, 5, 5).cuda()
        lbs = torch.randint(0, 3, [2, 5, 5]).cuda()
        lbs[1, 3, 4] = 255
        lbs[1, 2, 3] = 255
        print(lbs)

    import torch.nn.functional as F
    logits1 = net1(inten)
    logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear')
    logits2 = net2(inten)
    logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear')

    #  loss1 = criteria1(logits1, lbs)
    loss = criteria(logits1, lbs)
    #  print(loss.detach().cpu())
    loss.backward()

5.双损失函数
categorical_crossentropy 和 Label Smoothing Regularization :在对原始标签进行平滑的过程中,可能存在某些数据对标签变化特别敏感,导致损失函数的异常增大,使模型变得不稳定,为了增加模型的稳定性所以使用双损失函数——categorical_crossentropy 和 Label Smoothing Regularization,即保证了模型的泛化能力,又保证了数据不会对标签过于敏感,增加了模型的稳定性。

criterion = L.JointLoss(first=nn.crossentropyloss(), second=LabelSmoothSoftmaxCE(),
                              first_weight=0.5, second_weight=0.5)

6.优化器
RAdam LookAhead:兼具Adam和SGD两者的优化器RAdam,收敛速度快,鲁棒性好LookAhead对SGD进行改进,在各种深度学习任务上实现了更快的收敛 。
将RAdam和LookAhead结合在了一起,形成名为Ranger的新优化器。在ImageNet上进行了测试,在128px,20epoch的测试中,Ranger的训练精度达到了93%,比目前FastAI排行榜榜首提高了1%。
在这里插入图片描述

7.选择的模型

ResNeXt101系列,EfficientNet系列。
resnext101_32x16d_wsl
resnext101_32x8d_wsl
efficientnet-b4
efficientnet-b5

8.注意力机制

在模型中加入SE&CBAM注意力机制——提升网络模型的特征提取能力。
在这里插入图片描述

  • channel attention的过程。比SE多了一个 global max pooling(池化本身是提取高层次特征,不同的池化意味着提取的高层次特征更加丰富)。其第2个池化之后的处理过程和SE一样,都是先降维再升维,不同的是将2个池化后相加再sigmod和原 feature map。
    在这里插入图片描述
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

② spatial attention 的过程。将做完 channel attention 的feature map 作为输入,之后作2个大小为列通道的维度池化,每一次池化得到的 feature map 大小就为 h * w * 1 ,再将两次池化的 feature map 作基于通道的连接变成了大小为 h * w * 2 的 feature map ,再对这个feature map 进行核大小为 7*7 ,卷积核个数为1的卷积操作(通道压缩)再sigmod,最后就是熟悉的矩阵全乘。
在这里插入图片描述

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

9.SVM替换softmax层
抽取出模型的最后一层,将其接入SVM,用训练数据动态训练SVM分类器,再使用训练好的SVM分类器进行预测。
深度学习模型有支持向量机无法比拟的非线性函数逼近能力,能够很好地提取并表达数据的特征,深度学习模型的本质是特征学习器。然而,深度模型往往在独立处理分类、回归等问题上难以取得理想的效果。对于 SVM 来说,可以利用核函数的思想将非线性样本映射到高维空间,使其线性可分,再通过使数据集之间的分类间隔最大化来寻找最优分割超平面,在分类问题上表现出许多特有优势。但实质上,SVM 只含有一个隐层,数据表征能力并不理想。因此将深度学习方法与 SVM 相结合,构造用于分类的深层模型。利用深度学习的无监督方式分层提取样本高级特征,然后将这些高级特征输入 SVM 模型进行分类,从而达到最优分类精度。
在这里插入图片描述
10.模型融合
多模型融合的策略,Stacking,VotingClassifier—— 提升分类准确率Stacking方法: Stacking 先从初始数据集训练出初级学习器,然后”生成”一个新数据集用于训练次级学习器。在这个新数据集中,初级学习器的输出被当作样例输入特征,而初始样本的标记仍被当作样例标记。stacking使用交叉验证的方式,初始训练集 D 被随机划分为 k 个大小相似的集合 D1 , D2 , … , Dk,每次用k-1 个部分训练 T 个模型,对另个一个部分产生 T 个预测值作为特征,遍历每一折后,也就得到了新的特征集合,标记还是源数据的标记,用新的特征集合训练一个集合模型。

11.TTA:测试时数据增强(保证数据增强的几何变换和tta一致)
可将准确率提高若干个百分点,它就是测试时增强(test time augmentation, TTA)。这里会为原始图像造出多个不同版本,包括不同区域裁剪和更改缩放程度等,并将它们输入到模型中;然后对多个版本进行计算得到平均输出,作为图像的最终输出分数。

tta_model = tta.TTAWrapper(model,tta.fliplr_image2label)

作者:初识CV
链接:https://www.zhihu.com/question/408153243/answer/1490901503
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/333217.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

shiro、springboot、vue、elementUI CDN模式前后端分离的权限管理demo 附源码

shiro、springboot、vue、elementUI CDN模式前后端分离的权限管理demo 附源码 源码下载地址 https://github.com/Aizhuxueliang/springboot_shiro.git 前提你电脑的安装好这些工具&#xff1a;jdk8、idea、maven、git、mysql&#xff1b; shiro的主要概念 Shiro是一个强大…

麻省理工出版 | 2023年最新深度学习综述手册

UCL Simon Prince的新书&#xff1a;《Understanding Deep Learning》 &#xff0c;在2023年2月6日由MIT Press出版。他之前写过很受欢迎的《Computer Vision: Models, Learning, and Inference》。 关于这本最新的深度学习手册&#xff0c;作者这样介绍它&#xff1a; 正如书…

数据结构与算法基础-学习-09-线性表之栈的理解、初始化顺序栈、判断顺序栈空、获取顺序栈长度的实现

一、个人理解栈是线性表的一种衍生&#xff0c;和之前的顺序表和链表在插入和删除元素上有较大差异&#xff0c;其他基本相同&#xff0c;栈是数据只能插入表的尾部&#xff08;栈顶&#xff09;&#xff0c;删除数据时只能删除表的尾部&#xff08;栈顶&#xff09;数据&#…

2022Q4手机银行新版本聚焦提升客群专属、财富开放平台、智能化能力,活跃用户规模6.91亿人

易观&#xff1a;2022年第4季度&#xff0c;手机银行APP迭代升级加快&#xff0c;手机银行作为零售银行服务及经营的主阵地&#xff0c;与零售银行业务发展的联系日益紧密。迭代升级一方面可以顺应零售银行发展战略及方向&#xff0c;对手机银行业务布局进行针对性调整优化&…

前端面经详解

目录 css 盒子充满屏幕 A.给div设置定位 B.设置html,body的宽高 C.相对当前屏幕高度&#xff08;强烈推荐&#xff09; 三列布局&#xff1a;左右固定&#xff0c;中间自适应 flex布局&#xff08;强烈推荐&#xff09; grid布局 magin负值法 自身浮动 绝对定位 圣…

基于PHP的招聘网站

摘要在Internet高速发展的今天&#xff0c;我们生活的各个领域都涉及到计算机的应用&#xff0c;其中包括在线招聘的网络应用&#xff0c;在外国在线招聘已经是很普遍的方式&#xff0c;不过国内的在线招聘可能还处于起步阶段。招聘网站具有招聘信息功能的双向选择&#xff0c;…

UMI入门(创建react项目)

1、umI的环境要求确保 node 版本是 10.13 或以上React 16.8.0 及以上的 React2、什么时候不用 umi需要支持 IE 8 或更低版本的浏览器需要支持 React 16.8.0 以下的 React需要跑在 Node 10 以下的环境中有很强的 webpack 自定义需求和主观意愿需要选择不同的路由方案3、与其他框…

领域搜索算法之经典The Lin-Kernighan algorithm

领域搜索算法之经典The Lin-Kernighan algorithmThe Lin-Kernighan algorithm关于算法性能提升的约束参考文献领域搜索算法是TSP问题中的三大经典搜索算法之一&#xff0c;另外两种分别是回路构造算法和组合算法。 而这篇文章要介绍的The Lin-Kernighan algorithm属于领域搜索算…

精华文稿|迈向统一的点云三维物体检测框架

分享嘉宾 | 杨泽同 文稿整理 | William 嘉宾介绍 Introduction 3D检测是在三维世界中去定位和分类不同的物体&#xff0c;与传统2D检测的区别在于它有一个深度信息。目前&#xff0c;大部分的工作是倾向于用点云去做三维检测&#xff0c;点云实际上是通过传感器去扫描出来的一…

Redis 的安装 + SpringBoot 集成 Redis

1.安装 Redis此处的 Redis 安装是针对 Linux 版本的安装, 因为 Redis 官方没有提供 Windows 版本, 只提供了 Linux 版本. 但是我们可以通过Windows 去远程连接 Redis.1.1 使用 yum 安装 Redis使用如下命令, 将 Redis 安装到 Linux 服务器:yum -y install redis1.2 启动 Redis使…

科技云报道:开源真的香,风险知多少?

科技云报道原创。 过去几年&#xff0c;开源界一片火热&#xff0c;开源软件技术已全面进军操作系统、云原生、人工智能、大数据、半导体、物联网等行业领域。 数据显示&#xff0c;我国超九成企业在使用或正计划使用开源技术。 与此同时&#xff0c;全球各大开源组织相继兴…

苹果手机专用蓝牙耳机有哪些?与iphone兼容性好的蓝牙耳机

蓝牙耳机摆脱了线缆的束缚&#xff0c;在地以各种方式轻松通话。自从蓝牙耳机问世以来&#xff0c;一直是行动商务族提升效率的好工具&#xff0c;苹果产品一直都是受欢迎的数码产品&#xff0c;下面推荐几款与iphone兼容性好的蓝牙耳机。 第一款&#xff1a;南卡小音舱蓝牙耳…

Springboot部署阿里云短信服务

一、阿里云短信模板配置&#xff08;获取SignName[签名名称]、TemplateCode[模板CODE]&#xff09; 1. 进入阿里云首页直接搜索短信服务&#xff0c;并点击国内消息进入国内文本短信管理页面 2. 选择签名管理点击添加签名 填写签名信息并提交 注&#xff1a;下面的签名来源如…

全栈自动化测试技术笔记(一):前期调研怎么做

昨天下午在家整理书架&#xff0c;把很多看完的书清理打包好&#xff0c;预约了公益捐赠机构上门回收。 整理的过程中无意翻出了几年前的工作记事本&#xff0c;里面记录了很多我刚开始做自动化和性能测试时的笔记。 虽然站在现在的角度来看&#xff0c;那个时候无论是技术细…

【Java 面试合集】描述下Objec类中常用的方法(未完待续中...)

描述下Objec类中常用的方法 1. 概述 首先我们要知道Object 类是所有的对象的基类&#xff0c;也就是所有的方法都是可以被重写的。 那么到底哪些方法是我们常用的方法呢&#xff1f;&#xff1f;&#xff1f; cloneequalsfinalizegetClasshashCodenotifynotifyAlltoStringw…

你知道 GO 中什么情况会变量逃逸吗?

你知道 GO 中什么情况会变量逃逸吗&#xff1f;首先我们先来看看什么是变量逃逸 Go 语言将这个以前我们写 C/C 时候需要做的内存规划和分配&#xff0c;全部整合到了 GO 的编译器中&#xff0c;GO 中将这个称为 变量逃逸 GO 通过编译器分析代码的特征和代码的生命周期&#x…

在RT-Thread STM32F407平台下配置SPI flash为U盘

记录下SPI Flash U盘实现过程中踩过的坑&#xff0c;与您分享。前提条件是&#xff0c;需要先将SPI Flash 配置到elm fal文件系统&#xff0c;并挂载成功。如下图然后开始配置USB1&#xff0c;在CubeMX&#xff0c;选择SUB_OTG_FS2 选择USB Device3&#xff0c;确认USB时钟为48…

流程控制之循环

文章目录五、流程控制之循环5.1 步进循环语句for5.1.1 带列表的for循环语句5.1.2 不带列表的for循环语句5.1.3 类C风格的for循环语句5.2 while循环语句5.2.1 while循环读取文件5.2.2 while循环语句示例5.3 until循环语句5.4 select循环语句5.5 嵌套循环5.4 利用break和continue…

【八大数据排序法】堆积树排序法的图形理解和案例实现 | C++

第二十一章 堆积树排序法 目录 第二十一章 堆积树排序法 ●前言 ●认识排序 1.简要介绍 2.图形理解 3.算法分析 ●二、案例实现 1.案例一 ● 总结 前言 排序算法是我们在程序设计中经常见到和使用的一种算法&#xff0c;它主要是将一堆不规则的数据按照递增…

BinaryAI全新代码匹配模型BAI-2.0上线,“大模型”时代的安全实践

导语BinaryAI&#xff08;https://www.binaryai.net&#xff09;科恩实验室在2021年8月首次发布二进制安全智能分析平台—BinaryAI&#xff0c;BinaryAI可精准高效识别二进制文件的第三方组件及其版本号&#xff0c;旨在推动SCA&#xff08;Software Composition Analysis&…