深度学习模型不确定性方法对比

news2024/11/18 0:20:05

©PaperWeekly 原创 · 作者|崔克楠

学校|上海交通大学博士生

研究方向|异构信息网络、推荐系统

本文以 NeurIPS 2019 的 Can You Trust Your Model’s Uncertainty? Evaluating Predictive Uncertainty Under Dataset Shift 论文为主线,回顾近年顶级机器学习会议对于 dataset shift 和 out-of-distribution dataset 问题相关的论文,包括了 Temperature scaling [1],DeepEnsemble [2],Monte-Carlo Dropout [3] 等方法。而 [4] 在统一的数据集上对上述一系列方法,测试了他们在 data shift 和 out-of-distribution 问题上的 accuracy 和 calibration。

Temperature Scaling [1]

在介绍 temperature scaling 之前,首先需要了解什么叫做 calibrated?

神经网络在分类时会输出“置信度”分数和预测结果。理想情况下,这些分数应该与真实正确性的可能性相匹配。例如,如果我们将 80% 的置信度分配给 100 个样本,那么我们就会期望 80% 样本的预测实际上是正确的。如果是这样,我们说模型是经过校准的。

而 Temperature scaling 则是一个非常简单的后处理步骤,能够帮助模型进行校准。一种可视化校准的简单方法是将精度作为置信度的函数绘制(reliability diagram)。下边左边的可靠性图表中,我们可以看到一个在 CIFAR-100 上训练的 DenseNet 是极度自信的。然而,使用 Temperature scaling,模型就得到了校准。

具体怎么做 temperature scaling 呢,对于分类问题,网络最后一层往往会输出 logits,而 logits 进一步传给 softmax 函数来得到各个类别的概率,而 temperature scaling 对这一步骤修改为:

实现层面也很简单,在 PyTorch 的实现如下:

class Model(torch.nn.Module):    def __init__(self):        # ...        self.temperature = torch.nn.Parameter(torch.ones(1))    def forward(self, x):        # ...        # logits = final output of neural network        return logits / self.temperature

但要注意的是,上述方法需要在 validation set 上进行优化,来学习参数 temperature,而不能在 training set 上进行学习,所以 Temperature scaling 是一个 post process,即后处理步骤,这种方法也暂时只能用于分类任务,不能用于回归。

Deep Ensemble [2]

以往 ensemble 的方法大致分为 randomization-based 的方法,和 boosting based 的方法。前者方法中,ensemble 中的 members 可以并行训练,没有 interaction;后者方法中的 members 之间在训练时是有相互依赖的先后顺序。

而 deep ensemble [2] 属于前者方法。相比于以往的方法使用部分数据去训练 member,deep ensemble 使用整个训练集去训练 M 个独立随机初始化的网络模型。其训练过程如下图算法所示:

M 个独立的模型训练完后,对于模型预测使用如下的 uniformly-weighted 的方法进行融合。后文为方便,统称 deep ensemble 为 ensemble 方法。

MC-Dropout [2]

MC-Dropout 是为模型引入 uncertainty 特性的最为简单有效的方法之一。以往我们经常在训练时对模型参数使用 dropout,以防止模型过拟合,在 inference 阶段,往往会关闭 dropout。而 MC-dropout 则强调,在 inference 阶段,也要对模型参数进行 dropout。对于一个样本的 inference,MC-Dropout 要求随机进行 K 次 dropout,进行 K 次前传,得到 K 个输出结果。而 K 个输出结果再进行 ensemble。

这么做的目的是因为在贝叶斯网络中,网络模型的参数应当服从特定的分布。模型在预测结果时,应当对模型的参数分布进行积分,而对于如今庞大的模型来说这显然是不可能的。MC-dropout 相当从模型参数的变分分布当中随机采样,将这一“积分”过程变得简单,容易实现。

实验设置

实验主要探讨了上述方法在不同 data shift 和 out of distribution 下,在 accuracy,calibration 等 metric 上的表现。其 data shift 如下图所示,对 ImageNet 和 MNIST 的图片施加不同的 image level 的 corruption。

而 Out-of-distribution 指的是,和训练数据分布不一致的数据集,对于 MNIST 数据集来说,NotMNIST 数据集为 out-of-distribution,而对于 CIFAR 数据集来说,SVHN 数据集为 out-of-distribution。所有的方法均采用相同的网络结构,实验设置汇总到下表所示。

实验对比分析

在 MNIST 数据集上的对比如上图所示,其中 Brier score 越小越好,而 confidence 指的是分类器最大概率类别的置信度分数。经 Stochastic Variational Bayesian Inference (SVI) 在各个 metric 上的表现好。同时也能够发现:

1. 从 a 和 b 上能看出,在有了 data shift 之后,各个模型的 accuracy 都逐渐下降;

2. 从 a,b 中的 Brier score 可以看出,使用 Temperature scaling 在 validation 矫正,能够在 test 上保证 calibration,但在 shift data 上无法保证calibration;

  1. 从 c 中可以看出,SVI 在比较高的 confidence 下的 accuracy 最高,说明 SVI 方法比较适合于风险价值较高的应用;

4.从 e 和 f 中可以看出,这些方法在 OOD 数据上都显示了比较低的 entropy,并且在 OOD 的数据上给出了比较高的 confidence,说明他们对于 OOD 数据预测较为错误。

在 ImageNet 数据集上的对比如上图所示,其中 ECE 为 Calibration 指标,越小则代表模型校准的越好。我们可以发现:

1. 所有方法随着 shift 程度的增加(比如图片的模糊程度等),Accuracy 越来越低,ECE 越来越高,代表模型的精确度不断下降,同时校准越来越差;

2. 所有模型在不同 shift 上的 Accuracy 表现差别不大,但是 ensemble 优于所有的模型;

3. 同样,ensemble 在不同的 shift 下,模型仍然保持较好的 calibration 能力;

4. 在 CIFAR 的 OOD 实验上,从 c 图中可以看出,tempreture scaling 的 entropy 最高,ensemble 次之。同样而在 ensemble 方法在 OOD 样本的 confidence 比较低,说明 ensemble 能够保持比较好的 uncertainty 特性。

同时作者还发现,在 CIFAR-10 以及文本数据 20Newsgroup 上,ensemble 的表现仍然要优于其他方法,和在 ImageNet 上的表现一致(除了 MNIST 数据集)。而我们也会考虑是否因为 ensemble 方法集成了几个模型,capacity 较大,所以表现较为优异,因此做了如下探究实验。

如上图所示,作者考虑增加出了 ensemble 外,其他方法所使用的网络的 capacity,得到一系列其他方法在 wide architecture 上的表现,可以看到,增加模型的 capacity 并不能带来在 Accuracy 和 ECE 上的提升。

在上图中,作者展示了 ensemble size 对于模型 calibration 的影响,可以看到随着 ensemble size 的提升,Brier Score 是逐渐缩小的,这说明 ensemble size 越大,模型的 calibration 能力是越好的,但超过 50 之后不会再有提升。但考虑到计算负担,一般设置为 5 比较恰当。

作者也探讨了 sample size 对于采样类方法的影响,可以看到 MC dropout 和 SVI 的 Brier score 随着 sample size 的提升而下降,说明较大的 sample size 对于模型的 calibration 是有帮助的,但也要考虑到计算负担的影响。

最终作者给出了各个方法的计算和储存方面的效率,可以看到 Ensemble 虽然通常来说表现较好,但是开销往往也是最大的。

总结

1. 模型的 Accuracy 和 Calibration 会随着 data shift 逐渐下降;

2. Temperature scaling 虽然能够在 test set 上保持 calibration,但是在 shift dataset 上却无法达到同样的效果;

3. SVI 在 MNIST 上表现最好,但是在其他所有数据集上,ensemble 表现最为优异。并且他们表现得相对顺序也是一致的;

4. Ensemble 虽然表现较好,但是在计算负担方面不占优势,仍要考虑是否有其他鲁棒的方法。

参考文献

[1] Guo, C., Pleiss, G., Sun, Y. and Weinberger, K.Q. On Calibration of Modern Neural Networks. In International Conference on Machine Learning, 2017.

[2] Lakshminarayanan, Balaji, Alexander Pritzel, and Charles Blundell. “Simple and scalable predictive uncertainty estimation using deep ensembles.” Advances in neural information processing systems. 2017.

[3] Gal, Y. and Ghahramani, Z. Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. In ICML, 2016

[4] Snoek, Jasper, et al. “Can you trust your model’s uncertainty? Evaluating predictive uncertainty under dataset shift.” Advances in Neural Information Processing Systems. 2019.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1125087.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习笔记 - 特斯拉的占用网络简述

一、简述 ​ 2022 年,特斯拉宣布即将在其车辆中发布全新算法。该算法被称为occupancy networks,它应该是对Tesla 的HydraNet 的改进。 自动驾驶汽车行业在技术上分为两类:基于视觉的系统和基于激光雷达的系统。后者使用激光传感器来确定物体的存在和距离,而视觉系统…

acwing第 126 场周赛 (扩展字符串)

5281. 扩展字符串 一、题目要求 某字符串序列 s0,s1,s2,… 的生成规律如下: s0 DKER EPH VOS GOLNJ ER RKH HNG OI RKH UOPMGB CPH VOS FSQVB DLMM VOS QETH SQBsnDKER EPH VOS GOLNJ UKLMH QHNGLNJ Asn−1AB CPH VOS FSQVB DLMM VOS QHNG Asn−1AB,其…

day10_面向对象_抽象_接口

今日内容 1.作业 2.final 3.抽象 4.接口 零、复习 按从大到小的顺序写出访问修饰符 public > protected > package (default)> private static修饰属性和方法的特点在内存的特点: 在方法区(不是在堆,也不是在栈)初始化的特点: 随类(字节码文件)加载到内存已经初始化使…

基于大数据的时间序列股价预测分析与可视化 - lstm 计算机竞赛

文章目录 1 前言2 时间序列的由来2.1 四种模型的名称: 3 数据预览4 理论公式4.1 协方差4.2 相关系数4.3 scikit-learn计算相关性 5 金融数据的时序分析5.1 数据概况5.2 序列变化情况计算 最后 1 前言 🔥 优质竞赛项目系列,今天要分享的是 &…

Redis不止能存储字符串,还有List、Set、Hash、Zset,用对了能给你带来哪些优势?

文章目录 🌟 Redis五大数据类型的应用场景🍊 一、String🍊 二、Hash🍊 三、List🍊 四、Set🍊 五、Zset 📕我是廖志伟,一名Java开发工程师、Java领域优质创作者、CSDN博客专家、51CTO…

1300*B. Road Construction(构造菊花图)

Problem - 330B - Codeforces 解析&#xff1a; 1到任一点距离不超过二&#xff0c;并且有部分点不可以连边&#xff0c;直接统计所有不能连边的点&#xff0c;从之外的点中选一个点当作中心&#xff0c;构造菊花图即可。 #include<bits/stdc.h> using namespace std; i…

CSS常见选择器总结

1.简单选择器 简单选择器是开发中使用最多的选择器&#xff0c;包含&#xff1a; 元素选择器&#xff0c;使用元素的名称 类选择器&#xff0c;使用.类名 id选择器&#xff0c;使用#id id注意事项&#xff1a; 一个HTML文档里面的id值 是唯一的&#xff0c;不能重复 id值如…

阿里云服务器x86计算架构ECS实例规格汇总

阿里云企业级服务器基于X86架构的实例规格&#xff0c;每一个vCPU都对应一个处理器核心的超线程&#xff0c;基于ARM架构的实例规格&#xff0c;每一个vCPU都对应一个处理器的物理核心&#xff0c;具有性能稳定且资源独享的特点。阿里云服务器网aliyunfuwuqi.com分享阿里云企业…

特约|数码转型思考:Web3.0与银行

日前&#xff0c;欧科云链研究院发布重磅报告&#xff0c;引发银行界及金融监管机构广泛关注。通过拆解全球70余家银行的加密布局&#xff0c;报告认为&#xff0c;随着全球采用率的提升与相关技术的成熟&#xff0c;加密资产已成为银行业不容忽视也不能错过的创新领域。 作为…

尚硅谷kafka3.0.0

目录 &#x1f483;概述 ⛹定义 ​编辑⛹消息队列 &#x1f938;‍♂️消息队列应用场景 ​编辑&#x1f938;‍♂️两种模式&#xff1a;点对点、发布订阅 ​编辑⛹基本概念 &#x1f483;Kafka安装 ⛹ zookeeper安装 ⛹集群规划 ​编辑⛹流程 ⛹原神启动 &#x1f938;‍♂️…

gRPC之gateway集成swagger

1、gateway集成swagger 1、为了简化实战过程&#xff0c;gRPC-Gateway暴露的服务并未使用https&#xff0c;而是http&#xff0c;但是swagger-ui提供的调用服 务却是https的&#xff0c;因此要在proto文件中指定swagger以http调用服务&#xff0c;指定的时候会用到文件 prot…

WebService SOAP1.1 SOAP1.12 HTTP PSOT方式调用

Visual Studio 2022 新建WebService项目 创建之后启动运行 设置默认文档即可 经过上面的创建WebService已经创建完成&#xff0c;添加HelloWorld3方法&#xff0c; [WebMethod] public string HelloWorld3(int a, string b) { //var s a b; return $"Hello World ab{a …

Markdown语法详解

文章目录 [toc] 一、简介二、样式1. 标题2. 字体3. 引用4. 分割线5. 图片6. 超链接7. 列表8. 表格9. 代码 一、简介 以前写学习文档常用的软件都是Word或者CSDN自带的编辑器&#xff0c;但Word用起来不太灵活&#xff0c;而CSDN自带编辑器又感觉逼格不够&#xff08;主要原因&…

(自我剖析一下我博客“问答”中的第三个问题)准确率一直居低不上是什么原因引起的?

我提的问题是&#xff1a; “我使用单层GRU训练minist数据集时&#xff0c;准确率一直处于下图的状态是为什么&#xff1f; 什么原因引起的&#xff1f;” 这种debug就比较难受&#xff0c;因为程序是能跑的&#xff0c;任何“error”都没有出。这就表明在程序中有某些小细节没…

【SwiftUI模块】0060、SwiftUI基于Firebase搭建一个类似InstagramApp 3/7部分-搭建TabBar

SwiftUI模块系列 - 已更新60篇 SwiftUI项目 - 已更新5个项目 往期Demo源码下载 技术:SwiftUI、SwiftUI4.0、Instagram、Firebase 运行环境: SwiftUI4.0 Xcode14 MacOS12.6 iPhone Simulator iPhone 14 Pro Max SwiftUI基于Firebase搭建一个类似InstagramApp 3/7部分-搭建Tab…

数据集的特征提取

1、 特征提取 1.1、 将任意数据&#xff08;如文本或图像&#xff09;转换为可用于机器学习的数字特征 注&#xff1a;特征值化是为了计算机更好的去理解数据 字典特征提取(特征离散化)文本特征提取图像特征提取&#xff08;深度学习将介绍&#xff09; 2 特征提取API sklear…

Python OpenCV通过灰度平均值进行二值化处理以减少像素误差

Python OpenCV通过灰度平均值进行二值化处理以减少像素误差 前言前提条件相关介绍实验环境通过灰度平均值进行二值化处理以减少像素误差固定阈值二值化代码实现 灰度平均值二值化代码实现 前言 由于本人水平有限&#xff0c;难免出现错漏&#xff0c;敬请批评改正。更多精彩内容…

数据安全与PostgreSQL:最佳保护策略

在当今数字化时代&#xff0c;数据安全成为了企业不可或缺的一环。特别是对于使用数据库管理系统&#xff08;DBMS&#xff09;的组织来说&#xff0c;确保数据的完整性、保密性和可用性至关重要。在众多DBMS中&#xff0c;PostgreSQL作为一个强大而灵活的开源数据库系统&#…

酒类商城小程序怎么做

随着互联网的快速发展&#xff0c;线上购物越来越普及。酒类商品也慢慢转向线上销售&#xff0c;如何搭建一个属于自己的酒类小程序商城呢&#xff1f;下面就让我们一起来看看吧&#xff01; 一、登录乔拓云平台 首先&#xff0c;我们需要进入乔拓云平台的后台&#xff0c;点击…

Pytorch公共数据集、tensorboard、DataLoader使用

本文将主要介绍torchvision.datasets的使用&#xff0c;并以CIFAR-10为例进行介绍&#xff0c;对可视化工具tensorboard进行介绍&#xff0c;包括安装&#xff0c;使用&#xff0c;可视化过程等&#xff0c;最后介绍DataLoader的使用。希望对你有帮助 Pytorch公共数据集 torc…