文章目录
- 1. 目的
- 2. tanh ( x ) \tanh(x) tanh(x) 的 naive 实现
- 2.1 数学公式
- 2.2 naive 实现
- 3. tanh ( x ) \tanh(x) tanh(x) 的快速计算
- 3.1 Maple 中的近似公式
- 3.2 tan_c3()
- 3.3 Gauss 连分数公式 (Continued Fraction)
- 4. 最终代码和运行结果
- 代码
- 运行结果
- 5. 其他
- References
1. 目的
用于 LeNet-5 网络中 squashing function 中 tanh()
部分的计算。tanh()
是 hyperbolic tangent 双曲正切三角函数的意思。
LeNet-5 网络的 C1~F6, 每一层都需要对于输出结果应用 squashing function. 后世称作 activation function.
2. tanh ( x ) \tanh(x) tanh(x) 的 naive 实现
2.1 数学公式
tanh ( x ) = e x − e − x e x + e − x \tanh(x) = \frac{e^x-e^{-x}}{e^x+e^{-x}} tanh(x)=ex+e−xex−e−x
2.2 naive 实现
直接翻译公式得到:
static double m_tanh(double x)
{
double ep = m_exp(x); // exponent of positive x
double en = m_exp(-x); // exponent of negative x
double up = ep - en;
double down = ep + en;
return up / down;
}
其中 m_exp
在前一篇博客1 scratch lenet(8): C语言实现 exp(x) 的计算 给出过:
static double m_fabs(double n)
{
return n >= 0.0 ? n : -n;
}
double m_exp(double x)
{
double res = 1;
double eps = 1e-9;
double up = 1;
double down = 1;
for (int i = 1; ;i++)
{
up *= x;
down *= i;
double delta = up / down;
res += delta;
// printf("i=%d, delta=%lf\n", i, delta);
if (m_fabs(delta) < eps)
break;
}
return res;
}
3. tanh ( x ) \tanh(x) tanh(x) 的快速计算
StackOverFlow 上的一个问答2 给出了好几种近似计算方式。
3.1 Maple 中的近似公式
回答3 给出了一个公式(TL;DR 这一节的公式不靠谱,精度丢失比较多)
The best rational approximation to tanh(x) with numerator and denominator of degree 3 on the interval [0,3.1] (as provided by Maple’s minimax function) is
(-.67436811832e-5+(.2468149110712040+(.583691066395175e-1+.3357335044280075e-1*x)*x)*x)/(.2464845986383725+(.609347197060491e-1+(.1086202599228572+.2874707922475963e-1*x)*x)*x)
This (call it f(x)) has maximum error .2735944241730870e-4, which is considerably less than 2^(-8).
On the interval [−3.1,3.1], use sgn(x)f(|x|
double fast_tanh_by_maple(double x)
{
return (-.67436811832e-5+(.2468149110712040+(.583691066395175e-1+.3357335044280075e-1*x)*x)*x)/(.2464845986383725+(.609347197060491e-1+(.1086202599228572+.2874707922475963e-1*x)*x)*x);
}
zz@Legion-R7000P% ./a.out
Please input a real number: 0.345
tanh(0.345000) = 0.331934
tanh_c3(0.345000) = 0.331935
m_tanh(0.345000) = 0.331934
fast_tanh_by_maple(0.345000) = 0.331907
我尝试后,发现精度差的有点多,并不是所谓的“精度损失小于 .2735944241730870e-4”, 而是肉眼可见的有精度损失:
>>> e1 = .2735944241730870e-4
>>> e2 = 0.331934 - 0.331907
>>> e1 < e2
False
>>>
3.2 tan_c3()
jenkas 给出了一个更好的近似公式和实现4.
float tanh_c3(float v)
{
const float c1 = 0.03138777F;
const float c2 = 0.276281267F;
const float c_log2f = 1.442695022F;
v *= c_log2f;
int intPart = (int)v;
float x = (v - intPart);
float xx = x * x;
float v1 = c_log2f + c2 * xx;
float v2 = x + xx * c1 * x;
float v3 = (v2 + v1);
*((int*)&v3) += intPart << 24;
float v4 = v2 - v1;
return (v3 + v4) / (v3 - v4);
}
暂时没搞懂这个实现对应的公式
v = I + x // 整数部分 + 小数部分
xx = x * x // 小数部分的平方
v1 = c_log2f + c2 * xx
v2 = x + xx * c1 * x
v3 = v2 + v1 = x + xx * c1 * x + c_log2f + c2 * xx
= c_log2f + x + c1 * x * x * x + c2 * x * x
v4 = v2 - v1 = x + xx * c1 * x - c_log2f - c2 * xx
= -c_log2f + x - c2 * x * 2 + c1 * x * x * x
3.3 Gauss 连分数公式 (Continued Fraction)
1812年高斯给出的双曲正切函数 tanh ( x ) \tanh(x) tanh(x) 的连分数展开公式 (continued fraction for the hyperbolic tangent 5)
tanh ( x ) = x 1 + x 2 3 + x 2 5 + . . . \tanh(x) = \frac{x}{1 + \frac{x^2}{3 + \frac{x^2}{5 + ...}}} tanh(x)=1+3+5+...x2x2x
我们使用展开到
9
+
x
2
11
9 + \frac{x^2}{11}
9+11x2 的这一项, 作为 tanh 的近似6:
发现结果非常准确(至少对于 x = 0.345 x=0.345 x=0.345 来说, 和 C 标准库结果一样)
double approx_tanh_by_continues_fraction(double x)
{
double s = x * x;
double y = 9 + s / 11;
y = 7 + s / y;
y = 5 + s / y;
y = 3 + s / y;
y = 1 + s / y;
y = x / y;
return y;
}
4. 最终代码和运行结果
代码
#include <stdio.h>
#include <math.h>
#include <stdbool.h>
double tanh_c3(float v)
{
const float c1 = 0.03138777F;
const float c2 = 0.276281267F;
const float c_log2f = 1.442695022F;
v *= c_log2f;
int intPart = (int)v;
float x = (v - intPart);
float xx = x * x;
float v1 = c_log2f + c2 * xx;
float v2 = x + xx * c1 * x;
float v3 = (v2 + v1);
*((int*)&v3) += intPart << 24;
float v4 = v2 - v1;
return (v3 + v4) / (v3 - v4);
}
static double m_fabs(double n)
{
return n >= 0.0 ? n : -n;
}
double m_exp(double x)
{
double res = 1;
double eps = 1e-9;
double up = 1;
double down = 1;
for (int i = 1; ;i++)
{
up *= x;
down *= i;
double delta = up / down;
res += delta;
// printf("i=%d, delta=%lf\n", i, delta);
if (m_fabs(delta) < eps)
break;
}
return res;
}
static double m_tanh(double x)
{
double ep = m_exp(x); // exponent of positive x
double en = m_exp(-x); // exponent of negative x
double up = ep - en;
double down = ep + en;
return up / down;
}
double fast_tanh_by_maple(double x)
{
return (-.67436811832e-5+(.2468149110712040+(.583691066395175e-1+.3357335044280075e-1*x)*x)*x)/(.2464845986383725+(.609347197060491e-1+(.1086202599228572+.2874707922475963e-1*x)*x)*x);
}
double approx_tanh_by_continues_fraction(double x)
{
double s = x * x;
double y = 9 + s / 11;
y = 7 + s / y;
y = 5 + s / y;
y = 3 + s / y;
y = 1 + s / y;
y = x / y;
return y;
}
int main()
{
double x;
while (true)
{
printf("Please input a real number: ");
scanf("%lf", &x);
double y1 = tanh(x);
double y2 = tanh_c3(x);
double y3 = m_tanh(x);
double y4 = fast_tanh_by_maple(x);
double y5 = approx_tanh_by_continues_fraction(x);
printf(" tanh(%lf) = %lf\n", x, y1);
printf(" tanh_c3(%lf) = %lf\n", x, y2);
printf(" m_tanh(%lf) = %lf\n", x, y3);
printf(" fast_tanh_by_maple(%lf) = %lf\n", x, y4);
printf(" approx_tanh_by_continues_fraction(%lf) = %lf\n", x, y5);
}
return 0;
}
运行结果
gcc tanh.c -lm
zz@Legion-R7000P% ./a.out
Please input a real number: 0.345
tanh(0.345000) = 0.331934
tanh_c3(0.345000) = 0.331935
m_tanh(0.345000) = 0.331934
fast_tanh_by_maple(0.345000) = 0.331907
approx_tanh_by_continues_fraction(0.345000) = 0.331934
也尝试了其他输入如 x=257
, 整体上看 Gauss 给出的 Continued Fraction 的精度损失更小一些,速度也还算快,打算在 lenet-5 代码中使用它:
double approx_tanh_by_continues_fraction(double x)
{
double s = x * x;
double y = 9 + s / 11;
y = 7 + s / y;
y = 5 + s / y;
y = 3 + s / y;
y = 1 + s / y;
y = x / y;
return y;
}
5. 其他
K-Tanh 7 基于 AVX512 指令给出了5倍加速的实现。
[【Tanh的标量实现】]8 则考虑了 Inf/Nan 等情况, 并使用了
tanh
(
x
)
=
e
2
x
−
1
e
2
x
+
1
=
1
−
2
e
2
x
+
1
\tanh(x) = \frac{e^{2x}-1}{e^{2x} + 1} = 1 - \frac{2}{e^{2x}+1}
tanh(x)=e2x+1e2x−1=1−e2x+12
这一等效公式计算。
References
scratch lenet(8): C语言实现 exp(x) 的计算 ↩︎
Rapid approximation of tanh(x) ↩︎
https://math.stackexchange.com/a/107302 ↩︎
https://math.stackexchange.com/a/3485944 ↩︎
continued fraction for the hyperbolic tangent ↩︎
https://math.stackexchange.com/a/107295 ↩︎
K-TANH: EFFICIENT TANH FOR DEEP LEARNING ↩︎
【Tanh的标量实现】 ↩︎