本内容主要介绍实现单图像去模糊的 MSSNet 模型。
论文:MSSNet: Multi-Scale-Stage Network for Single Image Deblurring
代码(官方):https://github.com/kky7/MSSNet
1. 背景
单图像去模糊旨在从模糊图像(由相机抖动或被拍摄物体运动引起的模糊)中恢复一张清晰的图像。由于图像模糊会严重降低图像质量和其他任务(如物体检测)的性能,过去几十年来,图像去模糊已经被广泛研究。
在深度学习之前,大部分经典的单图像去模糊方法通过交替优化方式来估计模糊核(模糊核描述了图像是如何被模糊的)和潜在的清晰图像。为了高效并准确估计模糊核和潜在清晰图像,这些经典方法广泛采用了从粗到细方案。从粗到细方案:首先在粗尺度估计一个小的模糊核和清晰图像,然后将它们作为下一尺度的初始解。在粗尺度的小尺寸的图像和模糊使得能够快速估计。此外,粗尺度的小模糊尺寸使得能够更准确估计模糊核和潜在图像。因此,从粗到细方案能够快速为下一个尺度提供准确的初始解,并提高去模糊的质量和效率。
深度学习出现后,成功将深度学习应用到了单图像去模糊领域。基于深度学习的单图像去模糊模型可以分为两大类:
- 沿用传统去模糊模型的方案,首先使用 CNN 网络估计模糊核,然后使用模糊核得到清晰图像。
- 使用深度学习网络以端到端的方式直接从模糊图像恢复得到清晰图像,不需要估计模糊核。
同时,使用深度学习网络直接模糊图像恢复得到清晰图像的模型又可以分为两类:
-
多尺度方法,例如 DeepDeblur、SRN 和 PSS-NSC 等模型。由于从粗到细方案的有效性,这类模型基本都采用从粗到细方案。它们通常采用多尺度神经网络架构,将不同尺度的子网络堆叠在一起,首先估计小尺度的潜在清晰图像,然后使用小尺度的潜在清晰图像作为指导估计大尺度的潜在清晰图像。不管是否估计模糊核,使用从粗到细方案的动机都是一样的,即 由于在粗尺度上图像和模糊尺寸很小,因此可以更有效准确地估计去模糊图像。
-
单尺度方法,例如 DMPHN、MT-RNN、MPRNet、HINet 等模型。与多尺度的主要区别在于,其只会接收一种尺度的模糊图像作为输入。
单尺度方法的作者们指出以前多尺度方案计算耗时,以及粗尺度的结果对最终去模糊质量的贡献比较低。并且这些单尺度方法在质量和计算耗时上都超越了以前的多尺度方法,使得传统的从粗到细方案看起来似乎过时了。
基于这种情况,作者通过重新审视从粗到细方案,分析了以前从粗到细方法的缺陷,这些缺陷降低了模型性能。然后为了解决这些缺陷,提出了 MSSNet(Multi-Scale-Stage Network,多尺度多阶段网络)。接下来我们一起看看作者是如何一步步提出 MSSNet 网络的。
2. 以前从粗到细方法的分析
我们先来了解一下以前从粗到细方法实现单图像去模糊的流程。图 1.1 展示了以前从粗到细方法的网络架构。其中 SRN 在相邻的尺度子网络之间还存在额外的循环连接(这些额外连接可以获得额外的性能增益),在图中没有体现。以前的从粗到细方法基本上采用相同的去模糊流程,详细步骤如下:
- 首先,通过对输入的模糊图像进行下采样,从而构建一个图像金字塔。
- 然后,从最粗尺度开始,从下采样的模糊图像估计出一张去模糊图像,对估计出的去模糊图像进行上采样,并将其输入到下一尺度的子网络中。
- 最后,下一尺度的子网络利用上一尺度的去模糊图像作为指导,从当前尺度的模糊图像中估计出一张去模糊图像。
通过对这些方法所采用的的网络架构进行分析,作者发现这些网络架构存在如下三个缺陷:
- 无视模糊尺度的网络架构。在对图像去模糊时,为了恢复具体像素的像素值,需要比模糊尺寸更大的感受野。因此,更大的模糊尺寸需要更大的感受野或者更深的网络。同样地,从粗到细方法中更细的尺度需要更深的子网络。而以前的从粗到细的模型,在不同尺度中采用相同的网络架构。
- 低效的跨尺度信息传播。以前的从粗到细的方法将去模糊图像的像素值从粗尺度传递到下一个尺度。这导致在粗尺度的特征向量中编码的丰富信息出现显著损失,并最终降低去模糊的性能。
- 下采样导致的信息损失。在生成多尺度的输入模糊图像时,以前的方法都是通过对输入的模糊图像进行反复下采样来构建图像金字塔。但是下采样会导致严重的信息损失。
3. 模型设计
3.1 网络架构
基于前面对以前从粗到细方法的分析,为了解决这些缺陷,作者提出了 MSSNet 网络,一种新的基于深度学习的单图像去模糊方法,它采用了从粗到细方案。图 1.2 展示了 MSSNet 的网络架构。和以前的从粗到细方法一样,MSSNet 由三个尺度组成,从粗到细,分别表示为 S 1 S_1 S1、 S 2 S_2 S2 和 S 3 S_3 S3。MSSNet 会预测一张残差图像 R R R,然后与模糊图像 B B B 相加从而得到一张去模糊图像 L = B + R L = B+R L=B+R。
为了解决前面分析的以前从粗到细方法的缺陷,MSSNet 采用了三种策略:
- 反映模糊尺度的阶段配置。
- 一种跨尺度的信息传播方案。
- 一种基于 Pixel-Shuffle 的多尺度方案。
接下来,我们详细了解一下每种策略的具体实现。
3.1.1 反映模糊尺度的阶段配置
为了反映模糊尺度,MSSNet 的更细尺度的子网络拥有更深的网络架构。具体的实现如下:
- MSSNet 的 S 1 S_1 S1、 S 2 S_2 S2 和 S 3 S_3 S3 分别有 1、2 和 3 个阶段(stage)网络,每个阶段网络由单个轻型的 UNet 模块组成。我们使用 U i j U_i^j Uij 表示每个阶段网络,其中 i i i 和 j j j 分别表示尺度(scale)和阶段(stage)的索引。
- 每个 UNet 模块享有相同的网络架构,但是拥有不用的权重。
- 每个 UNet 模块都可以生成残差特征,这些残差特征可以转换为残差图像,将残差图像与模糊图像相加可以得到一张去模糊图像。
3.1.2 一种跨尺度的信息传输方案
以前的多尺度网络将上采样的去模糊图像从粗到细传递到下一尺度;而 MSSNet 则传递上采样的残差特征,以促进尺度之间的有效信息传播。具体的实现如下:
- 首先,对粗尺度网络输出的残差特征,依次进行双线性上采样和 1 x 1 的卷积操作。
- 然后,将上面的输出与从细尺度模糊图像提取的特征进行拼接,再连接一个 3 x 3 的卷积操作,从而得到融合后的特征。
- 最后,将上面得到的融合特征送入后续的 UNet 网络中。
3.1.3 一种基于 Pixel-Shuffle 的多尺度方案
在生成尺度输入模糊图像时,为了避免因下采样导致的信息丢失,作者提出了一种基于 Pixel-Shuffle 的多尺度方案。我们先设定模糊图像 B B B 的尺寸为 W × H × 3 W \times H \times 3 W×H×3,对 B B B 进行下采样得到 B 2 B_2 B2,其尺寸为 W / 2 × H / 2 × 3 W/2 \times H/2 \times 3 W/2×H/2×3。具体的实现如下:
- 对于最细尺度 S 3 S_3 S3,使用模糊图像 B 3 = B B_3 = B B3=B 作为输入。
- 对于 S 2 S_2 S2,对模糊图像 B B B 进行 unshuffle 操作,从而得到 4 张尺寸为 W / 2 × H / 2 × 3 W/2 \times H/2 \times 3 W/2×H/2×3 的图像。然后,沿着通道方法堆叠这 4 张图像,从而得到一个张量 X 2 X_2 X2,其尺寸为 W / 2 × H / 2 × 12 W/2 \times H/2 \times 12 W/2×H/2×12,将其作为 S 2 S_2 S2 的输入。我们可以看到 X 2 X_2 X2 和 B 2 B_2 B2 有相同的空间尺寸(即 W / 2 × H / 2 W/2 \times H/2 W/2×H/2),但是却拥有和 B 3 B_3 B3 相同的信息。正是由于 X 2 X_2 X2 拥有更丰富的信息,从而让 S 2 S_2 S2 能够生成更准确的结果。
- 对于最粗尺度 S 1 S_1 S1,对 B 2 B_2 B2 进行相同的 unshuffle 处理操作,从而得到尺寸为 W / 4 × H / 4 × 12 W/4 \times H/4 \times 12 W/4×H/4×12 的张量 X 1 X_1 X1,将其作为 S 1 S_1 S1 的输入。大家可以会好奇为什不直接对 B B B 进行 unshuffle 操作,得到尺寸为 W / 4 × H / 4 × 48 W/4 \times H/4 \times 48 W/4×H/4×48 的张量作为输入。其实作者也有测试这种方案,但是发现这种方案会造成轻微的性能损失。
3.1.4 跨阶段和跨尺度的特征融合
MSSNet 采用了跨阶段的特征融合方案。具体地,跨阶段的特征融合方案是指在相邻阶段之间提供额外连接(图 1.2 中粉红色虚线),以促进相邻阶段之间进行更有效的信息传播。图 1.3(a)展示了跨阶段的特征融合方案。
另外,MSSNet 也采用了跨尺度的特征融合方案。同样地,其在相邻尺度之间提供额外连接(图 1.2 中绿色虚线),以促进相邻尺度之间进行更有效的信息传播。图 1.3(b)展示了跨阶段的特征融合方案。
总结:由于每种策略都简单明了,从而决定了 MSSNet 是一种简单的架构网络。尽管 MSSNet 网络架构简单,但是作者通过实验证明,在当时 MSSNet 在模型性能、网络规模和计算耗时方面可以达到最优。
3.2 训练和损失函数
在训练期间,会使用辅助层为 MSSNet 的每个阶段生成一张去模糊图像,即总共会产生 6(1+2+3=6) 张去模糊图像。需要注意的是,在推理阶段,只有图 1.2 中的 U 3 3 U_3^3 U33 才会生成去模糊图像。具体的实现如下:
- 对于 S 3 S_3 S3,在 U 3 j U_3^j U3j 后面连接一个 3 x 3 的卷积层,生成残差图像 R 3 j R_3^j R3j(其尺寸为 W × H × 3 W \times H \times 3 W×H×3),然后将其与 B 3 B_3 B3 相加得到去模糊图像 L 3 j L_3^j L3j。
- 对于 S 2 S_2 S2,在 U 2 j U_2^j U2j 后面添加一个 3 x 3 的卷积层和一个 Pixel-Shuffle 操作层,从而生成残差图像(其尺寸为 W × H × 3 W \times H \times 3 W×H×3),然后将其与 B 3 B_3 B3 相加得到去模糊图像 L 2 j L_2^j L2j。
- 对于 S 1 S_1 S1,在 U 1 j U_1^j U1j 后面添加一个 3 x 3 的卷积层和一个 Pixel-Shuffle 操作层,从而生成残差图像(其尺寸为 W / 2 × H / 2 × 3 W/2 \times H/2 \times 3 W/2×H/2×3),然后将其与 B 2 B_2 B2 相加得到去模糊图像 L 1 j L_1^j L1j。
3.2.1 损失函数
作者在训练 MSSNet 时,损失函数由内容损失 L c o n t \mathcal{L}_{cont} Lcont 和频率重构损失 L f r e q \mathcal{L}_{freq} Lfreq 组成,具体公式如式(1.1)所示:
L t o t a l = L c o n t + λ L f r e q (1.1) \mathcal{L}_{total} = \mathcal{L}_{cont} + \lambda \mathcal{L}_{freq} \tag{1.1} Ltotal=Lcont+λLfreq(1.1)
其中, λ = 0.1 \lambda = 0.1 λ=0.1。
内容损失函数采用 L1 损失,具体公式如式(1.2)所示:
L c o n t = 1 N 1 ∣ ∣ L 1 1 − L g t ↓ ∣ ∣ 1 + ∑ j = 1 2 1 N 2 ∣ ∣ L 2 j − L g t ∣ ∣ 1 + ∑ j = 1 3 1 N 3 ∣ ∣ L 3 j − L g t ∣ ∣ 1 (1.2) \mathcal{L}_{cont} = \frac{1}{N_1} ||L_1^1-L_{gt\downarrow}||_1 +\sum_{j=1}^2\frac{1}{N_2}||L_2^j-L_{gt}||_1 +\sum_{j=1}^3\frac{1}{N_3}||L_3^j-L_{gt}||_1 \tag{1.2} Lcont=N11∣∣L11−Lgt↓∣∣1+j=1∑2N21∣∣L2j−Lgt∣∣1+j=1∑3N31∣∣L3j−Lgt∣∣1(1.2)
其中 L g t L_{gt} Lgt 是真实的清晰图像, L g t ↓ L_{gt\downarrow} Lgt↓ 是 L g t L_{gt} Lgt 的下采样结果; L i j L_i^j Lij 是每个阶段生成的去模糊图像; N 1 N_1 N1、 N 2 N_2 N2 和 N 3 N_3 N3 是归一化因子,分别为 N 1 = W / 2 × H / 2 × 3 N_1=W/2 \times H/2 \times 3 N1=W/2×H/2×3 和 N 2 = N 3 = W × H × 3 N_2=N_3=W \times H \times 3 N2=N3=W×H×3。
使用频率重构损失的目的:通过最小化去模糊图像和真实图像在频域的差异,从模糊图像中恢复高频细节。具体公式如式(1.3)所示:
L f r e q = 1 N 1 ∣ ∣ F ( L 1 1 ) − F ( L g t ↓ ) ∣ ∣ 1 + ∑ j = 1 2 1 N 2 ∣ ∣ F ( L 2 j ) − F ( L g t ) ∣ ∣ 1 + ∑ j = 1 3 1 N 3 ∣ ∣ F ( L 3 j ) − F ( L g t ) ∣ ∣ 1 (1.3) \begin{aligned} \mathcal{L}_{freq} = &\frac{1}{N_1}||\mathcal{F}(L_1^1)-\mathcal{F}(L_{gt\downarrow})||_1 +\sum_{j=1}^2 \frac{1}{N_2}||\mathcal{F}(L_2^j)-\mathcal{F}(L_{gt})||_1 \\ &+\sum_{j=1}^3 \frac{1}{N_3}||\mathcal{F}(L_3^j)-\mathcal{F}(L_{gt})||_1 \end{aligned} \tag{1.3} Lfreq=N11∣∣F(L11)−F(Lgt↓)∣∣1+j=1∑2N21∣∣F(L2j)−F(Lgt)∣∣1+j=1∑3N31∣∣F(L3j)−F(Lgt)∣∣1(1.3)
其中
F
\mathcal{F}
F 为傅里叶变换。
3.3 模型变体
对于 MSSNet 模型,作者提供了三种变体,他们的主要区别:UNet 网络中的通道数进行了不同的设置。MSSNet-small 设置为 20、60 和 100,MSSNet 设置为 54、96 和 138,MSSNet-large 设置为 80、130 和 180。
参考:
[1] MSSNet: Multi-Scale-Stage Network for Single Image Deblurring
[2] https://github.com/kky7/MSSNet