目录
一 SCI
1 SCI网络结构
核心代码(model.py)
2 SCI损失函数
核心代码(loss.py)
3 实验
二 SCI效果
1 下载代码
2 运行
一 SCI
💜论文题目:Toward Fast, Flexible, and Robust Low-Light Image Enhancement
💚论文地址:https://arxiv.org/pdf/2204.10137
💙代码地址:https://github.com/vis-opt-group/SCI
【摘要】现有的低照度图像增强技术大多不仅难以兼顾视觉质量和计算效率,而且在未知复杂场景下往往失效。在本文中,我们开发了一种新的自校准照明( Self-Calibrated Illumination,SCI )学习框架,用于实际低照度场景中快速、灵活和鲁棒的亮化图像。具体来说,我们建立了一个带有权重共享的级联光照学习过程来处理这个任务。考虑到级联模式的计算负担,我们构造了自校准模块,实现了每个阶段结果之间的收敛,产生了仅使用单个基本块进行推理(但在以前的工作中尚未得到利用)的增益,极大地减少了计算开销。然后,我们定义了无监督训练损失,以提高模型的能力,使其能够适应一般的场景。进一步,我们对挖掘SCI固有属性(现有工作中的欠缺),包括操作不敏感适应性(在不同的设置下获得稳定的性能)和模型无关通用性(可以应用于现有的基于光照的工作中,以提高性能)进行了全面探索。最后,大量的实验和消融研究充分表明了我们在质量和效率上的优越性。在低照度人脸检测和夜间语义分割上的应用充分显示了SCI的潜在实用价值。
见图1。最近最先进的方法和我们的方法的比较。
Kin D [ 34 ]是具有代表性的成对监督方法。
EnGAN [ 11 ]考虑了非成对监督学习。
Zero DCE [ 7 ]和RUAS [ 14 ]引入了无监督学习。
我们的方法(只包含3个大小为3 × 3的卷积)也属于无监督学习。如放大区域显示的那样,这些比较方法出现了不正确的曝光,颜色失真和结构不足以降低视觉质量。相比之下,我们的结果呈现出生动的颜色和清晰的轮廓。进一步,我们报告了( b )中的计算效率( SIZE , FLOPs和TIME)和( c )中的增强( PSNR , SSIM和EME)、检测( mAP )和分割( mIoU )三种任务中5种度量指标的数值得分,可以很容易地观察到我们的方法明显优于其他方法。
更具体地说,论文的主要贡献可以归结为:
① 开发了一个权重共享的光照学习自校准模块,以保证每个阶段的结果之间的收敛性,提高曝光稳定性,并大幅降低计算负担。据我们所知,利用学习过程加速低照度图像增强算法是第一项工作。
② 在自校准模块的作用下,我们定义了无监督的训练损失来约束每个阶段的输出,从而赋予模型对不同场景的适应能力。属性分析表明,SCI具有操作不敏感的自适应性和模型无关的一般性,这是现有工作中没有发现的。
③ 进行了大量的实验来说明我们相对于其他先进方法的优越性。在暗人脸检测和夜间语义分割上的应用进一步展示了我们的实用价值。简而言之,在基于网络的低照度图像增强领域,SCI重新定义了视觉质量、计算效率和下游任务性能的峰值点。
1 SCI网络结构
见图2。SCI的整个框架。在训练阶段,SCI由光照估计和自校准模块组成。在原始低照度输入中加入自校准模块映射,作为下一阶段照度估计的输入。注意这两个模块分别是整个训练过程中的共享参数。在测试阶段,只使用了单一的光照估计模块。
❤️核心代码(model.py)
主要由EnhanceNetwork 、CalibrateNetwork、Network、Finetunemodel四个类组成。
① EnhanceNetwork: 对输入图像进行增强。
🦋🦋🦋通过多次堆叠卷积块,来学习图像的特征。增强后的图像通过与输入相加并进行截断,以确保像素值在合理范围内。
首先,通过__init__初始化超参数和网络层。
然后,将输入图像通过3*3的卷积层,得到特征 fea,对特征fea多次应用相同的卷积块进行叠加,通过输出卷积层获得最终的特征 fea。
接着,将生成的特征与输入图像相加,得到增强后的图像,通过 clamp 函数将图像像素值限制在 0.0001 和 1 之间。
最后,返回增强后的图像 illu。
② CalibrateNetwork:定义了一个校准网络。
🦋🦋🦋在前向传播时,输入经过一系列卷积操作后,对于最终的特征 fea再与原始输入相减,得到最终的增益调整结果delta。
③ Network:组合了上述图像增强网络 (EnhanceNetwork) 和校准网络(CalibrateNetwork),并多次执行这两个操作。
首先,初始化网络结构,并创建EnhanceNetwork、CalibrateNetwork以及loss损失函数的实例。
接着,定义权重初始化的方法。
然后,通过多次迭代,每次迭代中进行下述的步骤:
◆ 将当前输入保存到列表中。
◆ 使用图像增强网络 EnhanceNetwork 处理当前输入,得到增强后的图像。
◆ 计算增强前后的比例,并将比例值限制在 [0, 1] 范围内。
◆ 使用校准网络 CalibrateNetwork 对比例进行校准,得到校准值。
◆ 将原始输入与校准值相加,得到下一阶段的输入。
◆ 将当前阶段的增强图像、比例、输入和校准值的绝对值保存到对应的列表中。
◆ 返回四个列表,分别包含不同阶段的增强图像、比例、输入和校准值。
最后,计算损失。
④ Finetunemodel:进行模型的微调。
import torch
import torch.nn as nn
from loss import LossFunction
class EnhanceNetwork(nn.Module):
def __init__(self, layers, channels):
super(EnhanceNetwork, self).__init__()
kernel_size = 3
dilation = 1
padding = int((kernel_size - 1) / 2) * dilation
self.in_conv = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
nn.ReLU()
)
self.conv = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
nn.BatchNorm2d(channels),
nn.ReLU()
)
self.blocks = nn.ModuleList()
for i in range(layers):
self.blocks.append(self.conv)
self.out_conv = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
def forward(self, input):
fea = self.in_conv(input)
for conv in self.blocks:
fea = fea + conv(fea)
fea = self.out_conv(fea)
illu = fea + input
illu = torch.clamp(illu, 0.0001, 1)
return illu
class CalibrateNetwork(nn.Module):
def __init__(self, layers, channels):
super(CalibrateNetwork, self).__init__()
kernel_size = 3
dilation = 1
padding = int((kernel_size - 1) / 2) * dilation
self.layers = layers
self.in_conv = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
nn.BatchNorm2d(channels),
nn.ReLU()
)
self.convs = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
nn.BatchNorm2d(channels),
nn.ReLU(),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
nn.BatchNorm2d(channels),
nn.ReLU()
)
self.blocks = nn.ModuleList()
for i in range(layers):
self.blocks.append(self.convs)
self.out_conv = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
def forward(self, input):
fea = self.in_conv(input)
for conv in self.blocks:
fea = fea + conv(fea)
fea = self.out_conv(fea)
delta = input - fea
return delta
class Network(nn.Module):
def __init__(self, stage=3):
super(Network, self).__init__()
self.stage = stage
self.enhance = EnhanceNetwork(layers=1, channels=3)
self.calibrate = CalibrateNetwork(layers=3, channels=16)
self._criterion = LossFunction()
def weights_init(self, m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
if isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1., 0.02)
def forward(self, input):
ilist, rlist, inlist, attlist = [], [], [], []
input_op = input
for i in range(self.stage):
inlist.append(input_op)
i = self.enhance(input_op)
r = input / i
r = torch.clamp(r, 0, 1)
att = self.calibrate(r)
input_op = input + att
ilist.append(i)
rlist.append(r)
attlist.append(torch.abs(att))
return ilist, rlist, inlist, attlist
def _loss(self, input):
i_list, en_list, in_list, _ = self(input)
loss = 0
for i in range(self.stage):
loss += self._criterion(in_list[i], i_list[i])
return loss
class Finetunemodel(nn.Module):
def __init__(self, weights):
super(Finetunemodel, self).__init__()
self.enhance = EnhanceNetwork(layers=1, channels=3)
self._criterion = LossFunction()
base_weights = torch.load(weights)
pretrained_dict = base_weights
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)
def weights_init(self, m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
if isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1., 0.02)
def forward(self, input):
i = self.enhance(input)
r = input / i
r = torch.clamp(r, 0, 1)
return i, r
def _loss(self, input):
i, r = self(input)
loss = self._criterion(input, i)
return loss
2 SCI损失函数
❗❗❗总损失函数:Ltotal = αLf + βLs
Lf :表示保真度;Ls :表示平滑损失。
🌸保真度损失是为了保证估计的照度与每级输入之间的像素级一致性,表示如公式(4)所示。
🌸光照的平滑特性在这个任务[ 7、34]中是一个广泛的共识。这里我们采用一个具有空间变化l1范数的光滑项[ 4 ],表示如公式(5)所示。
❤️核心代码(loss.py)
SCI使用的是无监督损失训练,由fifidelity loss和smoothing loss的线性组合构成。
◆ Fidelity Loss
🌸采用均方误差损失函数 nn.MSELoss 计算输入图像 input 与增强后的图像 illu 之间的均方误差。
正则化项基于像素梯度和其指数权重的计算。
◆ Smooth Loss
🌸采用 SmoothLoss 类的实例 self.smooth_loss 计算输入图像 input 与增强后的图像 illu 之间的光滑损失。
通过 YCbCr 色彩空间的梯度计算来衡量图像的光滑性。
import torch
import torch.nn as nn
class LossFunction(nn.Module):
def __init__(self):
super(LossFunction, self).__init__()
self.l2_loss = nn.MSELoss()
self.smooth_loss = SmoothLoss()
def forward(self, input, illu):
Fidelity_Loss = self.l2_loss(illu, input)
Smooth_Loss = self.smooth_loss(input, illu)
return 1.5*Fidelity_Loss + Smooth_Loss
class SmoothLoss(nn.Module):
def __init__(self):
super(SmoothLoss, self).__init__()
self.sigma = 10
def rgb2yCbCr(self, input_im):
im_flat = input_im.contiguous().view(-1, 3).float()
mat = torch.Tensor([[0.257, -0.148, 0.439], [0.564, -0.291, -0.368], [0.098, 0.439, -0.071]]).cuda()
bias = torch.Tensor([16.0 / 255.0, 128.0 / 255.0, 128.0 / 255.0]).cuda()
temp = im_flat.mm(mat) + bias
out = temp.view(input_im.shape[0], 3, input_im.shape[2], input_im.shape[3])
return out
# output: output input:input
def forward(self, input, output):
self.output = output
self.input = self.rgb2yCbCr(input)
sigma_color = -1.0 / (2 * self.sigma * self.sigma)
w1 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :] - self.input[:, :, :-1, :], 2), dim=1,
keepdim=True) * sigma_color)
w2 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :] - self.input[:, :, 1:, :], 2), dim=1,
keepdim=True) * sigma_color)
w3 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, 1:] - self.input[:, :, :, :-1], 2), dim=1,
keepdim=True) * sigma_color)
w4 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, :-1] - self.input[:, :, :, 1:], 2), dim=1,
keepdim=True) * sigma_color)
w5 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :-1] - self.input[:, :, 1:, 1:], 2), dim=1,
keepdim=True) * sigma_color)
w6 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, 1:] - self.input[:, :, :-1, :-1], 2), dim=1,
keepdim=True) * sigma_color)
w7 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :-1] - self.input[:, :, :-1, 1:], 2), dim=1,
keepdim=True) * sigma_color)
w8 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, 1:] - self.input[:, :, 1:, :-1], 2), dim=1,
keepdim=True) * sigma_color)
w9 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :] - self.input[:, :, :-2, :], 2), dim=1,
keepdim=True) * sigma_color)
w10 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :] - self.input[:, :, 2:, :], 2), dim=1,
keepdim=True) * sigma_color)
w11 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, 2:] - self.input[:, :, :, :-2], 2), dim=1,
keepdim=True) * sigma_color)
w12 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, :-2] - self.input[:, :, :, 2:], 2), dim=1,
keepdim=True) * sigma_color)
w13 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :-1] - self.input[:, :, 2:, 1:], 2), dim=1,
keepdim=True) * sigma_color)
w14 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, 1:] - self.input[:, :, :-2, :-1], 2), dim=1,
keepdim=True) * sigma_color)
w15 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :-1] - self.input[:, :, :-2, 1:], 2), dim=1,
keepdim=True) * sigma_color)
w16 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, 1:] - self.input[:, :, 2:, :-1], 2), dim=1,
keepdim=True) * sigma_color)
w17 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :-2] - self.input[:, :, 1:, 2:], 2), dim=1,
keepdim=True) * sigma_color)
w18 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, 2:] - self.input[:, :, :-1, :-2], 2), dim=1,
keepdim=True) * sigma_color)
w19 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :-2] - self.input[:, :, :-1, 2:], 2), dim=1,
keepdim=True) * sigma_color)
w20 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, 2:] - self.input[:, :, 1:, :-2], 2), dim=1,
keepdim=True) * sigma_color)
w21 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :-2] - self.input[:, :, 2:, 2:], 2), dim=1,
keepdim=True) * sigma_color)
w22 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, 2:] - self.input[:, :, :-2, :-2], 2), dim=1,
keepdim=True) * sigma_color)
w23 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :-2] - self.input[:, :, :-2, 2:], 2), dim=1,
keepdim=True) * sigma_color)
w24 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, 2:] - self.input[:, :, 2:, :-2], 2), dim=1,
keepdim=True) * sigma_color)
p = 1.0
pixel_grad1 = w1 * torch.norm((self.output[:, :, 1:, :] - self.output[:, :, :-1, :]), p, dim=1, keepdim=True)
pixel_grad2 = w2 * torch.norm((self.output[:, :, :-1, :] - self.output[:, :, 1:, :]), p, dim=1, keepdim=True)
pixel_grad3 = w3 * torch.norm((self.output[:, :, :, 1:] - self.output[:, :, :, :-1]), p, dim=1, keepdim=True)
pixel_grad4 = w4 * torch.norm((self.output[:, :, :, :-1] - self.output[:, :, :, 1:]), p, dim=1, keepdim=True)
pixel_grad5 = w5 * torch.norm((self.output[:, :, :-1, :-1] - self.output[:, :, 1:, 1:]), p, dim=1, keepdim=True)
pixel_grad6 = w6 * torch.norm((self.output[:, :, 1:, 1:] - self.output[:, :, :-1, :-1]), p, dim=1, keepdim=True)
pixel_grad7 = w7 * torch.norm((self.output[:, :, 1:, :-1] - self.output[:, :, :-1, 1:]), p, dim=1, keepdim=True)
pixel_grad8 = w8 * torch.norm((self.output[:, :, :-1, 1:] - self.output[:, :, 1:, :-1]), p, dim=1, keepdim=True)
pixel_grad9 = w9 * torch.norm((self.output[:, :, 2:, :] - self.output[:, :, :-2, :]), p, dim=1, keepdim=True)
pixel_grad10 = w10 * torch.norm((self.output[:, :, :-2, :] - self.output[:, :, 2:, :]), p, dim=1, keepdim=True)
pixel_grad11 = w11 * torch.norm((self.output[:, :, :, 2:] - self.output[:, :, :, :-2]), p, dim=1, keepdim=True)
pixel_grad12 = w12 * torch.norm((self.output[:, :, :, :-2] - self.output[:, :, :, 2:]), p, dim=1, keepdim=True)
pixel_grad13 = w13 * torch.norm((self.output[:, :, :-2, :-1] - self.output[:, :, 2:, 1:]), p, dim=1, keepdim=True)
pixel_grad14 = w14 * torch.norm((self.output[:, :, 2:, 1:] - self.output[:, :, :-2, :-1]), p, dim=1, keepdim=True)
pixel_grad15 = w15 * torch.norm((self.output[:, :, 2:, :-1] - self.output[:, :, :-2, 1:]), p, dim=1, keepdim=True)
pixel_grad16 = w16 * torch.norm((self.output[:, :, :-2, 1:] - self.output[:, :, 2:, :-1]), p, dim=1, keepdim=True)
pixel_grad17 = w17 * torch.norm((self.output[:, :, :-1, :-2] - self.output[:, :, 1:, 2:]), p, dim=1, keepdim=True)
pixel_grad18 = w18 * torch.norm((self.output[:, :, 1:, 2:] - self.output[:, :, :-1, :-2]), p, dim=1, keepdim=True)
pixel_grad19 = w19 * torch.norm((self.output[:, :, 1:, :-2] - self.output[:, :, :-1, 2:]), p, dim=1, keepdim=True)
pixel_grad20 = w20 * torch.norm((self.output[:, :, :-1, 2:] - self.output[:, :, 1:, :-2]), p, dim=1, keepdim=True)
pixel_grad21 = w21 * torch.norm((self.output[:, :, :-2, :-2] - self.output[:, :, 2:, 2:]), p, dim=1, keepdim=True)
pixel_grad22 = w22 * torch.norm((self.output[:, :, 2:, 2:] - self.output[:, :, :-2, :-2]), p, dim=1, keepdim=True)
pixel_grad23 = w23 * torch.norm((self.output[:, :, 2:, :-2] - self.output[:, :, :-2, 2:]), p, dim=1, keepdim=True)
pixel_grad24 = w24 * torch.norm((self.output[:, :, :-2, 2:] - self.output[:, :, 2:, :-2]), p, dim=1, keepdim=True)
ReguTerm1 = torch.mean(pixel_grad1) \
+ torch.mean(pixel_grad2) \
+ torch.mean(pixel_grad3) \
+ torch.mean(pixel_grad4) \
+ torch.mean(pixel_grad5) \
+ torch.mean(pixel_grad6) \
+ torch.mean(pixel_grad7) \
+ torch.mean(pixel_grad8) \
+ torch.mean(pixel_grad9) \
+ torch.mean(pixel_grad10) \
+ torch.mean(pixel_grad11) \
+ torch.mean(pixel_grad12) \
+ torch.mean(pixel_grad13) \
+ torch.mean(pixel_grad14) \
+ torch.mean(pixel_grad15) \
+ torch.mean(pixel_grad16) \
+ torch.mean(pixel_grad17) \
+ torch.mean(pixel_grad18) \
+ torch.mean(pixel_grad19) \
+ torch.mean(pixel_grad20) \
+ torch.mean(pixel_grad21) \
+ torch.mean(pixel_grad22) \
+ torch.mean(pixel_grad23) \
+ torch.mean(pixel_grad24)
total_term = ReguTerm1
return total_term
3 实验
见图7。在LSRW数据集上对当前最先进的低照度图像增强方法进行了视觉比较。
见图8。在一些具有挑战性的实例上进行视觉比较。更多的结果可以在补充材料中找到。
二 SCI效果
Requirements:python3.7 pytorch==1.8.0 cuda11.1
1 下载代码
git clone https://github.com/vis-opt-group/SCI.git
2 运行
python3 test.py
原图:
效果图:
至此,本文分享的内容就结束啦💕💕💕💕💕💕。