torch.nn.Hardtanh
原型
CLASS torch.nn.Hardtanh(min_val=- 1.0, max_val=1.0, inplace=False, min_value=None, max_value=None)
参数
- min_val ([float]) – 线性区域的最小值,默认为
-1
- max_val ([float]) – 线性区域的最大值,默认为
1
- inplace ([bool]) – 默认为
False
定义
HardTanh ( x ) = { max_val if x > max_val min_val if x < min_val x otherwise \text{HardTanh}(x) = \begin{cases} \text{max\_val} & \text{ if } x > \text{ max\_val } \\ \text{min\_val} & \text{ if } x < \text{ min\_val } \\ x & \text{ otherwise } \\ \end{cases} HardTanh(x)=⎩ ⎨ ⎧max_valmin_valx if x> max_val if x< min_val otherwise
图
代码
import torch
import torch.nn as nn
m = nn.Hardtanh(-2, 2)
input = torch.randn(2)
output = m(input)
print("input: ", input) # input: tensor([2.1926, 0.2211])
print("output: ", output) # output: tensor([2.0000, 0.2211])
【参考】
Hardtanh — PyTorch 1.13 documentation