深度学习中的数值稳定性处理详解:以SimCLR损失为例

news2025/4/16 18:04:00

文章目录

    • 1. 问题背景
      • SimCLR的原始公式
    • 2. 数值溢出问题
      • 为什么会出现数值溢出?
      • 浮点数的表示范围
    • 3. 数值稳定性处理方法
      • 核心思想
      • 数学推导
    • 4. 代码实现分解
      • 代码与公式的对应关系
    • 5. 具体数值示例
      • 示例:相似度矩阵
      • 方法1:直接计算exp(x)
      • 方法2:减去最大值后计算
      • 验证结果等价性
    • 6. 为什么减去最大值有效?
      • 关键原理
    • 7. 实际应用场景
    • 8. 实现建议
    • 总结

在深度学习实现中,特别是涉及指数和对数运算的损失函数计算过程中,数值稳定性是一个核心问题。本文以SimCLR对比学习损失为例,详细解析数值稳定性处理的原理、实现和重要性。

1. 问题背景

SimCLR是一种自监督学习方法,其核心是InfoNCE损失函数。这个损失函数的计算涉及大量指数运算,容易导致数值溢出或下溢问题。

SimCLR的原始公式

SimCLR的核心损失函数(InfoNCE损失)公式为:

L i = − log ⁡ exp ⁡ ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N exp ⁡ ( s i m ( z i , z k ) / τ ) ⋅ 1 k ≠ i L_i = -\log \frac{\exp(sim(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau) \cdot \mathbf{1}_{k \neq i}} Li=logk=12Nexp(sim(zi,zk)/τ)1k=iexp(sim(zi,zj)/τ)

其中:

  • z i z_i zi是锚点特征
  • z j z_j zj是与 z i z_i zi对应的正样本特征
  • τ \tau τ是温度参数
  • s i m ( ) sim() sim()是相似度函数(通常是点积)
  • 1 k ≠ i \mathbf{1}_{k \neq i} 1k=i表示排除自身对比的指示函数

2. 数值溢出问题

为什么会出现数值溢出?

当我们计算 exp ⁡ ( x ) \exp(x) exp(x)时:

  • 如果 x x x很大(如 x = 100 x = 100 x=100), exp ⁡ ( 100 ) ≈ 2.7 × 1 0 43 \exp(100) \approx 2.7 \times 10^{43} exp(100)2.7×1043,可能超出浮点数表示范围
  • 如果 x x x是很小的负数(如 x = − 100 x = -100 x=100), exp ⁡ ( − 100 ) ≈ 3.7 × 1 0 − 44 \exp(-100) \approx 3.7 \times 10^{-44} exp(100)3.7×1044,可能导致下溢为0

在SimCLR中, s i m ( z i , z k ) / τ sim(z_i, z_k)/\tau sim(zi,zk)/τ可能很大,特别是当:

  • 特征向量高度相似( s i m sim sim接近1)
  • 温度参数 τ \tau τ很小(如0.07)

浮点数的表示范围

浮点数的表示范围是有限的:

  • 单精度浮点数(32位):约 ± 3.4 × 1 0 38 \pm 3.4 \times 10^{38} ±3.4×1038
  • 双精度浮点数(64位):约 ± 1.8 × 1 0 308 \pm 1.8 \times 10^{308} ±1.8×10308

3. 数值稳定性处理方法

SimCLR实现中使用了一种简单而有效的数值稳定性处理技术,代码如下:

# 数值稳定性处理
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()

核心思想

这种处理的核心思想是:

  1. 找出每行相似度的最大值
  2. 将每行的所有值减去这个最大值
  3. 然后再进行指数计算

数学推导

这种操作是数学等价的。对原始公式进行变换:

L i = − log ⁡ exp ⁡ ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N exp ⁡ ( s i m ( z i , z k ) / τ ) ⋅ 1 k ≠ i \begin{align} L_i &= -\log \frac{\exp(sim(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau) \cdot \mathbf{1}_{k \neq i}} \\ \end{align} Li=logk=12Nexp(sim(zi,zk)/τ)1k=iexp(sim(zi,zj)/τ)

引入最大值 M i = max ⁡ k ( s i m ( z i , z k ) / τ ) M_i = \max_k (sim(z_i, z_k)/\tau) Mi=maxk(sim(zi,zk)/τ)

L i = − log ⁡ exp ⁡ ( s i m ( z i , z j ) / τ − M i + M i ) ∑ k = 1 2 N exp ⁡ ( s i m ( z i , z k ) / τ − M i + M i ) ⋅ 1 k ≠ i = − log ⁡ exp ⁡ ( M i ) ⋅ exp ⁡ ( s i m ( z i , z j ) / τ − M i ) exp ⁡ ( M i ) ⋅ ∑ k = 1 2 N exp ⁡ ( s i m ( z i , z k ) / τ − M i ) ⋅ 1 k ≠ i = − log ⁡ exp ⁡ ( s i m ( z i , z j ) / τ − M i ) ∑ k = 1 2 N exp ⁡ ( s i m ( z i , z k ) / τ − M i ) ⋅ 1 k ≠ i \begin{align} L_i &= -\log \frac{\exp(sim(z_i, z_j)/\tau - M_i + M_i)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau - M_i + M_i) \cdot \mathbf{1}_{k \neq i}} \\ &= -\log \frac{\exp(M_i) \cdot \exp(sim(z_i, z_j)/\tau - M_i)}{\exp(M_i) \cdot \sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i}} \\ &= -\log \frac{\exp(sim(z_i, z_j)/\tau - M_i)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i}} \end{align} Li=logk=12Nexp(sim(zi,zk)/τMi+Mi)1k=iexp(sim(zi,zj)/τMi+Mi)=logexp(Mi)k=12Nexp(sim(zi,zk)/τMi)1k=iexp(Mi)exp(sim(zi,zj)/τMi)=logk=12Nexp(sim(zi,zk)/τMi)1k=iexp(sim(zi,zj)/τMi)

因为分子和分母中的 exp ⁡ ( M i ) \exp(M_i) exp(Mi)相互抵消,所以最终结果不变。

4. 代码实现分解

完整的SimCLR损失计算代码(包含数值稳定性处理):

# 计算相似度矩阵并除以温度系数
anchor_dot_contrast = torch.div(
    torch.matmul(anchor_feature, contrast_feature.T),
    self.temperature)

# 数值稳定性处理
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()

# 创建和应用掩码
mask = mask.repeat(anchor_count, contrast_count)
logits_mask = torch.scatter(
    torch.ones_like(mask),
    1,
    torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
    0
)
mask = mask * logits_mask

# 计算损失
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()

代码与公式的对应关系

  1. anchor_dot_contrast s i m ( z i , z k ) / τ sim(z_i, z_k)/\tau sim(zi,zk)/τ
  2. logits_max M i = max ⁡ k ( s i m ( z i , z k ) / τ ) M_i = \max_k (sim(z_i, z_k)/\tau) Mi=maxk(sim(zi,zk)/τ)
  3. logits s i m ( z i , z k ) / τ − M i sim(z_i, z_k)/\tau - M_i sim(zi,zk)/τMi
  4. exp_logits exp ⁡ ( s i m ( z i , z k ) / τ − M i ) ⋅ 1 k ≠ i \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i} exp(sim(zi,zk)/τMi)1k=i
  5. log_prob log ⁡ exp ⁡ ( s i m ( z i , z k ) / τ − M i ) ∑ k exp ⁡ ( s i m ( z i , z k ) / τ − M i ) ⋅ 1 k ≠ i \log \frac{\exp(sim(z_i, z_k)/\tau - M_i)}{\sum_{k} \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i}} logkexp(sim(zi,zk)/τMi)1k=iexp(sim(zi,zk)/τMi)

5. 具体数值示例

为了直观理解,我们用一个简化的例子来说明为什么减去最大值能防止数值溢出。

示例:相似度矩阵

假设有一个计算得到的相似度矩阵(已除以温度τ=0.07):

sim(z_i, z_k)/τ = [
    [80, 50, 60, 70, 40],
    [60, 90, 70, 80, 50],
    [70, 60, 85, 75, 55],
    [50, 40, 60, 75, 45]
]

方法1:直接计算exp(x)

直接计算exp(sim(z_i, z_k)/τ)

exp(sim(z_i, z_k)/τ) ≈ [
    [5.54e+34, 5.18e+21, 1.14e+26, 2.51e+30, 2.35e+17],
    [1.14e+26, 1.22e+39, 2.51e+30, 5.54e+34, 5.18e+21],
    [2.51e+30, 1.14e+26, 5.91e+36, 3.58e+32, 1.14e+24],
    [5.18e+21, 2.35e+17, 1.14e+26, 3.58e+32, 3.49e+19]
]

这些值极其巨大,相加时很容易溢出。例如第一行的和约为5.54e+34,已经接近单精度浮点数的上限。

方法2:减去最大值后计算

找出每行的最大值:

max_values = [80, 90, 85, 75]

减去最大值:

adjusted_logits = [
    [0, -30, -20, -10, -40],
    [-30, 0, -20, -10, -40],
    [-15, -25, 0, -10, -30],
    [-25, -35, -15, 0, -30]
]

计算exp(adjusted_logits)

exp(adjusted_logits) ≈ [
    [1.0, 9.36e-14, 2.06e-9, 4.54e-5, 4.25e-18],
    [9.36e-14, 1.0, 2.06e-9, 4.54e-5, 4.25e-18],
    [3.06e-7, 1.39e-11, 1.0, 4.54e-5, 9.36e-14],
    [1.39e-11, 6.31e-16, 3.06e-7, 1.0, 9.36e-14]
]

这些值都在[0,1]范围内,完全避免了溢出问题。同时,正样本对和负样本对之间的相对比例关系保持不变。

验证结果等价性

例如,对于第一行计算最终的归一化概率:

原始方法:

P(z_0 -> z_0) = exp(80) / sum(exp(row_0)) ≈ 1.0
P(z_0 -> z_1) = exp(50) / sum(exp(row_0)) ≈ 9.35e-14
...

减去最大值后:

P(z_0 -> z_0) = exp(0) / sum(exp(adjusted_row_0)) ≈ 1.0
P(z_0 -> z_1) = exp(-30) / sum(exp(adjusted_row_0)) ≈ 9.35e-14
...

两种计算方法得到的概率分布是相同的,但后者避免了数值溢出风险。

6. 为什么减去最大值有效?

关键原理

减去最大值的处理之所以有效,是因为:

  1. 将范围控制在安全区间

    • 减去最大值后,所有值都≤0
    • 因此所有exp(x)的结果都≤1,避免了上溢
    • 同时最大值对应的exp(0)=1,避免了整体下溢为0
  2. 保持相对比例关系

    • 对每行减去相同的常数不改变值之间的相对大小
    • 对于exp()函数来说,这等价于同时除以一个常数因子
    • 在计算Softmax或对数概率时,这个常数因子在分子和分母中抵消
  3. 数学等价性

    • exp(a-b) = exp(a)/exp(b)的性质保证了结果的正确性
    • 这相当于将原始公式的分子和分母同时除以exp(max_value)

7. 实际应用场景

这种数值稳定性技术不仅适用于SimCLR,还广泛应用于:

  1. Softmax计算:几乎所有需要计算Softmax的地方都需要
  2. 交叉熵损失:分类任务中常用
  3. 注意力机制:Transformer中的attention计算
  4. 所有对比学习方法:MoCo、BYOL、CLIP等

8. 实现建议

在实现涉及指数计算的函数时,建议:

  1. 始终使用数值稳定性处理
  2. 对每个batch/样本独立进行处理(找到每行/每个样本的最大值)
  3. 使用.detach()阻止梯度通过最大值操作传播
  4. 注意掩码操作,确保不包括自身对比或特定的负样本

总结

数值稳定性处理是深度学习实现中一个看似简单但至关重要的技术。通过简单地减去每行的最大值,我们可以有效防止数值溢出/下溢问题,同时保持计算结果的数学等价性。这种技术尤其重要,因为随着模型和批量大小的增加,数值问题更容易出现,而且往往难以诊断。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2335198.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

散户使用算法交易怎么做?

智能算法交易是量化交易里面最常见的一种,也是大多数散户被套住的股票,想要解套,降低成本最直接有效的方式。但是往往这种波动速度小,担心速度跟不上的情况,我们就要叠加快速通道。 第一:算法交易的应用场…

mongodb 安装配置

1.官网下载地址:MongoDB Community Download | MongoDB 2.解压包安装:https://pan.baidu.com/s/1Er56twK9UfxoExuCPlJjhg 提取码: 26aj 3.配置环境: (1)mongodb安装包位置: (2)复…

榕壹云酒水定制系统:基于THinKPHP+MySQL+UniApp打造数字化时代的个性化购酒新体验

数字化浪潮下的酒水定制新机遇 在消费升级与个性化需求崛起的背景下,传统酒水行业正面临数字化转型的迫切需求。为此,我们团队基于ThinkPHP+MySQL+UniApp技术栈,开发了一套榕壹云酒水定制系统,旨在通过数字化手段解决消费者个性化购酒痛点,为酒类品牌提供全链路数字化解决…

Leetcode——137 260找出只出现一次的数

文章目录 找出只出现一次的数引入Leetcode 260Leetcode 137 找出只出现一次的数 对于数组中有一类题,即某些数据在数组中只出现一遍,需要我们找出,今天我们来看看这个类型的题。 引入 想必大家应该见过这么一道题: 现给定一个数…

OpenTiny使用指南

最近项目里用到了一个新的组件库——OpenTiny,但是官方文档的使用指南的描述很复杂,花了一些时间尝试才正常使用。下面是一个使用步骤的描述,可放心食用: 一、安装 TinyVue 组件库同时支持 Vue 2.0 和 Vue 3.0 框架,…

KingbaseES之KDts迁移SQLServer

项目适配迁移SQLServer至金仓,今天写写KDts-WEB版迁移工具迁移SQLServer至KingbaseES的步骤,以及迁移注意事项. SQLServer版本:SQLServer2012 KingbaseES版本:V009R004C011(SQLServer兼容版) --1.进入数据库客户端工具KDTS工具目录,启动KDts服务: [king…

代码随想录动态规划part02

动态规划part02 62.不同路径 代码随想录 视频讲解:动态规划中如何初始化很重要!| LeetCode:62.不同路径_哔哩哔哩_bilibili 递归法 动态规划,当前状态是由上一个状态转化来的 这里初始化错误了,想法是对的右一和…

详解如何复现DeepSeek R1:从零开始利用Python构建

DeepSeek R1 的整个训练过程,说白了就是在其基础模型(也就是 deepseek V3)之上,用各种不同的强化学习方法来“雕琢”它。 咱们从一个小小的本地运行的基础模型开始,一边跟着 DeepSeek R1 技术报告 的步骤,…

Java集合框架 源码分析 迭代器 并发修改异常底层原理

迭代器 Java中的Iterator(迭代器)是集合框架中用于遍历容器元素的统一接口,提供了一种标准化的元素访问方式,无需依赖具体集合类型的实现细节。以下是其核心要点: 一、核心方法与使用步骤 获取迭代器 通过集合的 it…

Cannot find module ‘vue‘ or its corresponding type declarations

在使用vue3vite创建新的工程时&#xff0c;在新增.vue文件时会出现Cannot find module vue这个错误。 只需要我们在项目中的.d.ts文件中添加以下代码即可 declare module *.vue {import { defineComponent } from vue;const component: ReturnType<typeof defineComponent&…

【Python爬虫】详细工作流程以及组成部分

目录 一、Python爬虫的详细工作流程 确定起始网页 发送 HTTP 请求 解析 HTML 处理数据 跟踪链接 递归抓取 存储数据 二、Python爬虫的组成部分 请求模块 解析模块 数据处理模块 存储模块 调度模块 反爬虫处理模块 一、Python爬虫的详细工作流程 在进行网络爬虫工…

欧拉服务器操作系统部署deekseep(Ollama+DeekSeep+open WebUI)

​​一、解压并安装 Ollama​​ # 1. 解压文件&#xff08;默认会得到一个二进制文件&#xff09; tar -xzvf ollama-linux-amd64.tgz# 2. 将二进制文件安装到系统路径 sudo mv ollama /usr/local/bin/ sudo chmod x /usr/local/bin/ollama# 3. 验证安装 ollama --version链接…

#4 我们为什么使用物联网? 以及 物联网的整体结构

设备不物联是否可以&#xff1f; 答案 是可以的&#xff0c;从项目实战的角度&#xff0c;还是有很多包括分拣&#xff0c;控制&#xff0c;检测等应用是分立的&#xff0c;这个和成本&#xff0c;场景&#xff0c;客户接受度等因素有关。 局部看&#xff0c;一些系统的确很简…

3D版的VLA——从3D VLA、SpatialVLA到PointVLA(不动VLM,仅动作专家中加入3D数据)

前言 之前写这篇文章的时候&#xff0c;就想解读下3D VLA来着&#xff0c;但一直因为和团队并行开发具身项目&#xff0c;很多解读被各种延后 更是各种出差&#xff0c;比如从25年3月下旬至今&#xff0c;连续出差三轮&#xff0c;绕中国半圈&#xff0c;具身占八成 第一轮 …

linux Shell编程之循环语句(三)

目录 一. for 循环语句 1. for语句的结构 2. for 语句应用示例 (1) 根据姓名列表批量添加用户 (2) 根据 IP 地址列表检查主机状态 二. 使用 while 循环语句 1. while 语句的结构 2. while 语句应用示例 (1) 批量添加规律编号的用户 (2) 猜价格游戏 三. until 循环语…

C#容器源码分析 --- Queue<T>

Queue<T> 是 System.Collections.Generic 命名空间下的先进先出&#xff08;FIFO&#xff09;动态集合&#xff0c;其核心实现基于​​循环数组​​&#xff0c;通过维护头尾指针实现高效入队和出队操作。 .Net4.8 Queue<T>源码地址&#xff1a;queue.cs (microso…

ViT 模型讲解

文章目录 一、模型的诞生背景1.1 背景1.2 ViT 的提出&#xff08;2020年&#xff09; 二、模型架构2.1 patch2.2 模型结构2.2.1 数据 shape 变化2.2.2 代码示例2.2.3 模型结构图 2.3 关于空间信息 三、实验3.1 主要实验3.2 消融实验 四、先验问题4.1 归纳偏置4.2 先验or大数据&…

IntelliJ IDEA 中安装和使用通义灵码 AI 编程助手教程

随着人工智能技术的发展&#xff0c;AI 编程助手逐渐成为提升开发效率的强大工具。通义灵码是阿里云推出的一款 AI 编程助手&#xff0c;它能够帮助开发者实现智能代码补全、代码解释、生成单元测试等功能&#xff0c;极大地提升了编程效率和代码质量。 IntelliJ IDEA 是一款广…

FreeRTOS入门与工程实践-基于STM32F103(一)(单片机程序设计模式,FreeRTOS源码概述,内存管理,任务管理,同步互斥与通信,队列,信号量)

裸机程序设计模式 裸机程序的设计模式可以分为&#xff1a;轮询、前后台、定时器驱动、基于状态机。前面三种方法都无法解决一个问题&#xff1a;假设有A、B两个都很耗时的函数&#xff0c;无法降低它们相互之间的影响。第4种方法可以解决这个问题&#xff0c;但是实践起来有难…

can‘t set boot order in virtualbox

Boot order setting is ignored if UEFI is enabled https://forums.virtualbox.org/viewtopic.php?t99121 如果勾选EFI boot order就是灰色的 传统BIOS就是可选的 然后选中任意介质&#xff0c;通过右边的上下箭头调节顺序&#xff0c;最上面的应该是优先级最高的 然后就…