白话 Dropout
文章目录
- 什么是Dropout
- 理解缩放
- 举个例子
什么是Dropout
Dropout 是神经网络的一种正则化技术,它在训练时以指定的概率 p p p(常见值为 p = 0.5 p=0.5 p=0.5)丢弃一个单元(连同连接)。在测试时,所有单元都存在,但权重按 p p p 缩放(即变为 p w pw pw)。
该思想是为了防止共同适应(co-adaptation),神经网络变得过于依赖特定的连接,因为这可能是过拟合的症状。直观上,Dropout 可以被认为是创建一个隐含的神经网络集合。
按照这个定义,PyTorch 的 nn.Dropout
“使用来自伯努利分布的样本以概率
p
p
p 随机将输入张量的一些元素归零。 每个通道都将在每次转发呼叫时独立清零。”
Dropout 可以被认为是根据给定的概率 p p p 将输入张量中的一些元素随机归零。发生这种情况时,一部分输出将丢失。考虑到这一点,输出也按比例缩放 1 1 − p \frac{1}{1-p} 1−p1。
缩放使输入均值和输出均值大致相等。
理解缩放
可能很多人会对 Dropout 层如何缩放输入以及为什么要缩放输入感到困惑。这里详细解释一下。
PyTorch 官方文档 有指出:
此外,在训练期间,输出按 1 1 − p \frac{1}{1-p} 1−p1 的比例缩放。这意味着在评估时,模块只是计算一个恒等函数。
那么这是如何完成的呢?为什么要这么做呢?让我们看一下 Pytorch 中的一些代码。
创建一个丢弃率
p
=
0.4
p=0.4
p=0.4 的 Dropout 层 m
:
import torch
import numpy as np
p = 0.4
m = torch.nn.Dropout(p)
PyTorch 文档中解释道:
在训练期间,使用来自伯努利分布的样本以概率 p p p 随机将输入张量的某些元素归零。归零的元素在每次前向调用时随机化。
向 Dropout 层放置一个随机输入,并确认约 40% ( p = 0.4 p=0.4 p=0.4) 的元素已变为 0:
nbig = 5000000
inp = torch.rand(nbig, 10)
outp = m(inp)
print(f'输入中0元素的比例为: {(outp==0).numpy().mean():.5f}, p={p}')
上面代码运行后输出:
$ 输入中0元素的比例为: 0.40007, p=0.4
我们接着看一下缩放部分。
创建一个较小的随机输入并将其放入 Dropout 层。比较输入和输出:
np.random.seed(42)
inp = torch.rand(5, 4)
inp
上面代码创建一个5行4列的随机张量,输出如下:
$ tensor([[0.6485, 0.3114, 0.1626, 0.1022],
[0.7352, 0.4634, 0.8206, 0.4228],
[0.0322, 0.9399, 0.9163, 0.4169],
[0.2574, 0.0467, 0.2213, 0.6171],
[0.4146, 0.2288, 0.0388, 0.7752]])
我们可以在下面看到,通过比较下面两个张量中的非零元素,在训练期间输出按 1 1 − p \frac{1}{1-p} 1−p1 倍缩放:
outp = m(inp)
inp/(1-p)
$ tensor([[1.0808, 0.5191, 0.2710, 0.1703],
[1.2254, 0.7723, 1.3676, 0.7046],
[0.0537, 1.5665, 1.5272, 0.6948],
[0.4290, 0.0778, 0.3689, 1.0284],
[0.6909, 0.3813, 0.0646, 1.2920]])
输出 output
$ tensor([[1.0808, 0.5191, 0.2710, 0.0000],
[0.0000, 0.7723, 0.0000, 0.0000],
[0.0000, 1.5665, 1.5272, 0.6948],
[0.4290, 0.0778, 0.3689, 1.0284],
[0.6909, 0.0000, 0.0646, 0.0000]])
我们可以在代码中断言该观察:
idx_nonzero = outp!=0
assert np.allclose(outp[idx_nonzero].numpy(), (inp/(1-p))[idx_nonzero].numpy())
那么为什么要这样做呢?
基本上,在评估/测试/推理期间,Dropout 层变成了一个恒等函数并且不改变它的输入。由于 Dropout 仅在训练期间处于活动状态而不在推理期间处于活动状态,因此在没有缩放的情况下,预期输出在推理期间会更大,因为元素不再被随机丢弃(设置为 0)。但是我们希望不管通过还是不通过 Dropout 层,预期输出都相同。因此,在训练期间,我们通过将 Dropout 层的输出放大 1 1 − p \frac{1}{1−p} 1−p1 的比例因子来进行补偿。 p p p 越大意味着 Dropout 越激进,这意味着我们需要的补偿越多,即比例因子 1 1 − p \frac{1}{1−p} 1−p1 越大。
下面的代码演示了比例因子如何将输出恢复到与输入相同的比例。
inp = torch.rand(nbig, 10)
outp = m(inp)
print(f'dropout 层的平均输出 ({outp.mean():.4f}) 接近平均输入 ({inp.mean():.4f})')
$ dropout 层的平均输出 (0.5000) 接近平均输入 (0.5000)
举个例子
下面通过一个包含 100 个张量的例子,演示 Dropout 及其缩放比例如何影响输入。
import torch
import torch.nn as nn
# 生成 100 个 1
x = torch.ones(100)
$ tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
当丢弃率为
p
=
0.1
p = 0.1
p=0.1 时,大约有 10 个值应该置为 0。缩放率为
1
1
−
0.1
=
1
0.9
=
1.
1
˙
\frac{1}{1-0.1} = \frac{1}{0.9} = 1.\dot{1}
1−0.11=0.91=1.1˙
# 输入 Dropout 层
output = nn.Dropout(p=0.1)(x)
$ tensor([1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111,
1.1111, 1.1111, 1.1111, 1.1111, 0.0000, 1.1111, 1.1111, 1.1111, 1.1111,
1.1111, 0.0000, 0.0000, 1.1111, 0.0000, 1.1111, 1.1111, 1.1111, 1.1111,
1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111,
1.1111, 1.1111, 0.0000, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111,
1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111,
1.1111, 1.1111, 1.1111, 0.0000, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111,
1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111,
1.1111, 1.1111, 0.0000, 1.1111, 1.1111, 1.1111, 1.1111, 0.0000, 0.0000,
1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 1.1111,
1.1111, 1.1111, 1.1111, 1.1111, 1.1111, 0.0000, 1.1111, 1.1111, 1.1111,
1.1111])
结果如我们预期所示,其中 10 个值被完全归零,结果被缩放以确保输入和输出具有相同的均值——或尽可能接近。
print(x.mean(), output.mean())
$ tensor(1.) tensor(1.0000)
在本例中,输入和输出的均值为 1.0。