文章目录
- 1、没有使用残差连接的网络难以训练
- 2、loss 不下降的原因
- 3、使用了残差连接的网络可以高效训练
1、没有使用残差连接的网络难以训练
经典的 SegNet 网络结构如下:
在使用上图所示的 SegNet 作为噪声预测网络训练扩散模型(DDPM)时,遇到了 loss 无法下降的问题:
可以看到,loss 值快速下降到一个固定值就不再下降了。我尝试多次调整学习率依然如此。
2、loss 不下降的原因
从模型训练的本质上来看,loss 值固定就表示模型参数没有变化,即参数没有更新。那么参数为什么或没有更新呢?
梯度下降法更新参数的原理为:
所以参数没有更新的原因极有可能是梯度为0,即出现了 “梯度消失” 的现象。事实上,SegNet 有 20个卷积层和10个上/下采样层,总共有30层。这是一个层数比较多的网络,因此出现梯度消失是很正常的。
为了解决 “梯度消失” 的问题,我们自然能想到使用残差连接,这是解决梯度消失最有效的方法之一。
3、使用了残差连接的网络可以高效训练
下图是我在 SegNet 加了四个残差连接(红色箭头)的网络结构:
loss 变化为:
可以明显看到,仅仅加了几个残差连接,网络就可以顺利地训练了。