论文:Multi-level Wavelet-CNN for Image Restoration
源码:GitHub - lpj0/MWCNN: Multi-level Wavelet-CNN for Image Restoration
目录
一、背景和出发点
二、创新点
三、MWCNN具体实现
四、DWT与池化运算和膨胀卷积相关性证明
五、DWT、IWT代码实现
六、实验
七、总结
一、背景和出发点
在低水平视觉中,感受野大小和效率之间的权衡是一个至关重要的问题。普通卷积网络(CNN)通常以牺牲计算成本为代价来扩大感受野。最近,扩张滤波被用来解决这个问题。
在本文中,提出了一种新的多层小波CNN(MWCNN)模型,以更好地权衡感受野大小和计算效率。
二、创新点
1. 提出了一种新的MWCNN模型,在效率和恢复性能之间取得了更好的平衡,扩大了接受野。
2. 由于DWT具有良好的时频局域性,因此具有良好的细节保留能力。
3. 在图像去噪、SIS-R和JPEG图像deblocking方面具有最新性能
三、MWCNN具体实现
MWCNN结构图如下:
步骤:使用DWT变换代替每一级下采样,IWT还原代替每一级上采样。其余与UNet基本一致。
每一层由3×3滤波器卷积(Conv)、批量归一化(BN)和校正线性单元(ReLU)操作组成。对于最后一个CNN块的最后一层,采用不含BN和ReLU的Conv对残差图像进行预测。
MWCNN的目标函数,如下所示(新的loss):
四、DWT与池化运算和膨胀卷积相关性证明
(1) 四种波滤器
低频LL,高频LH、HL、HH波滤器分别定义为:
证明1:可以看出,其实就是求和池化操作。
(2) 四种高频子带
四种高频子带通过以下公式可得:
实际上,是由特征图x分别上述滤波器以不同次序相乘,再根据公式进行相加。
上式合并,经过Dwt的特征图X的第(i,j)个值可写为可写作:
其中k是3×3卷积核。
膨胀因子为2的卷积可写作:
可证明,膨胀因子为2的卷积进行的膨胀滤波可表示为:首先将一个图像通过dwt分解为四个子图像,然后在这些子图像上使用共享的标准卷积核。
五、DWT、IWT代码实现
DWT和IWT代码实现(使用哈尔小波变换):
def dwt_init(x):
x01 = x[:, :, 0::2, :] / 2
x02 = x[:, :, 1::2, :] / 2
x1 = x01[:, :, :, 0::2]
x2 = x02[:, :, :, 0::2]
x3 = x01[:, :, :, 1::2]
x4 = x02[:, :, :, 1::2]
x_LL = x1 + x2 + x3 + x4
x_HL = -x1 - x2 + x3 + x4
x_LH = -x1 + x2 - x3 + x4
x_HH = x1 - x2 - x3 + x4
return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
def iwt_init(x):
r = 2
in_batch, in_channel, in_height, in_width = x.size()
# print([in_batch, in_channel, in_height, in_width])
out_batch, out_channel, out_height, out_width = in_batch, int(
in_channel / (r ** 2)), r * in_height, r * in_width
x1 = x[:, 0:out_channel, :, :] / 2
x2 = x[:, out_channel:out_channel * 2, :, :] / 2
x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
return h
class DWT(nn.Module):
def __init__(self):
super(DWT, self).__init__()
self.requires_grad = False
def forward(self, x):
return dwt_init(x)
class IWT(nn.Module):
def __init__(self):
super(IWT, self).__init__()
self.requires_grad = False
def forward(self, x):
return iwt_init(x)
分解(代替下采样):
dwt_module=DWT()
x=Image.open('./test.png')
# x=Image.open('./mountain.png')
x=transforms.ToTensor()(x)
x=torch.unsqueeze(x,0)
x=transforms.Resize(size=(256,256))(x)
subbands=dwt_module(x)
# 分解
title=['LL','HL','LH','HH']
plt.figure()
for i in range(4):
plt.subplot(2,2,i+1)
temp=torch.permute(subbands[0,3*i:3*(i+1),:,:],dims=[1,2,0])
plt.imshow(temp)
plt.title(title[i])
plt.axis('off')
plt.show()
重构(代替上采样):
dwt_module=DWT()
x=Image.open('./test.png')
# x=Image.open('./mountain.png')
x=transforms.ToTensor()(x)
x=torch.unsqueeze(x,0)
x=transforms.Resize(size=(256,256))(x)
subbands=dwt_module(x)
# 重构
title=['Original Image','Reconstruction Image']
reconstruction_img=IWT()(subbands).cpu()
ssim_value=ssim(x,reconstruction_img) # 计算原图与重构图之间的结构相似度
print("SSIM Value:",ssim_value) # tensor(1.)
show_list=[torch.permute(x[0],dims=[1,2,0]),torch.permute(reconstruction_img[0],dims=[1,2,0])]
plt.figure()
for i in range(2):
plt.subplot(1,2,i+1)
plt.imshow(show_list[i])
plt.title(title[i])
plt.axis('off')
plt.show()
六、实验
数据集:Berkeley Segmentation dataset、DIV2K 和 Waterloo Exploration Database 。
1. 去噪声实验
表明MWCNN去噪效果最好。
2. 性能对比
可见MWCNN在PSNR和SSIM指标方面都表现良好。
七、总结
提出了一种用于图像恢复的多层小波cnn(MWCNN)结构,该结构由收缩子网络和扩展子网络组成。收缩子网由多级 D WT 和 C NN 块组成,扩展子网由多级IWT和CNN 块组 成。由于 D WT 的可逆性、频率性和位置性,MWCNN可以安全地进行子采样而不丢失信息,并且可以有效地从退化的观测中恢复细节纹理和尖锐结构。结果表明, MWCNN可以在效率和性能之间取得更好的平衡,从而扩大接收域。大量的实验证明了MWCNN在图像去噪、SISR 和JPEG压缩、伪影去除、恢复三个任务上的有效性和效率。