深入理解Softmax函数及其在PyTorch中的实现

news2025/4/13 3:28:32

Softmax函数简介

Softmax函数在机器学习和深度学习中,被广泛用于多分类问题的输出层。它将一个实数向量转换为概率分布,使得每个元素介于0和1之间,且所有元素之和为1。

Softmax函数的定义

给定一个长度为 K K K的输入向量 z = [ z 1 , z 2 , … , z K ] \boldsymbol{z} = [z_1, z_2, \dots, z_K] z=[z1,z2,,zK],Softmax函数 σ ( z ) \sigma(\boldsymbol{z}) σ(z)定义为:

σ ( z ) i = e z i ∑ j = 1 K e z j , 对于所有  i = 1 , 2 , … , K \sigma(\boldsymbol{z})_i = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}, \quad \text{对于所有 } i = 1, 2, \dots, K σ(z)i=j=1Kezjezi,对于所有 i=1,2,,K

其中:

  • e e e是自然对数的底数,约为2.71828。
  • σ ( z ) i \sigma(\boldsymbol{z})_i σ(z)i是输入向量第 i i i个分量对应的Softmax输出。

Softmax函数的特点

  1. 将输出转换为概率分布:Softmax的输出向量中的每个元素都在 ( 0 , 1 ) (0, 1) (0,1)之间,并且所有元素的和为1,这使得输出可以视为各类别的概率。

  2. 强调较大的值:Softmax函数会放大输入向量中较大的元素对应的概率,同时压缩较小的元素对应的概率。这种特性有助于突出模型认为更有可能的类别。

  3. 可微性:Softmax函数是可微的,这对于基于梯度的优化算法(如反向传播)非常重要。


数值稳定性的问题

在实际计算中,为了防止指数函数计算过程中可能出现的数值溢出,通常会对输入向量进行调整。常见的做法是在计算Softmax之前,从输入向量的每个元素中减去向量的最大值:

σ ( z ) i = e z i − z max ∑ j = 1 K e z j − z max \sigma(\boldsymbol{z})_i = \frac{e^{z_i - z_{\text{max}}}}{\sum_{j=1}^{K} e^{z_j - z_{\text{max}}}} σ(z)i=j=1Kezjzmaxezizmax

其中, z max = max ⁡ { z 1 , z 2 , … , z K } z_{\text{max}} = \max\{z_1, z_2, \dots, z_K\} zmax=max{z1,z2,,zK}。这种调整不会改变Softmax的输出结果,但能提高计算的数值稳定性。


Softmax函数的应用场景

  1. 多分类问题:在神经网络的最后一层,Softmax函数常用于将模型的线性输出转换为概率分布,以进行多分类预测。

  2. 注意力机制:在深度学习中的注意力模型中,Softmax用于计算注意力权重,以突显重要的输入特征。

  3. 语言模型:在自然语言处理任务中,Softmax函数用于预测下一个词的概率分布。


Softmax函数的示例计算

假设有一个三类别分类问题,神经网络的输出为一个长度为3的向量:

z = [ z 1 , z 2 , z 3 ] = [ 2.0 , 1.0 , 0.1 ] \boldsymbol{z} = [z_1, z_2, z_3] = [2.0, 1.0, 0.1] z=[z1,z2,z3]=[2.0,1.0,0.1]

我们想使用Softmax函数将其转换为概率分布。

步骤1:计算每个元素的指数

e z 1 = e 2.0 = 7.3891 e z 2 = e 1.0 = 2.7183 e z 3 = e 0.1 = 1.1052 \begin{align*} e^{z_1} &= e^{2.0} = 7.3891 \\ e^{z_2} &= e^{1.0} = 2.7183 \\ e^{z_3} &= e^{0.1} = 1.1052 \end{align*} ez1ez2ez3=e2.0=7.3891=e1.0=2.7183=e0.1=1.1052

步骤2:计算指数和

sum = e z 1 + e z 2 + e z 3 = 7.3891 + 2.7183 + 1.1052 = 11.2126 \text{sum} = e^{z_1} + e^{z_2} + e^{z_3} = 7.3891 + 2.7183 + 1.1052 = 11.2126 sum=ez1+ez2+ez3=7.3891+2.7183+1.1052=11.2126

步骤3:计算Softmax输出

σ 1 = e z 1 sum = 7.3891 11.2126 = 0.6590 σ 2 = e z 2 sum = 2.7183 11.2126 = 0.2424 σ 3 = e z 3 sum = 1.1052 11.2126 = 0.0986 \begin{align*} \sigma_1 &= \frac{e^{z_1}}{\text{sum}} = \frac{7.3891}{11.2126} = 0.6590 \\ \sigma_2 &= \frac{e^{z_2}}{\text{sum}} = \frac{2.7183}{11.2126} = 0.2424 \\ \sigma_3 &= \frac{e^{z_3}}{\text{sum}} = \frac{1.1052}{11.2126} = 0.0986 \end{align*} σ1σ2σ3=sumez1=11.21267.3891=0.6590=sumez2=11.21262.7183=0.2424=sumez3=11.21261.1052=0.0986

因此,经过Softmax函数后,输出概率分布为:

σ ( z ) = [ 0.6590 , 0.2424 , 0.0986 ] \sigma(\boldsymbol{z}) = [0.6590, 0.2424, 0.0986] σ(z)=[0.6590,0.2424,0.0986]

这表示模型预测第一个类别的概率约为65.9%,第二个类别约为24.24%,第三个类别约为9.86%。


使用PyTorch实现Softmax函数

在PyTorch中,可以通过多种方式实现Softmax函数。以下将通过示例演示如何使用torch.nn.functional.softmaxtorch.nn.Softmax

创建输入数据

首先,创建一个示例输入张量:

import torch
import torch.nn as nn
import torch.nn.functional as F

# 创建一个输入张量,形状为 (batch_size, features)
input_tensor = torch.tensor([[2.0, 1.0, 0.1],
                             [1.0, 3.0, 0.2]])
print("输入张量:")
print(input_tensor)

输出:

输入张量:
tensor([[2.0000, 1.0000, 0.1000],
        [1.0000, 3.0000, 0.2000]])

方法一:使用torch.nn.functional.softmax

利用PyTorch中torch.nn.functional.softmax函数直接对输入数据应用Softmax。

# 在维度1上(即特征维)应用Softmax
softmax_output = F.softmax(input_tensor, dim=1)
print("\nSoftmax输出:")
print(softmax_output)

输出:

Softmax输出:
tensor([[0.6590, 0.2424, 0.0986],
        [0.1065, 0.8726, 0.0209]])

方法二:使用torch.nn.Softmax模块

也可以使用torch.nn中的Softmax模块。

# 创建一个Softmax层实例
softmax = nn.Softmax(dim=1)

# 对输入张量应用Softmax层
softmax_output_module = softmax(input_tensor)
print("\n使用nn.Softmax模块的输出:")
print(softmax_output_module)

输出:

使用nn.Softmax模块的输出:
tensor([[0.6590, 0.2424, 0.0986],
        [0.1065, 0.8726, 0.0209]])

在神经网络模型中应用Softmax

构建一个简单的神经网络模型,在最后一层使用Softmax激活函数。

class SimpleNetwork(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SimpleNetwork, self).__init__()
        self.layer1 = nn.Linear(input_size, 5)
        self.layer2 = nn.Linear(5, num_classes)
        # 使用LogSoftmax提高数值稳定性
        self.softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = self.layer2(x)
        x = self.softmax(x)
        return x

# 定义输入大小和类别数
input_size = 3
num_classes = 3

# 创建模型实例
model = SimpleNetwork(input_size, num_classes)

# 查看模型结构
print("\n模型结构:")
print(model)

输出:

模型结构:
SimpleNetwork(
  (layer1): Linear(in_features=3, out_features=5, bias=True)
  (layer2): Linear(in_features=5, out_features=3, bias=True)
  (softmax): LogSoftmax(dim=1)
)

前向传播:

# 将输入数据转换为浮点型张量
input_data = input_tensor.float()

# 前向传播
output = model(input_data)
print("\n模型输出(对数概率):")
print(output)

输出:

模型输出(对数概率):
tensor([[-1.2443, -0.7140, -1.2645],
        [-1.3689, -0.6535, -1.5142]], grad_fn=<LogSoftmaxBackward0>)

转换为概率:

# 取指数,转换为概率
probabilities = torch.exp(output)
print("\n模型输出(概率):")
print(probabilities)

输出:

模型输出(概率):
tensor([[0.2882, 0.4898, 0.2220],
        [0.2541, 0.5204, 0.2255]], grad_fn=<ExpBackward0>)

预测类别:

# 获取每个样本概率最大的类别索引
predicted_classes = torch.argmax(probabilities, dim=1)
print("\n预测的类别:")
print(predicted_classes)

输出:

预测的类别:
tensor([1, 1])

torch.nn.functional.softmaxtorch.nn.Softmax的区别

函数式API与模块化API的设计理念

PyTorch提供了两种API:

  1. 函数式API (torch.nn.functional)

    • 特点:无状态(Stateless),不包含可学习的参数。
    • 使用方式:直接调用函数。
    • 适用场景:需要在forward方法中灵活应用各种操作。
  2. 模块化API (torch.nn.Module)

    • 特点:有状态(Stateful),可能包含可学习的参数,即使某些模块没有参数(如Softmax),但继承自nn.Module
    • 使用方式:需要先实例化,再在前向传播中调用。
    • 适用场景:构建模型时,统一管理各个层和操作。

具体到Softmax的实现

  • torch.nn.functional.softmax(函数)

    • 使用示例

      import torch.nn.functional as F
      output = F.softmax(input_tensor, dim=1)
      
    • 特点:直接调用,简洁灵活。

  • torch.nn.Softmax(模块)

    • 使用示例

      import torch.nn as nn
      softmax = nn.Softmax(dim=1)
      output = softmax(input_tensor)
      
    • 特点:作为模型的一层,便于与其他层组合,保持代码结构一致。

为什么存在两个实现?

提供两种实现方式是为了满足不同开发者的需求和编程风格。

  • 使用nn.Softmax的优势

    • 在模型定义阶段明确各层,结构清晰。
    • 便于使用nn.Sequential构建顺序模型。
    • 统一管理模型的各个部分。
  • 使用F.softmax的优势

    • 代码简洁,直接调用函数。
    • 适用于需要在forward中进行灵活操作的情况。

使用示例

使用nn.Softmax
import torch
import torch.nn as nn

# 定义模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer = nn.Linear(10, 5)
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = self.layer(x)
        x = self.softmax(x)
        return x

# 实例化和使用
model = MyModel()
input_tensor = torch.randn(2, 10)
output = model(input_tensor)
print(output)
使用F.softmax
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer = nn.Linear(10, 5)
    
    def forward(self, x):
        x = self.layer(x)
        x = F.softmax(x, dim=1)
        return x

# 实例化和使用
model = MyModel()
input_tensor = torch.randn(2, 10)
output = model(input_tensor)
print(output)

总结

Softmax函数在深度学习中起着关键作用,尤其在多分类任务中。PyTorch为了满足不同的开发需求,提供了torch.nn.functional.softmaxtorch.nn.Softmax两种实现方式。

  • F.softmax:函数式API,灵活简洁,适合在forward方法中直接调用。

  • nn.Softmax:模块化API,便于模型结构的统一管理,适合在模型初始化时定义各个层。

在实际开发中,选择适合你的项目和团队的方式。如果更喜欢模块化的代码结构,使用nn.Softmax;如果追求简洁和灵活,使用F.softmax。同时,要注意数值稳定性的问题,尤其是在计算损失函数时,建议使用nn.LogSoftmaxnn.NLLLoss结合使用。


参考文献

  • PyTorch官方文档 - Softmax函数
  • PyTorch官方文档 - nn.Softmax
  • PyTorch官方教程 - 构建神经网络
  • PyTorch论坛 - Softmax激活函数:nn.Softmax vs F.softmax

希望本文能帮助读者深入理解Softmax函数及其在PyTorch中的实现和应用。如有任何疑问,欢迎交流讨论!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2333527.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java导入excel更新设备经纬度度数或者度分秒

文章目录 一、背景介绍二、页面效果三、代码0.pom.xml1.ImportDevice.vue2.ImportDeviceError.vue3.system.js4.DeviceManageControl5.DeviceManageUserControl6.Repeater7.FileUtils8.ResponseModel9.EnumLongitudeLatitude10.词条 四、注意点本人其他相关文章链接 一、背景介…

视频设备轨迹回放平台EasyCVR远程监控体系落地筑牢国土监管防线

一、背景概述 我国土地资源遭违法滥用的现象愈发严峻&#xff0c;各类土地不合理利用问题频发。不当的土地开发不仅加剧了地质危害风险&#xff0c;导致良田受损、森林资源的滥伐&#xff0c;还引发了煤矿无序开采、城市开发区违建等乱象&#xff0c;给国家宝贵的土地资源造成…

Stable Diffusion 四重调参优化——项目学习记录

学习记录还原&#xff1a;在本次实验中&#xff0c;我基于 Stable Diffusion v1.5模型&#xff0c;通过一系列优化方法提升生成图像的质量&#xff0c;最终实现了图像质量的显著提升。实验从基础的 Img2Img 技术入手&#xff0c;逐步推进到参数微调、DreamShaper 模型和 Contro…

我可能用到的网站和软件

我可能用到的网站和软件 程序员交流的网站代码管理工具前端组件库前端框架在线工具人工智能问答工具学习的网站Windows系统电脑的常用工具 程序员交流的网站 csdn博客博客园 - 开发者的网上家园InfoQ - 软件开发及相关领域-极客邦掘金 (juejin.cn) 代码管理工具 GitHub 有时…

FPGA状态机设计:流水灯实现、Modelsim仿真、HDLBits练习

一、状态机思想 1.概念 状态机&#xff08;Finite State Machine, FSM&#xff09;是计算机科学和工程领域中的一种抽象模型&#xff0c;用于描述系统在不同状态之间的转换逻辑。其核心思想是将复杂的行为拆解为有限的状态&#xff0c;并通过事件触发状态间的转移。 2.状态机…

2024年第十五届蓝桥杯CC++大学A组--成绩统计

2024年第十五届蓝桥杯C&C大学A组--成绩统计 题目&#xff1a; 动态规划&#xff0c; 对于该题&#xff0c;考虑动态规划解法&#xff0c;先取前k个人的成绩计算其方差&#xff0c;并将成绩记录在数组中&#xff0c;记录当前均值&#xff0c;设小蓝已检查前i-1个人的成绩&…

Kotlin 学习-集合

/*** kotlin 集合* List:是一个有序列表&#xff0c;可通过索引&#xff08;下标&#xff09;访问元素。元素可以在list中出现多次、元素可重复* Set:是元素唯一的集合。一般来说 set中的元素顺序并不重要、无序集合* Map:&#xff08;字典&#xff09;是一组键值对。键是唯一的…

自动驾驶的未来:多模态感知融合技术最新进展

作为自动驾驶领域的专业人士&#xff0c;我很高兴与大家分享关于多模态感知融合技术的前沿研究和实践经验。在迅速发展的自动驾驶领域&#xff0c;多模态感知融合已成为提升系统性能的关键技术。本文将深入探讨基于摄像头和激光雷达的多模态感知融合技术&#xff0c;重点关注最…

亮相2025全球分布式云大会,火山引擎边缘云落地AI新场景

4 月 9 日&#xff0c;2025 全球分布式云大会暨 AI 基础设施大会在深圳成功举办&#xff0c;火山引擎边缘云产品解决方案高级总监沈建发出席并以《智启边缘&#xff0c;畅想未来&#xff1a;边缘计算新场景落地与 Al 趋势新畅想》为主题&#xff0c;分享了边缘计算在 AI 技术趋…

无损分区管理,硬盘管理的“瑞士军刀”!

打工人们你们好&#xff01;这里是摸鱼 特供版~ 今天给大家带来一款简单易用、功能强大的无损分区软件——分区助手技术员版&#xff0c;让你的硬盘管理变得轻松又高效&#xff01; 推荐指数&#xff1a;★★★★★ 软件简介 分区助手技术员版是一款功能强大的硬盘分区工具&…

VS Code下开发FPGA——FPGA开发体验提升__下

上一篇&#xff1a;IntelliJ IDEA下开发FPGA-CSDN博客 Type&#xff1a;Quartus 一、安装插件 在应用商店先安装Digtal IDE插件 安装后&#xff0c;把其他相关的Verilog插件禁用&#xff0c;避免可能的冲突。重启后&#xff0c;可能会弹出下面提示 这是插件默认要求的工具链&a…

ffmpeg播放音视频流程

文章目录 &#x1f3ac; FFmpeg 解码播放流程概览&#xff08;以音视频文件为例&#xff09;1️⃣ 创建结构体2️⃣ 打开音视频文件3️⃣ 查找解码器并打开解码器4️⃣ 循环读取数据包&#xff08;Packet&#xff09;5️⃣ 解码成帧&#xff08;Frame&#xff09;6️⃣ 播放 / …

SpringCloud微服务: 分布式架构实战

# SpringCloud微服务: 分布式架构实战 第一章&#xff1a;理解SpringCloud微服务架构 什么是SpringCloud微服务架构&#xff1f; 在当今互联网应用开发中&#xff0c;微服务架构已经成为业界的主流趋势。SpringCloud是一个基于Spring Boot的快速开发微服务架构的工具&#xff0…

AI预测3D新模型百十个定位预测+胆码预测+去和尾2025年4月11日第49弹

从今天开始&#xff0c;咱们还是暂时基于旧的模型进行预测&#xff0c;好了&#xff0c;废话不多说&#xff0c;按照老办法&#xff0c;重点8-9码定位&#xff0c;配合三胆下1或下2&#xff0c;杀1-2个和尾&#xff0c;再杀6-8个和值&#xff0c;可以做到100-300注左右。 (1)定…

【models】Transformer 之 各种 Attention 原理和实现

Transformer 之 各种 Attention 原理和实现 本文将介绍Transformer 中常见的Attention的原理和实现&#xff0c;其中包括&#xff1a; Self Attention、Spatial Attention、Temporal Attention、Cross Attention、Grouped Attention、Tensor Product Attention、FlashAttentio…

老硬件也能运行的Win11 IoT LTSC (OEM)物联网版

#记录工作 Windows 11 IoT Enterprise LTSC 2024 属于物联网相关的版本。 Windows 11 IoT Enterprise 是为物联网设备和场景设计的操作系统版本。它通常针对特定的工业控制、智能设备等物联网应用进行了优化和定制&#xff0c;以满足这些领域对稳定性、安全性和长期支持的需求…

Git开发

目录 Linux下Git安装Git基本指令分支管理远程仓库与本地仓库标签管理多人协作同一分支下不同分支下 企业级开发模型 -- git flow 模型 在现实中&#xff0c;当我们完成一个文档的初稿后&#xff0c;后面可能还需要对初稿进行反复修改&#xff0c;从而形成不同版本的文档。显然&…

verilog有符号数的乘法

无符号整数的乘法 1、单周期乘法器&#xff08; 无符号整数 &#xff09; 对于低速要求的乘法器&#xff0c;可以简单的使用 * 实现。 module Mult(input wire [7:0] multiplicand ,input wire [7:0] multipliter ,output wire [7:0] product);as…

DevDocs:抓取并整理技术文档的MCP服务

GitHub&#xff1a;https://github.com/cyberagiinc/DevDocs 更多AI开源软件&#xff1a;发现分享好用的AI工具、AI开源软件、AI模型、AI变现 - 小众AI DevDocs 是一个完全免费的开源工具&#xff0c;由 CyberAGI 团队开发&#xff0c;托管在 GitHub 上。它专为程序员和软件开发…

第十四届蓝桥杯大赛软件赛国赛Python大学B组题解

文章目录 弹珠堆放划分偶串交易账本背包问题翻转最大阶梯最长回文前后缀贸易航线困局 弹珠堆放 递推式 a i a i − 1 i a_ia_{i-1}i ai​ai−1​i&#xff0c; n 20230610 n20230610 n20230610非常小&#xff0c;直接模拟 答案等于 494 494 494 划分 因为总和为 1 e 6 1e6…