一文彻底搞懂 Softmax 函数,数学原理分析和 PyTorch 验证

news2025/2/9 12:18:27

文章目录

1. Softmax 的定义

softmax函数又称归一化指数函数,是基于 sigmoid 二分类函数在多分类任务上的推广;在多分类网络中,常用 Softmax 作为最后一层进行分类。

Softmax 的计算公式如下:

S o f t m a x ( x i ) = e x i ∑ i = 1 n e x i ∈ ( 0 , 1 ) (1) Softmax(x_i)=\frac{e^{x_i}}{\displaystyle \sum^{n}_{i=1}{e^{x_i}}} \in (0,1) \tag{1} Softmax(xi)=i=1nexiexi(0,1)(1)

2. Softmax 使用 e 的幂次的作用

对比普通的 max() 方法,Softmax 的独特之处就是使用的 e 的幂函数,其目的是为了两极化:

Softmax 可以使正样本(正数)的结果趋近于 1,使负样本(负数)的结果趋近于 0;且样本的绝对值越大,两极化越明显。

2.1 代码验证

(1)先用 numpy 来验证一下:

import numpy as np

# 计算向量 x 的 softmax
def softmax(x: list) -> list:
    exps = np.exp(x)
    return list(exps / np.sum(exps))

if __name__ == '__main__':
    input = [-2, -1, 0, 1, 2]
    output = softmax(input)
    output = [float('{:.4f}'.format(i)) for i in output]
    print(f"{output}")

对比两组输入输出:

(1) input = [-0.5, -0.2, 0, 0.2, 0.5]     output = [0.1145, 0.1546, 0.1888, 0.2307, 0.3114]
(2) input = [-5,   -2,   0, 2,   5]       output = [0.0, 0.0009, 0.0064, 0.0471, 0.9456]

可以明显看到, x 的数值分布越不均匀,则 S o f t m a x ( x ) Softmax(x) Softmax(x) 的两极化越明显 在上面第二个 input 中, -5 对应的输出已经非常接近0,而 5 对应的输出已经接近 0.95

Softmax 可以使数值较大的值获得更大的概率

(2)再看看 PyTorch 中的 Softmax 函数:

import torch
import torch.nn as nn

input = torch.Tensor([-0.5, -0.2, 0, 0.2, 0.5])
softmax = nn.Softmax(dim=0)
output = softmax(input)
print(output)    # tensor([0.1145, 0.1546, 0.1888, 0.2307, 0.3114])

可以看到 PyTorch 的计算结果与我们自己用 numpy 算的是一致的。

2.2 数学原理分析

从数学原理上分析,是因为当 x 的数值分布越不均匀时, e m a x ( x i ) e^{max(x_i)} emax(xi) ∑ i = 1 n x i \displaystyle \sum^{n}_{i=1}{x_i} i=1nxi 非常接近,导致 S o f t m a x ( m a x ( x i ) ) → 1 Softmax(max(x_i)) \rightarrow 1 Softmax(max(xi))1 ,而 S o f t m a x ( m i n ( x i ) ) → 0 Softmax(min(x_i)) \rightarrow 0 Softmax(min(xi))0

3. 解决 Softmax 的数值溢出问题

3.1 什么是数值溢出?

数值溢出是 Softmax 函数经常遇到的问题,数值溢出包括数值上溢和下溢两张情况:

(1)上溢:数值较大的数据经过一些运算后其数值非常大,以至于超过计算机的存储范围而无法继续运算,在程序中表现为 NAN

(2)下溢:非常接近0 的数据被四舍五入为 0,从而产生毁灭性的舌入误差。

3.2 解决数值上溢问题: x i − m a x ( x ) x_i-max(x) ximax(x)

由于 Softmax 中存在 e 的幂次,这将很容易导致数值溢出问题:

(1)当 x i → − ∞ x_i \rightarrow -\infty xi时, S o f t m a x ( x ) Softmax(x) Softmax(x) 的分母将接近 0,导致 S o f t m a x ( x ) → 0 Softmax(x) \rightarrow 0 Softmax(x)0,会出现数值下溢问题。

(2)当 x i → + ∞ x_i \rightarrow +\infty xi+时, S o f t m a x ( x ) Softmax(x) Softmax(x) 的分子和分母都接近正无穷大,导致 S o f t m a x ( x ) Softmax(x) Softmax(x) 的结果是未定的。

依然首先通过代码来说明,可以看到:当输入数值较小时,Softmax 的输出为 0;而当输入数值较大时,Softmax 的输出为 nan

import numpy as np

# 计算向量 x 的 softmax
def softmax(x: list) -> list:
    exps = np.exp(x)
    return list(exps / np.sum(exps))

if __name__ == '__main__':
    input = [-1000, -200, 0, 200, 1000]
    output = softmax(input)
    print(f"{output}")     # [0.0, 0.0, 0.0, 0.0, nan]

上述两个问题可以通过公式 (2) 同时解决:

S o f t m a x ( x i ) = S o f t m a x ( x i − m a x ( x ) ) = e x i − m a x ( x ) ∑ i = 1 n e x i − m a x ( x ) (2) Softmax(x_i)=Softmax(x_i-max(x))=\frac{e^{x_i-max(x)}}{\displaystyle \sum^{n}_{i=1}{e^{x_i-max(x)}}} \tag{2} Softmax(xi)=Softmax(ximax(x))=i=1neximax(x)eximax(x)(2)

简单推导一下就知道, S o f t m a x ( x i ) = S o f t m a x ( x i − m a x ( x ) ) Softmax(x_i)=Softmax(x_i-max(x)) Softmax(xi)=Softmax(ximax(x)) 是成立的;因为 Softmax 的函数值不会因为输入向量减去或加上一个标量而改变(标量在分子和分母中会抵消)。

x i x_i xi 减去 m a x ( x ) max(x) max(x) 使得 exp 指数的最大参数 x i − m a x ( x ) x_i-max(x) ximax(x) 为 0 ,这避免了数值上溢的可能。同时,分母中有一项是固定的 e m a x ( x ) − m a x ( x ) = 1 e^{max(x)-max(x)} =1 emax(x)max(x)=1,这保证分母不会为 0 ,避免出现分母奇异的情况。但公式 (2) 并不能避免分子为 0 从而导致数值下溢的情况。

3.3 解决数值下溢问题:log_softmax

使用 x i − m a x ( x ) x_i-max(x) ximax(x) 可以避免数值上溢,但不能完全解决数值下溢的问题。log_softmax 正是为了解决 softmax 中的数值下溢的情况;对公式(2)取对数得到 log_softmax 的表达式:

l o g [ S o f t m a x ( x i ) ] = l o g e x i − m a x ( x ) ∑ i = 1 n e x i − m a x ( x ) = x i − m a x ( x ) − l o g ( ∑ i = 1 n e x i − m a x ( x ) ) (3) log[Softmax(x_i)]=log \frac{e^{x_i-max(x)}}{\displaystyle \sum^{n}_{i=1}{e^{x_i-max(x)}}}=x_i-max(x)-log(\displaystyle \sum^{n}_{i=1}{e^{x_i-max(x)}}) \tag{3} log[Softmax(xi)]=logi=1neximax(x)eximax(x)=ximax(x)log(i=1neximax(x))(3)

l o g [ S o f t m a x ( x i ) ] log[Softmax(x_i)] log[Softmax(xi)] 中都是常数项,因此不会出现数值溢出问题。

4. PyTorch 中 CrossEntropyLoss 与 Softmax 的关系

PyTorch 中 CrossEntropyLoss 的接口是 torch.nn.CrossEntropyLoss()

先说结论1:

nn.CrossEntropyLoss() 中已经集成了 Softmax,因此如果使用nn.CrossEntropyLoss() 作为损失函数,则网络的最后一层不需要也不能加 Softmax 层

nn.CrossEntropyLoss() 的官方介绍为 torch.nn.CrossEntropyLoss(),其计算公式为:
在这里插入图片描述
再说结论2:

nn.CrossEntropyLoss 是 nn.LogSoftmax 和 nn.NLLLoss 的组合

nn.LogSoftmax 就是 3.3 中讲的 log_softmax,nn.NLLLoss 其实就是先求和,再取负数。所以先做 LogSoftmax 再做 NLLLoss 其实就等价于直接做CrossEntropyLoss

使用 PyTorch 验证一下:

import torch
import torch.nn as nn

# 输入数据和 label
input = torch.Tensor([[-0.5, -0.2, 0, 0.2, 0.5]])
target = torch.tensor([0.35]).long()

log_softmax = nn.LogSoftmax(dim=1)
CEL = nn.CrossEntropyLoss()
NLL = nn.NLLLoss()

# CrossEntropyLoss
output_CEL = CEL(input, target)
print(f"output_CEL = {output_CEL}")

# LogSoftmax + NLLLoss
logSM_input = log_softmax(input)
output_NLL = NLL(logSM_input, target)
print(f"output_NLL = {output_NLL}")


"""
output_CEL = 2.1668357849121094
output_NLL = 2.1668357849121094
"""

可以看到,nn.CrossEntropyLoss 的计算结果与 nn.LogSoftmax + nn.NLLLoss 的组合计算结果完全相同。

本节参考资料:

Pytorch踩坑记之交叉熵(nn.CrossEntropy,nn.NLLLoss,nn.BCELoss的区别和使用)

Pytorch 中使用nn.CrossEntropyLoss的注意点(不需要额外的softmax)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/646088.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python篇——数据结构与算法(第六部分:哈希表)

目录 1、直接寻址表 2、直接寻址表缺点 3、哈希 4、哈希表 5、解决哈希冲突 6、拉链法 7、常见哈希函数 8、哈希表的实现 8.1迭代器iter()和__iter__ 8.2str()和repr() 8.3代码实现哈希表 8.4哈…

【数据库】Mysql数据库管理

文章目录 引言一、Mysql数据库管理1. 库和表2. 常用的数据类型3. char和varchar区别 二、SQL语句1. SQL语句分类2. 查看数据库结构3. DDL数据定义语言3.1 创建新的数据库3.2 创建新的表3.3 删除指定数据表3.4 删除指定数据库 4. DML数据操控语言4.1 向数据表中插入新的内容4.2 …

连以太网接口和串口傻傻分不清?看完本文就懂了

概要 路由器是一种网络设备,它的主要功能是在不同的网络之间转发数据包,实现网络互联。路由器根据数据包的目的地址,选择最佳的路径,将数据包发送到下一跳。路由器可以连接不同的网络类型,如以太网、帧中继、PPP等。 …

ChatGPT读PDF、生成思维导图的几种方案

大家好,我是可夫小子,《小白玩转ChatGPT》专栏作者,关注AIGC、读书和自媒体。 日常办公,我们离不开pdf文档读取,思维导图制作,那么ChatGPT能够给我们什么帮助呢? 通常的方法是:我们…

14、Nginx---缓存服务

一、缓存类型 1、服务器端缓存 2、代理缓存 3、客户端缓存 代理缓存的原理: 二、代理缓存配置语法 2.1、代理缓存路径 proxy_cache_path path [levelslevels] [use_temp_pathon|off] keys_zonename:size [inactivetime] [max_sizesize] [manager_filesnumber] [mana…

如何让你的allure报告测试步骤更清晰,更具吸引力?

引言 在软件测试中,清晰的测试步骤对于团队的协作和问题跟踪至关重要,Allure报告是一种强大的工具,能够将测试结果以直观和易于理解的方式呈现给您的团队和客户。 想要让Allure报告更具吸引力和可读性吗?那就不要错过我的精彩建…

MIT6.024学习笔记(三)——图论(2)

科学是使人变得勇敢的最好途径。——布鲁诺 文章目录 通信网络问题二叉树型直径路由器规模路由器数量拥挤程度 二维数组型直径路由器规模路由器数量拥挤程度 蝴蝶型直径路由器规模路由器数量拥挤程度 benes型直径路由器规模路由器数量拥挤 通信网络问题 在通信网络中&#xff…

Redis基础知识(安装基础指令等)

Redis 基础知识 相关资料 官网: https://redis.io/中文地址: http://redis.cn/下载地址: https://redis.io/download 为什么需要Redis 企业需求 高并发 高可用 高性能 海量用户 关系型数据库(如MySQL)-问题 性能瓶颈:磁盘IO 性能低下 扩展瓶颈&#xff1a…

Java基础小项目——【源码】控制台的类似BOSS招聘的一个应聘者用户和公司用户的就业项目【应聘+招聘】

目录 引出题目要求--云就业平台相关的java基础知识项目分层设计 核心业务图解源码总结 引出 类似BOSS招聘的一个应聘者用户和公司用户的就业项目,控制台项目 题目要求–云就业平台 类似BOSS招聘的一个应聘者用户和公司用户的就业项目 第3章 应用系统功能介绍 3…

【Jetpack】使用 Room Migration 升级数据库并导出 Schema 文件 ( Schema 文件简介 | 生成 Schema 文件配置 | 生成 Schema 文件过程 )

文章目录 一、Schema 文件简介二、生成 Schema 文件配置三、生成 Schema 文件过程1、数据库版本 1 - 首次运行应用2、数据库版本 1 升级至 数据库版本 2 - 第二次运行应用3、数据库版本 2 升级至 数据库版本 3 - 第三次运行应用 一、Schema 文件简介 使用 Room Migration 升级数…

Windows Subsystem for Android (WSA) 下载:在 Windows 11 上运行 Android 应用 (June 2023)

适用于 Android™️ 的 Windows 子系统,2023 年 6 月更新 请访问原文链接:https://sysin.org/blog/wsa/,查看最新版。原创作品,转载请保留出处。 作者主页:sysin.org 适用于 Android™️ 的 Windows 子系统使你的 Wi…

Linux系统之ifconfig命令的基本使用

Linux系统之ifconfig命令的基本使用 一、ifconfig命令介绍1. ifconfig简介2. ifconfig注意事项3. ifconfig命令特点 二、ifconfig命令的使用方法1. 查看ifconfig的帮助信息2. ifconfig的使用帮助 三、安装ifconfig命令工具1. 安装net-tools软件包2. 查看ifconfig工具的版本 四、…

至暗时刻,显卡销量腰斩,NVIDIA提前掏出2000元档4060救场

不知道大家有没有感觉,自从 RTX 40 系显卡面世后,玩家们对于装机热情却是反常理的不增反降。 以往每代新显卡出来,哪次不是掀起一阵装机热潮。 然而这次小忆听到最多的声音就是:手里 750Ti 还能再战、GTX 1060 永远滴神等。 当然…

pandas链式操作与SettingWithCopyWarning详解

1.SettingWithCopyWarning问题 SettingWithCopyWarning是pandas中一个经典问题,也是pandas库中位数不多的坑之一。关于这个问题,我们先看下面的一个例子。 import pandas as pddef t1():data {name: [a, b, c, d, e, f],num: [1, 2, 3, 4, 5, 6],ss: …

Linux系统命令与网络、磁盘参数和日志监控

文章目录 1、grep搜索命令2、wc命令3、 uptime机器启动时间负载4、ulimit用户资源5、scp远程拷贝6、dos2unix和unix2dos7、sed 1、grep搜索命令 grep命令用于在文件中搜索,并显示匹配效果 # 1、在指定文件查找,查找int main grep int main server.c# 2…

接口自动化测试丨如何处理 Header cookie

Cookie(复数形态:Cookies)是某些网站为了辨别用户身份而储存在用户本地终端上的数据。在接口测试过程中,如果网站采取了 Cookie 认证的方式,那么发送的请求需要附带 Cookie,才会得到正常的响应的结果。接口…

C#中List<T>的排序相关的使用方法总结

C#中List<>的排序相关的使用方法 list的排序一般使用Sort和LINQ的Orderby方法&#xff0c;本文主要介绍其如何使用。 &#x1f32e;1.Sort和实现Comparable接口 此方式需要类去实现IComparable接口 public class OrderTest {[Test]public void OraderTest(){List<E…

【Excel】如何给Excel表格(文档)加密/上锁

目录 0.环境 1.操作步骤 若只输入了【打开权限密码】 若只输入了【修改权限密码】 若同时输入两种密码 0.环境 windows excel2021版 适用情景&#xff1a; 希望给别人提供文档时&#xff0c;需要用密码打开这个文档&#xff0c;加密又分为【打开时加密】和【修改时加密…

【sentinel】Sentinel规则的持久化

Sentinel规则的推送有下面三种模式: 推送模式说明优点缺点原始模式API将规则推送至客户端并直接更新到内存中简单&#xff0c;无任何依赖不保证一致性&#xff1b;规则保存在内存中&#xff0c;重启即消失。严重不建议用于生产环境Pull模式扩展写数据源&#xff08;WritableDa…

【Flutter】MAC环境下打包APK

1、打开终端生成签名文件 keytool -genkey -v -keystore ~/sign.jks -keyalg RSA -keysize 2048 -validity 10000 -alias sign 结果 输入**库口令: 再次输入新口令: 您的名字与姓氏是什么?[Unknown]: yuanzhiying 您的组织单位名称是什么?[Unknown]: gongsi 您的组织…