PyTorch学习笔记:nn.PReLU——PReLU激活函数
torch.nn.PReLU(num_parameters=1, init=0.25, device=None, dtype=None)
功能:逐元素对数据应用如下函数公式进行激活
PReLU
(
x
)
=
max
(
0
,
x
)
+
a
∗
min
(
0
,
x
)
\text{PReLU}(x)=\max(0,x)+a*\min(0,x)
PReLU(x)=max(0,x)+a∗min(0,x)
或者
PReLU
(
x
)
=
{
x
,
i
f
x
≥
0
a
x
,
otherwise
\begin{aligned} \text{PReLU}(x)=\left\{ \begin{matrix} x,\quad &if\quad x ≥0\\ ax,&\text{otherwise} \end{matrix} \right. \end{aligned}
PReLU(x)={x,ax,ifx≥0otherwise
此激活函数与LeakyReLU激活函数非常相似,都可以保留负激活数据,但与LeakyReLU最大的不同在于PReLU中的参数
a
a
a是可学习的,而LeakyReLU中的
a
a
a是一个定值。
函数图像:
这里与LeakyReLU图像非常相似。
输入:
num_parameters
(整数):可学习参数 a a a的数量,只有两种选择,要么定义成1,表示在所有通道上应用相同的 a a a进行激活,要么定义成输入数据的通道数,表示在所有通道上应用不同的 a a a进行激活,默认1。init
(float): a a a的初始值
注意:
- 输入数据的第二维度表示为通道维度,当输入维度小于2时,不存在通道维度,此时默认通道数为1
- 可以通过调用
.weight
方法来取出参数 a a a - 即使有多个
a
a
a,
init
也还是只能输入一个float
类型的数
代码案例
一般用法
import torch.nn as nn
import torch
PReLU = nn.PReLU()
x = torch.randn(10)
value = PReLU(x)
print(x)
print(value)
输出
# 输入
tensor([ 0.2399, -0.3208, -0.7234, 1.6305, 0.5196, -0.7686, 0.1195, -0.2320,
1.2424, -0.7216])
# 激活值
tensor([ 0.2399, -0.0802, -0.1809, 1.6305, 0.5196, -0.1922, 0.1195, -0.0580,
1.2424, -0.1804], grad_fn=<PreluBackward>)
有多个 a a a时
import torch.nn as nn
import torch
PReLU = nn.PReLU(num_parameters=3, init=0.1)
x = torch.randn(12).reshape(4,3)
value = PReLU(x)
print(x)
print(value)
print(PReLU.weight)
输出
# 输入
tensor([[-0.5554, 0.2285, 1.0417],
[ 0.0180, 0.1619, 2.1579],
[ 0.1636, -1.1147, -1.9901],
[-0.4662, 1.5423, 0.0380]])
# 输出
tensor([[-0.0555, 0.2285, 1.0417],
[ 0.0180, 0.1619, 2.1579],
[ 0.1636, -0.1115, -0.1990],
[-0.0466, 1.5423, 0.0380]], grad_fn=<PreluBackward>)
# 参数a
Parameter containing:
tensor([0.1000, 0.1000, 0.1000], requires_grad=True)
注:绘图代码
import torch.nn as nn
import torch
import numpy as np
import matplotlib.pyplot as plt
PReLU = nn.PReLU()
x = torch.tensor(np.linspace(-5,5,100), dtype=torch.float32)
value = PReLU(x)
plt.plot(x, value.detach().numpy())
plt.savefig('PReLU.jpg')
官方文档
nn.PReLU:https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html#torch.nn.PReLU