torch.nn.Hardshrink
CLASS torch.nn.Hardshrink(lambd=0.5)
参数
- lambd ([float]) – the
λ
\lambda
λ 默认为
0.5
定义
HardShrink
(
x
)
=
{
x
,
if
x
>
λ
x
,
if
x
<
−
λ
0
,
otherwise
\text{HardShrink}(x) = \begin{cases} x, & \text{ if } x > \lambda \\ x, & \text{ if } x < -\lambda \\ 0, & \text{ otherwise } \end{cases}
HardShrink(x)=⎩
⎨
⎧x,x,0, if x>λ if x<−λ otherwise
图
代码
import torch
import torch.nn as nn
m = nn.Hardshrink()
input = torch.randn(2)
output = m(input)
print("input: ", input) # input: tensor([ 0.2078, -1.4333])
print("output: ", output) # output: tensor([ 0.0000, -1.4333])
【参考】
Hardshrink — PyTorch 1.13 documentation