深入理解Softmax:从“Hard”到“Soft”的转变

news2024/11/24 1:53:13

深入理解Softmax:从“Hard”到“Soft”的转变

在机器学习的分类任务中,Softmax 函数是一个极其重要的工具。它不仅将神经网络的输出转化为概率分布,还能有效处理多分类问题。然而,为了更好地理解Softmax,我们可以先将其拆解为 “soft” 和 “max” 两个部分,并探讨它们各自的意义。

从“Hard”到“Soft”

在某些情况下,我们可能会考虑直接选择输出层中的最大值作为预测结果。这种方式可以被称为 “Hard” 选择,即直接在所有输出中选择最大的那个,忽略其他所有信息。举个例子,假设我们有一个输出向量 ([0.2, 0.3, 0.5]),在这种 Hard 选择方式中,我们会直接选择最大值 (0.5) 对应的类别作为最终的预测结果。

在代码实现上,这种 Hard 选择非常简单:

import numpy as np

# 示例数据
outputs = np.array([0.2, 0.3, 0.5])

# Hard max选择
predicted_class = np.argmax(outputs)
print(predicted_class)  # 输出:2,对应0.5

然而,这种 Hard 方式在实际应用中往往不够合理。原因是,很多情况下输出层的多个值可能非常接近,这样直接选最大值忽略了其他可能的选项。例如,在文本分类中,一个文档可能同时包含多个主题,这时直接选最大值会导致潜在的有意义信息丢失。

为了更好地反映各个类别的可能性,我们引入了 Softmax 函数。与 Hard 选择不同,Softmax 不仅关注最大值,还能衡量其他类别的可能性。它通过将输出层的每个值转换为一个概率分布,给出了每个类别的置信度。

Softmax 的数学原理与实现

Softmax 函数的核心是使用 指数函数,这在数学上可以表示为:

Softmax ( z i ) = e z i ∑ j e z j \text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j} e^{z_j}} Softmax(zi)=jezjezi

其中, z i z_i zi 是输出层的第 i i i 个神经元的激活值。通过指数函数的作用,Softmax 会放大较大的值,缩小较小的值,这样就拉开了不同类别之间的差距。
在这里插入图片描述

我们可以通过一个简单的例子来看一下:

假设我们有一个输出向量 [ 1.0 , 2.0 , 3.0 ] [1.0,2.0,3.0] [1.0,2.0,3.0],不使用指数函数和使用指数函数的结果差别如下:

import numpy as np

# 示例数据
outputs = np.array([1.0, 2.0, 3.0])

# 计算 Softmax
exp_outputs = np.exp(outputs)
softmax_outputs = exp_outputs / np.sum(exp_outputs)

print(softmax_outputs)
# 输出:[0.09003057 0.24472847 0.66524096]

我们可以看到,经过 Softmax 变换后,最大的值变得更加显著,而较小的值被进一步减弱。

Softmax 的数值稳定性问题

虽然 Softmax 函数的设计非常巧妙,但它也存在数值稳定性的问题。当输入的值非常大时,指数函数的计算可能会导致数值溢出,使得计算结果不准确。

因此,在实际使用中,我们通常会先找到输出向量中的最大值,然后将每个值减去这个最大值,再应用 Softmax 函数。这种方式不仅可以避免溢出问题,还能提高计算的精度。

# 数值稳定的Softmax实现
outputs = np.array([1000, 1001, 1002])  # 示例数据

# 减去最大值
shifted_outputs = outputs - np.max(outputs)

exp_outputs = np.exp(shifted_outputs)
softmax_outputs = exp_outputs / np.sum(exp_outputs)

print(softmax_outputs)
# 输出:[0.09003057 0.24472847 0.66524096]

为什么 Softmax 使用指数函数?

你可能会问,为什么在 Softmax 中使用的是指数函数而不是其他的函数?这一点非常关键。指数函数具有以下几个特点,使其在 Softmax 中尤其有效:

  1. 非负性:指数函数的值永远是非负的,确保了输出的概率不会出现负数。
  2. 快速增长:指数函数随着输入的增大而迅速增长,这意味着较大的输出值会更突出,有利于区分不同的类别。
  3. 比例缩放:由于指数函数的快速增长特性,当一个输出值远大于其他值时,Softmax 函数会接近于选择这个值对应的类别,从而类似于 Hard Max 的效果,但仍保留了一些对其他选项的考虑。

Softmax 的数值计算效率

在实际应用中,Softmax 函数的计算可以借助向量化操作进一步提高效率。大多数深度学习框架(如TensorFlow、PyTorch)都对Softmax进行了高度优化,使得在大规模数据和复杂网络结构中,它依然可以高效运行。

以下是使用PyTorch实现的一个高效Softmax示例:

import torch
import torch.nn.functional as F

# 示例数据
outputs = torch.tensor([0.2, 0.3, 0.5])

# 使用PyTorch的Softmax函数
softmax_outputs = F.softmax(outputs, dim=0)

print(softmax_outputs)
# 输出:tensor([0.3006, 0.3322, 0.3672])

Softmax与交叉熵损失的关系

当我们使用 Softmax 作为输出层的激活函数时,通常会配合使用 交叉熵损失函数 来优化模型。交叉熵损失函数可以通过衡量预测分布与真实分布之间的差异,指导模型的学习过程。

交叉熵损失的数学表达式为:

Cross-Entropy Loss = − ∑ i y i log ⁡ ( y ^ i ) \text{Cross-Entropy Loss} = -\sum_{i} y_i \log(\hat{y}_i) Cross-Entropy Loss=iyilog(y^i)

其中, y i y_i yi 是真实类别的概率分布(通常是one-hot编码), y ^ i \hat{y}_i y^i 是预测的概率分布。

这种组合的妙处在于,Softmax 的输出是一个概率分布,而 交叉熵损失 正是用来衡量两个概率分布之间的差异。这种配合使得梯度下降过程中的导数计算非常简单且高效,因为在这种组合下,导数的结果与Softmax输出非常相关,这使得反向传播的计算更为直接。

交叉熵的意义

交叉熵 H ( P , Q ) H(P, Q) H(P,Q) 表示了我们使用分布 Q Q Q(模型的预测)去编码来自分布 P P P(真实分布)的数据时,所需的平均编码长度。这个平均长度由两部分组成:

  1. P P P 的熵:这是任何模型都无法避免的最低编码长度。
  2. 相对熵(Kullback-Leibler divergence,简称 KL 散度):这是由于 Q Q Q P P P 之间的差异导致的额外编码长度。

交叉熵可以被分解为:

H ( P , Q ) = H ( P ) + D K L ( P ∥ Q ) H(P, Q) = H(P) + D_{KL}(P \parallel Q) H(P,Q)=H(P)+DKL(PQ)

其中, D K L ( P ∥ Q ) D_{KL}(P \parallel Q) DKL(PQ) 是 KL 散度,衡量的是 Q Q Q P P P 的差异程度。理想情况下,我们希望模型的预测 Q Q Q 完全与真实分布 P P P 一致,这样 KL 散度为零,交叉熵等于 H ( P ) H(P) H(P)

为什么交叉熵损失可以衡量两个概率分布的差异?

在神经网络中,我们通过最小化交叉熵损失函数来训练模型,即希望模型的预测分布 Q Q Q 越来越接近真实的分布 P P P。具体来说,交叉熵损失的最小化过程可以理解为在减少模型的预测分布 Q Q Q 与真实分布 P P P 之间的 KL 散度。

通过最小化交叉熵损失,模型被迫调整参数,使得 Q Q Q 更接近 P P P,即 y ^ i \hat{y}_i y^i 尽量接近 y i y_i yi 。因此,交叉熵损失不仅提供了一个衡量模型预测与真实分布差异的度量,而且在梯度下降过程中,直接引导模型缩小这种差异。

直观解释

考虑分类问题中的一个简单例子,假设真实分布 P P P 为一个 one-hot 编码的向量,表示某个样本真实属于某个类别,即 P = [ 0 , 1 , 0 ] P = [0, 1, 0] P=[0,1,0] 。如果模型的预测分布 Q Q Q 也是准确的 one-hot 编码 [ 0 , 1 , 0 ] [0, 1, 0] [0,1,0],那么交叉熵损失就为零,表示没有差异。

但如果模型预测的是 [ 0.3 , 0.4 , 0.3 ] [0.3, 0.4, 0.3] [0.3,0.4,0.3],那么交叉熵损失就会很大,表示模型的预测远离了真实分布。通过最小化交叉熵损失,模型参数被调整,使得 Q Q Q 更加接近 P P P,即减少 D K L ( P ∥ Q ) D_{KL}(P \parallel Q) DKL(PQ) 的值。

Softmax 在实际应用中的局限性

虽然 Softmax 是多分类问题中广泛使用的工具,但它也存在一些局限性。例如,在类别数量非常大的情况下(如在推荐系统或词汇表非常大的 NLP 模型中),Softmax 的计算成本可能会非常高。这时,通常会使用近似方法,如 层次化Softmax(Hierarchical Softmax)或 负采样(Negative Sampling)来降低计算复杂度。

此外,Softmax 在处理类别不平衡问题时也可能表现不佳。如果某些类别的数据量明显少于其他类别,模型可能会倾向于预测常见类别。这时,可能需要对损失函数进行加权处理,或者采用其他改进措施,如 Focal Loss

总结

通过将“Hard”选择转变为“Soft”选择,Softmax 函数不仅为分类任务提供了概率分布,还通过指数函数的巧妙使用,放大了类别之间的差距。然而,在实际应用中,我们需要注意数值稳定性问题,通过适当的变换避免计算溢出。此外,了解 Softmax 的局限性和改进方法,能够帮助我们在更多复杂场景下更好地应用这一函数。最后,当与交叉熵损失函数配合使用时,Softmax 函数展现出了极高的效率和有效性,使得它成为神经网络分类任务中的标准工具。

这篇博客希望帮助你深入理解Softmax的原理与实现,并更好地应用于实际的深度学习任务中。

参考链接:

  • 一文详解Softmax函数 - 知乎
  • [L4]使用LSTM实现语言模型-softmax与交叉熵 - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2061970.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

漫画小程序源码全开源商业版

介绍: 漫画小程序源码全开源商业版 带漫画资源,带简单安装说明,可以快速发布一个漫画小程序。 代码下载

秋招力扣Hot100刷题总结——链表

1. 反转链表题目连接 题目要求:给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 代码及思路 遍历所有节点,将所有节点的next指向前一个节点由于要改变节点的next指向,而链表是单向的,因此需要…

Spring MVC域对象共享数据

在Spring MVC中,域对象(Domain Object)通常指的是与业务逻辑相关的模型对象,它们代表了应用程序中的核心数据结构。例如,在一个电商应用中,Product、User、Order等类可以被视为域对象。这些对象通常与数据库…

Pod基础使用

POD基本操作 1.Pod生命周期 在Kubernetes中,Pod的生命周期经历了几个重要的阶段。下面是Pod生命周期的详细介绍: Pending(待处理): 调度: Pod被创建后,首先进入“Pending”状态。此时,Kubernetes的调度器…

设计模式24-命令模式

设计模式24-命令模式 写在前面行为变化模式 命令模式的动机定义与结构定义结构 C 代码推导优缺点应用场景总结补充函数对象(Functors)定义具体例子示例:使用函数对象进行自定义排序代码说明输出结果具体应用 优缺点应用场景 命令模式&#xf…

查看 CUDA 和 cuDNN 版本

在安装 onnxruntime-gpu 选择版本时需要查看本机 CUDA 和 cuDNN 版本。 查看 CUDA 和 cuDNN 版本 import platform import torchprint("python.version:", platform.python_version()) print("torch.version:", torch.__version__) print("CUDA.vers…

汽车管理 API 接口:开启高效车辆运营新时代

API(Application Programming Interface)是一种接口,用于不同软件之间的通信。在汽车管理领域,API的应用可以帮助提升车辆运营的效率,让车主和车辆管理者更方便地获取车辆相关信息,进行保养和维修等工作。本…

fastadmin api中无法获取用户信息

控制器使用_initialize方法时,要增加 parent::_initialize(); 这行代码,否则会出现获取不到用户信息的问题: public function _initialize() {// 你的逻辑内容// ...// endparent::_initialize(); }

Chapter 01 Vue入门

前言 Vue 是一个框架,也是一个生态,其功能覆盖了大部分前端开发常见的需求。本文详细讲解了 Vue 的基本概念以及 Vue 开发者工具的安装。 一、Vue简介 ①定义 Vue 是一款用于构建用户界面的渐进式框架。它基于标准 HTML、CSS 和 JavaScript 构建&…

基于RDMA技术的Mayastor解决方案

1. 方案背景和挑战 1.1. Mayastor简介 OpenEBS是一个广受欢迎的开源云原生存储解决方案,托管于CNCF(云原生计算基金会)之下,旨在通过扩展Kubernetes的能力,为有状态应用提供灵活的持久性存储。Mayastor是OpenEBS项目…

maxscale

入门 官网:https://mariadb.com/kb/en/maxscale/ 开发语言:C 是否支持分片:不支持 支持的数据库:MySQL/Mariadb 路由规则:事务包裹的SQL会全部走写库、没有事务包裹SQL读写库通过设置Hint实现。其它功能通过配置文件实…

微服务通信

1、Feign远程调用 Feign是Spring Cloud提供的⼀个声明式的伪Http客户端, 它使得调⽤远程服务就像调⽤本地服务⼀样简单, 只需要创建⼀个接⼝并添加⼀个注解即可。 Nacos很好的兼容了Feign, Feign 默认集为Ribbon, 所以在Nacos下使…

M8020A J-BERT 高性能比特误码率测试仪

M8020A 比特误码率测试仪 J-BERT M8020A 高性能 BERT 产品综述 Keysight J-BERT M8020A 高性能比特误码率测试仪能够快速、准确地表征传输速率高达 16 或 32 Gb/s 的单通道和多通道器件中的接收机。 M8020A 综合了更广泛的功能,可以简化您的测试系统。 自动对信…

AGV导航方法大盘点:3大类,12小类

导语 大家好,我是社长,老K。专注分享智能制造和智能仓储物流等内容。 在自动化物流领域,自动导引车(AGV)扮演着至关重要的角色。它们不仅能够提高搬运效率,还能在各种环境中准确无误地完成任务。 而这一切的…

KVM虚拟化之命令行界面创建KVM虚拟机

环境:CentOS8 安装所需软件包 yum groupinstall -y "Virtualization*" 上传一个ISO镜像 使用指令创建KVM虚拟机 给KVM虚拟机创建一个磁盘 -f:指定磁盘类型为qcow2 使用指令创建一个虚拟机 virt-install \ --nameCentos-2 \ --vcpu 1 \ --memory 2048 \ -…

【SpringCloud】(一文通)服务注册/服务发现-Eureka

目 录 一. 背景1.1 问题描述1.2 解决思路1.3 什么是注册中心1.4 CAP理论1.5 常见的注册中心 二. Eureka 介绍三. 搭建Eureka Server3.1 创建 Eureka-server 子模块3.2 引入 eureka-server 依赖3.3 项目构建插件3.4 完善启动类3.5 编写配置文件3.6 启动服务 四. 服务注册4.1 引入…

Docker基础概述、Docker安装、Docker镜像加速、Docker镜像指令

1.为什么学docker 开发环境与测试环境不同,导致错误 因此docker提供解决方法———系统平滑移植,容器虚拟化技术 将代码与软件与配置文件 打包成一个镜像 2.docker的历练 创建一个开发环境内成为镜像文件再用docker使用镜像 3.什么是docker Docke…

泛型篇(Java - 泛型机制)(持续更新迭代)

目录 私聊 一、什么是泛型,泛型有什么用,为什么要用 1. 说法一 1.1 什么是泛型 1.2 泛型的使用 1.3 为什么要用泛型 2. 说法二 2.1 什么是泛型,泛型有什么用,为什么要用 2.2 怎么使用泛型,泛型可以作用在什么…

私有方法加事务注解会导致事务失效

这里idea其实已经提醒了使用事务不能用私有方法,这其实是个常见问题,这里主要就加深印象

XSS复现

目录 XSS简单介绍 一、反射型 1、漏洞逻辑: 为什么有些标签可以触发,有些标签不能触发 可以触发的标签 不能触发的标签 为什么某些标签能触发而某些不能 二、DOM型 1、Ma Spaghet! 要求: 分析: 结果: 2、J…