笔者在学习Softmax实现时遇到了一个问题,很多文章直接将softmax的计算分成了五个过程,而没有解释每个过程的含义,尤其是在阅读这篇文章时,作者想计算最基本的softmax的效率,以展示可行的优化空间:
贴一个gpt写的解释,个人感觉还蛮清楚的,至于为什么要减去每行的最大值,是为了使大整数运算转化为0-1的小数运算,以避免溢出。
Softmax 函数是深度学习和机器学习中常用的一种激活函数,它将输入向量转换为概率分布。具体地,softmax 函数接受一个输入向量并将其转换为概率向量,输出向量的每个元素表示输入向量中对应元素的概率。Softmax 函数的计算涉及以下五个步骤:
1. ReduceMax
目的:找到每一行的最大值。
操作:
- 对于每一行,找到该行的最大值。这用于后续步骤的数值稳定性。
内存访问:
- 读取:需要读取整个矩阵的所有元素,总共是
D
次读取。 - 写入:需要将每一行的最大值写入到一个数组中,总共是
num_rows
次写入。
总内存访问:D + num_rows
2. BroadcastSub
目的:将每个元素减去其所在行的最大值。
操作:
- 将每个元素减去其所在行的最大值,这样可以避免数值溢出(数值稳定性技巧)。
内存访问:
- 读取:需要读取整个矩阵的所有元素(
D
次),以及每一行的最大值(num_rows
次),总共是D + num_rows
次读取。 - 写入:需要将结果写回到矩阵中,总共是
D
次写入。
总内存访问:2 * D + num_rows
3. Exp
目的:对每个元素计算指数函数(exp)。
操作:
- 计算每个元素的指数函数值
exp(x)
。
内存访问:
- 读取:需要读取整个矩阵的所有元素,总共是
D
次读取。 - 写入:需要将结果写回到矩阵中,总共是
D
次写入。
总内存访问:2 * D
4. ReduceSum
目的:计算每一行的总和。
操作:
- 对于每一行,计算其所有元素的和。这用于将指数函数值标准化为概率。
内存访问:
- 读取:需要读取整个矩阵的所有元素,总共是
D
次读取。 - 写入:需要将每一行的总和写入到一个数组中,总共是
num_rows
次写入。
总内存访问:D + num_rows
5. BroadcastDiv
目的:将每个元素除以其所在行的总和。
操作:
- 将每个元素除以其所在行的总和,使得每一行的元素之和为 1,从而形成概率分布。
内存访问:
- 读取:需要读取整个矩阵的所有元素(
D
次),以及每一行的总和(num_rows
次),总共是D + num_rows
次读取。 - 写入:需要将结果写回到矩阵中,总共是
D
次写入。
总内存访问:2 * D + num_rows
Softmax 过程总结
这五个步骤构成了 softmax 函数的计算过程,具体如下:
- ReduceMax:找到每一行的最大值以进行数值稳定性处理。
- BroadcastSub:将每个元素减去所在行的最大值。
- Exp:对每个元素计算指数函数。
- ReduceSum:计算每一行的元素和。
- BroadcastDiv:将每个元素除以所在行的和,形成概率分布。
每个步骤的内存访问模式和计算方式都是为了确保整个 softmax 计算过程的数值稳定性和准确性,同时在尽量减少全局内存访问次数以优化性能。