01、引言
本文方法源于Youssef Mansour和Reinhard Heckel撰写的论文《Zero-Shot Noise2Noise: Efficient Image Denoising without any Data》,该文作者探索了一种不需要任何数据且高效的高效图像去噪方法。
该方法使用两个固定的内核对噪声图像进行卷积,以创建一对降采样的图像。然后用一致性损失训练一个简单的2层CNN,将一个下采样的图像映射到另一个。
本文涵盖主题:方法介绍、代码复现两个主题。
本文主要结果汇总仪表板:
本期内容『数据+代码』已上传百度网盘。
有需要的朋友可以关注公众号【小Z的科研日常】,后台回复关键词[图像去噪]获取。
02、简要介绍
ZS-N2N使用简单的2层网络,在没有任何训练数据或噪声分布知识的情况下,可以以低计算成本实现高质量的图像去噪,对像素独立噪声的去噪效果良好。适合于数据稀缺和计算资源有限的情况下使用。
该方法只对噪声统计进行了最小的假设(像素级的独立性),并且不需要训练数据。不需要明确的噪声模型,因此适用于各种噪声类型,并且可以在噪声分布或水平未知的情况下使用。
对噪声的唯一假设是,它是无结构的,而且平均值为零。
该论文提出的方法是无需数据集和噪声模型,与现有的方法相比,在泛化、去噪质量和计算资源之间实现了更好的权衡。
如上图所示。我们与标准的零拍基线,包括BM3D,以及最近的基于神经网络的算法DIP [UVL18]和S2S [Qua+20]进行比较。
只有BM3D比我们的方法快,但在非高斯噪声上取得的效果很差。
只有S2S有时优于我们的方法,但速度要慢几个数量级,在低噪声水平上经常失败[KLS22],并且需要集合才能达到可接受的性能。
03、代码复现
首先,我们加载一张测试图像并展示:
img_path = "C:/Users/asus/Desktop/Z/Kodak24/kodim07.png"
clean_img = Image.open(img_path)
clean_img = clean_img.convert("RGB")
clean_img = torch.from_numpy(np.array(clean_img)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
#展示加载的测试图像
show_img = clean_img.squeeze(0).permute((1, 2, 0)).numpy()
plt.imshow(show_img)
plt.axis('off')
plt.title('Original Picture')
plt.show()
原图如下:
可以看到原图是一张清晰的彩色图像。接下来,我们添加噪声:
noise_type = 'gauss' # 可选择'gauss' 或 'poiss'
noise_level = 50 # 高斯的像素范围是0-255,裂变的像素范围是0-1。
def add_noise(x, noise_level):
if noise_type == 'gauss':
noisy = x + torch.normal(0, noise_level / 255, x.shape)
noisy = torch.clamp(noisy, 0, 1)
elif noise_type == 'poiss':
noisy = torch.poisson(noise_level * x) / noise_level
return noisy
noisy_img = add_noise(clean_img, noise_level)
# 展示添加噪声后的图像
show_noisy_img = noisy_img.squeeze(0).permute((1, 2, 0)).numpy()
plt.imshow(show_noisy_img)
plt.axis('off')
plt.title(f'{noise_type.capitalize()} Noise ({noise_level})')
plt.show()
可以看到,我们成功地为图像添加了高斯噪声。接下来,我们将图像张量PyTorch 变量 device 指定的设备中,并定义一个2层的CNN网络,用于降噪:
device = 'cuda'
clean_img = clean_img.to(device)
noisy_img = noisy_img.to(device)
class network(nn.Module):
def __init__(self, n_chan, chan_embed=48):
super(network, self).__init__()
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.conv1 = nn.Conv2d(n_chan, chan_embed, 3, padding=1)
self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1)
self.conv3 = nn.Conv2d(chan_embed, n_chan, 1)
def forward(self, x):
x = self.act(self.conv1(x))
x = self.act(self.conv2(x))
x = self.conv3(x)
return x
n_chan = clean_img.shape[1]
model = network(n_chan)
model = model.to(device)
这是该方法的神经网络模型,其中 chan_embed 是嵌入通道数,默认为48。该网络采用了 LeakyReLU 激活函数和卷积操作。
接下来介绍图像对下采样器,它通过对非重叠斑块中的对角线像素进行平均,输出两个空间分辨率为一半的下采样图像,如下图所示。它是通过将图像与两个固定的内核进行卷积实现。其中卷积的跨度为2,并分别应用于每个图像通道。
接下来,我们定义一个能够对输入图像进行下采样的函数:
def pair_downsampler(img):
# img has shape B C H W
c = img.shape[1]
filter1 = torch.FloatTensor([[[[0, 0.5], [0.5, 0]]]]).to(img.device)
filter1 = filter1.repeat(c, 1, 1, 1)
filter2 = torch.FloatTensor([[[[0.5, 0], [0, 0.5]]]]).to(img.device)
filter2 = filter2.repeat(c, 1, 1, 1)
output1 = F.conv2d(img, filter1, stride=2, groups=c)
output2 = F.conv2d(img, filter2, stride=2, groups=c)
return output1, output2
显示噪声图像和其相应的降采样对,请注意降采样的图像的空间分辨率是一半:
img1, img2 = pair_downsampler(noisy_img)
img0 = noisy_img.cpu().squeeze(0).permute(1,2,0)
img1 = img1.cpu().squeeze(0).permute(1,2,0)
img2 = img2.cpu().squeeze(0).permute(1,2,0)
fig, ax = plt.subplots(1, 3,figsize=(15, 15))
ax[0].imshow(img0)
ax[0].set_title('Noisy Img')
ax[1].imshow(img1)
ax[1].set_title('First downsampled')
ax[2].imshow(img2)
ax[2].set_title('Second downsampled')
接下来,定义一个损失函数,它对输入图像和网络输出进行比较,并计算它们之间的均方误差。这个均方误差表示了噪声被降低的程度:
def mse(gt: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
loss = torch.nn.MSELoss()
return loss(gt, pred)
def loss_func(noisy_img):
noisy1, noisy2 = pair_downsampler(noisy_img)
pred1 = noisy1 - model(noisy1)
pred2 = noisy2 - model(noisy2)
loss_res = 1 / 2 * (mse(noisy1, pred2) + mse(noisy2, pred1))
noisy_denoised = noisy_img - model(noisy_img)
denoised1, denoised2 = pair_downsampler(noisy_denoised)
loss_cons = 1 / 2 * (mse(pred1, denoised1) + mse(pred2, denoised2))
loss = loss_res + loss_cons
return loss
现在,开始训练神经网络,使用Adam优化器并迭代2000次:
def train(model, optimizer, noisy_img):
loss = loss_func(noisy_img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def test(model, noisy_img, clean_img):
with torch.no_grad():
pred = torch.clamp(noisy_img - model(noisy_img), 0, 1)
MSE = mse(clean_img, pred).item()
PSNR = 10 * np.log10(1 / MSE)
return PSNR
def denoise(model, noisy_img):
with torch.no_grad():
pred = torch.clamp(noisy_img - model(noisy_img), 0, 1)
return pred
max_epoch = 2000 # training epochs
lr = 0.001 # learning rate
step_size = 1500 # number of epochs at which learning rate decays
gamma = 0.5 # factor by which learning rate decays
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
#开始去噪
for epoch in range(max_epoch):
# 初始化进度条
progress_bar = "[" + "-" * 20 + "]"
# 训练
train(model, optimizer, noisy_img)
# 更新进度条
progress_bar = list(progress_bar)
progress_bar[(epoch * 20) // max_epoch] = "#"
progress_bar = "".join(progress_bar)
# 打印进度条和当前时间点
print(f"\r{progress_bar} Epoch: {epoch + 1}/{max_epoch}", end="")
训练完成后,我们可以将网络应用于原始噪声图像并显示输出结果:
# 去噪后图像的PSNR
PSNR = test(model, noisy_img, clean_img)
print(PSNR)
#显示原图、加入噪音和去噪的图像
denoised_img = denoise(model, noisy_img)
denoised = denoised_img.cpu().squeeze(0).permute(1,2,0)
clean = clean_img.cpu().squeeze(0).permute(1,2,0)
noisy = noisy_img.cpu().squeeze(0).permute(1,2,0)
fig, ax = plt.subplots(1, 3,figsize=(15, 15))
ax[0].imshow(clean)
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].set_title('Ground Truth')
ax[1].imshow(noisy)
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].set_title('Noisy Img')
noisy_psnr = 10*np.log10(1/mse(noisy_img,clean_img).item())
ax[1].set(xlabel= str(round(noisy_psnr,2)) + ' dB')
ax[2].imshow(denoised)
ax[2].set_xticks([])
ax[2].set_yticks([])
ax[2].set_title('Denoised Img')
ax[2].set(xlabel= str(round(PSNR,2)) + ' dB')
fig.savefig('output.png', dpi=300, bbox_inches='tight')
可以看到,神经网络成功地降低了图像中的噪声。