【低照度图像增强系列(7)】RDDNet算法详解与代码实现(同济大学|ICME)

news2024/10/5 14:18:50

前言

☀️ 在低照度场景下进行目标检测任务,常存在图像RGB特征信息少提取特征困难目标识别和定位精度低等问题,给检测带来一定的难度。

     🌻使用图像增强模块对原始图像进行画质提升,恢复各类图像信息,再使用目标检测网络对增强图像进行特定目标检测,有效提高检测的精确度。

      ⭐本专栏会介绍传统方法、Retinex、EnlightenGAN、SCI、Zero-DCE、IceNet、RRDNet、URetinex-Net等低照度图像增强算法。

👑完整代码已打包上传至资源→低照度图像增强代码汇总

目录

前言

🚀一、RDDNet介绍 

☀️1.1 RDDNet简介   

研究背景 

算法框架 

损失函数

🚀二、RDDNet核心代码

 ☀️2.1 网络模型—RRDNet.py

 ☀️2.2 损失函数—loss_functions.py

(1)重构损失——reconstruction_loss

(2)光照损失——illumination_smooth_loss

(3)反射损失——reflectance_smooth_loss

(4)噪声损失——noise_loss

  ☀️2.3 Retinex操作—pipline.py

🚀三、RDDNet代码复现

☀️3.1 环境配置

☀️3.2 运行过程

☀️3.3 运行效果

 

🚀一、RDDNet介绍 

学习资料:

  • 论文题目:《ZERO-SHOT RESTORATION OF UNDEREXPOSED IMAGES VIA ROBUST RETINEX DECOMPOS》(通过鲁棒性 Retinex 分解对曝光不足的图像进行零样本恢复)
  • 论文讲解:ICME| RRDNet《ZERO-SHOT RESTORATION OF UNDEREXPOSED IMAGES VIA ROBUST RETINEX DECOMPOS》论文超详细解读(翻译+精读)
  • 原文地址:Zero-Shot Restoration of Underexposed Images via Robust Retinex Decomposition | IEEE Conference Publication | IEEE Xplore
  • 源码地址:代码export.arxiv.org/pdf/2109.05838v2.pdf

☀️1.1 RDDNet简介   

RRDNet同济大学在2020年提出来的一种新的三分支全卷积神经网络,认为图像由三部分构成:光照分量反射分量噪声分量。在没有pair对的情况下实现低光图像增强,通过对loss进行迭代来有效估计出噪声和恢复光照。 

研究背景 

  • 曝光不足的图像由于能见度差和黑暗中的潜在噪声,通常会出现严重的质量下降。
  • 现有的图像增强方法忽略了噪声,因此使用带噪声分量的Retinex模型作为基础。
  • 基于学习(数据驱动)的方法限制了模型的泛化能力,因此提出zero-shot的学习模式。

算法框架 

  1. 通过三分支网络把输入图像分解为反射图、光照图和噪声图三个分量。
  2. 通过Gamma变换调整光照图,再计算得到无噪声的反射图。
  3. 结合光照图和反射图,重构得到最终结果。 

损失函数

1. Retinex重构损失,取最大通道值作为初始光照图,用来约束光照图。在光照图的基础上约束反射图和噪声。

2. 纹理增强损失,通过平滑光照图可以帮助增强反射图的纹理。具体损失公式是带有权重的总变分损失,权重的设计规则是,梯度大的地方权重小,即权重与梯度成负相关即可,这里是将梯度经过高斯滤波放在分母。

3. 光照指导的噪声损失,根据噪声随着光照的变大而变大的假设,可以使用光照图来做权重指导,其次考虑两点:

(1)假定噪声范围限定

(2)通过平滑反射图来得到噪声,本身并没有直接得到噪声的损失,只是通过对反射图做总变分约束来去噪


🚀二、RDDNet核心代码

 代码框架如图所示:

(图片来源:【代码笔记】RRDNet 网络-CSDN博客 谢谢大佬!@chaiky) 

 ☀️2.1 网络模型—RRDNet.py

import torch
import torch.nn as nn

class RRDNet(nn.Module):
    def __init__(self):
        super(RRDNet, self).__init__()

 #----------- 1.illumination(光照估计)---------------------------#
        self.illumination_net = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 1, 3, 1, 1),

        )

 #----------- 2.reflectance(反射率估计)---------------------------#
        self.reflectance_net = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 3, 3, 1, 1)
        )

 #----------- 3.noise(噪声估计)---------------------------#
        self.noise_net = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 3, 3, 1, 1)
        )

    def forward(self, input):
        illumination = torch.sigmoid(self.illumination_net(input))
        reflectance = torch.sigmoid(self.reflectance_net(input))
        noise = torch.tanh(self.noise_net(input))

        return illumination, reflectance, noise

  我们可以对照上图左边的结构来理解代码。

  • illumination_net:  主要是负责对输入图像进行处理以获取光照信息,包括一系列卷积层和ReLU激活函数,最终输出一个通道数为1的图像,表示光照强度

  • reflectance_net:  主要是负责提取输入图像的反射率信息,同样包括一系列卷积层和ReLU激活函数,最终输出一个通道数为3的图像,表示反射率在RGB通道上的分布。

  • noise_net:  主要是则用于估计输入图像的噪声信息,同样由一系列卷积层和ReLU激活函数组成,最终输出一个通道数为3的图像,表示噪声在RGB通道上的分布。

 最后,illumination_netreflectance_net的输出经过sigmoid函数处理,而noise_net的输出则经过tanh函数处理。


 ☀️2.2 损失函数—loss_functions.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import conf

 #----------- 1.reconstruction_loss:计算重构损失---------------------------#
def reconstruction_loss(image, illumination, reflectance, noise):
    reconstructed_image = illumination*reflectance+noise
    return torch.norm(image-reconstructed_image, 1)

 #----------- 2.gradient: 计算输入图像在水平和垂直方向上的梯度--------------------#
def gradient(img):
    height = img.size(2)
    width = img.size(3)
    gradient_h = (img[:,:,2:,:]-img[:,:,:height-2,:]).abs()
    gradient_w = (img[:, :, :, 2:] - img[:, :, :, :width-2]).abs()
    gradient_h = F.pad(gradient_h, [0, 0, 1, 1], 'replicate')
    gradient_w = F.pad(gradient_w, [1, 1, 0, 0], 'replicate')
    gradient2_h = (img[:,:,4:,:]-img[:,:,:height-4,:]).abs()
    gradient2_w = (img[:, :, :, 4:] - img[:, :, :, :width-4]).abs()
    gradient2_h = F.pad(gradient2_h, [0, 0, 2, 2], 'replicate')
    gradient2_w = F.pad(gradient2_w, [2, 2, 0, 0], 'replicate')
    return gradient_h*gradient2_h, gradient_w*gradient2_w

 #----------- 3.normalize01: 将输入图像进行归一化到0到1的范围内---------------------#
def normalize01(img):
    minv = img.min()
    maxv = img.max()
    return (img-minv)/(maxv-minv)

 #----------- 4.gaussianblur3: 3通道的高斯模糊---------------------------#
def gaussianblur3(input):
    slice1 = F.conv2d(input[:,0,:,:].unsqueeze(1), weight=conf.gaussian_kernel, padding=conf.g_padding)
    slice2 = F.conv2d(input[:,1,:,:].unsqueeze(1), weight=conf.gaussian_kernel, padding=conf.g_padding)
    slice3 = F.conv2d(input[:,2,:,:].unsqueeze(1), weight=conf.gaussian_kernel, padding=conf.g_padding)
    x = torch.cat([slice1,slice2, slice3], dim=1)
    return x

 #----------- 5.illumination_smooth_loss: 计算光照平滑损失---------------------------#
def illumination_smooth_loss(image, illumination):
    gray_tensor = 0.299*image[0,0,:,:] + 0.587*image[0,1,:,:] + 0.114*image[0,2,:,:]
    max_rgb, _ = torch.max(image, 1)
    max_rgb = max_rgb.unsqueeze(1)
    gradient_gray_h, gradient_gray_w = gradient(gray_tensor.unsqueeze(0).unsqueeze(0))
    gradient_illu_h, gradient_illu_w = gradient(illumination)
    weight_h = 1/(F.conv2d(gradient_gray_h, weight=conf.gaussian_kernel, padding=conf.g_padding)+0.0001)
    weight_w = 1/(F.conv2d(gradient_gray_w, weight=conf.gaussian_kernel, padding=conf.g_padding)+0.0001)
    weight_h.detach()
    weight_w.detach()
    loss_h = weight_h * gradient_illu_h
    loss_w = weight_w * gradient_illu_w
    max_rgb.detach()
    return loss_h.sum() + loss_w.sum() + torch.norm(illumination-max_rgb, 1)

 #----------- 6.reflectance_smooth_loss:计算反射率平滑损失---------------------------#
def reflectance_smooth_loss(image, illumination, reflectance):
    gray_tensor = 0.299*image[0,0,:,:] + 0.587*image[0,1,:,:] + 0.114*image[0,2,:,:]
    gradient_gray_h, gradient_gray_w = gradient(gray_tensor.unsqueeze(0).unsqueeze(0))
    gradient_reflect_h, gradient_reflect_w = gradient(reflectance)
    weight = 1/(illumination*gradient_gray_h*gradient_gray_w+0.0001)
    weight = normalize01(weight)
    weight.detach()
    loss_h = weight * gradient_reflect_h
    loss_w = weight * gradient_reflect_w
    refrence_reflect = image/illumination
    refrence_reflect.detach()
    return loss_h.sum() + loss_w.sum() + conf.reffac*torch.norm(refrence_reflect - reflectance, 1)

 #----------- 7.noise_loss: 计算噪声损失---------------------------#
def noise_loss(image, illumination, reflectance, noise):
    weight_illu = illumination
    weight_illu.detach()
    loss = weight_illu*noise
    return torch.norm(loss, 2)
(1)重构损失——reconstruction_loss

图像的分解组件必须满足Robust Retinex的公式,将RGB三个通道中最大强度值S的初始值,在此基础上约束反射图和噪声。

(2)光照损失——illumination_smooth_loss

通过平滑的光照图可以增强暗区域的纹理细节,公式中x和y是水平和垂直方向,Wx和Wy是确保图像平滑的权重参数。

权重与梯度呈反比,梯度大的地方权重小,梯度小的地方权重大,因此将高斯滤波G放在分母,这里公式中的I是输入图像转换成的灰度图,Wy的计算方式和Wx的相同。

(3)反射损失——reflectance_smooth_loss

通过平滑反射图来得到噪声,本身并没有直接得到噪声的损失,只是通过对反射图做总变分约束来去噪。

(4)噪声损失——noise_loss

为了增加图像的清晰度增加了图像的对比度,与此同时,图像的噪声也被放大,出于以下两点限制噪声:

  1. 噪声的范围需要被限制。
  2. 噪声可以平滑的反射图限制。


  ☀️2.3 Retinex操作—pipline.py

import os
import numpy as np
import cv2
import torch
import torch.optim as optim
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import torch.nn.init as init

from model.RRDNet import RRDNet
from loss.loss_functions import reconstruction_loss, illumination_smooth_loss, reflectance_smooth_loss, noise_loss, normalize01
import conf

 #----------- retinex图像增强---------------------------#
def pipline_retinex(net, img):
    img_tensor = transforms.ToTensor()(img)  # [c, h, w] #将输入图像转换为张量,并调整形状
    img_tensor = img_tensor.to(conf.device)
    img_tensor = img_tensor.unsqueeze(0)     # [1, c, h, w]

    optimizer = optim.Adam(net.parameters(), lr=conf.lr)

    # iterations:迭代优化过程
    for i in range(conf.iterations+1):
        # forward:通过网络前向传播得到光照、反射率和噪声图像。
        illumination, reflectance, noise = net(img_tensor)  # [1, c, h, w]
        # loss computing:计算总损失,并进行反向传播优化网络参数。
        loss_recons = reconstruction_loss(img_tensor, illumination, reflectance, noise)  # 重构损失
        loss_illu = illumination_smooth_loss(img_tensor, illumination) # 光照损失
        loss_reflect = reflectance_smooth_loss(img_tensor, illumination, reflectance) #反射损失
        loss_noise = noise_loss(img_tensor, illumination, reflectance, noise) # 噪声损失

        loss = loss_recons + conf.illu_factor*loss_illu + conf.reflect_factor*loss_reflect + conf.noise_factor*loss_noise

        # backward
        net.zero_grad()
        loss.backward()
        optimizer.step()

        # log:每隔 100 次迭代打印日志,显示重建损失、光照损失、反射率损失和噪声损失的数值。
        if i%100 == 0:
            print("iter:", i, '  reconstruction loss:', float(loss_recons.data), '  illumination loss:', float(loss_illu.data), '  reflectance loss:', float(loss_reflect.data), '  noise loss:', float(loss_noise.data))


    # adjustment:对增强后的图像进行调整
    adjust_illu = torch.pow(illumination, conf.gamma)
    res_image = adjust_illu*((img_tensor-noise)/illumination)# 对增强后的图像进行调整
    res_image = torch.clamp(res_image, min=0, max=1)# 对调整后的图像进行限幅操作,确保像素值在 0 到 1 之间。

    if conf.device != 'cpu':
        res_image = res_image.cpu()
        illumination = illumination.cpu()
        adjust_illu = adjust_illu.cpu()
        reflectance = reflectance.cpu()
        noise = noise.cpu()
    
    # 将处理后的张量转换为 PIL 图像
    res_img = transforms.ToPILImage()(res_image.squeeze(0))
    illum_img = transforms.ToPILImage()(illumination.squeeze(0))
    adjust_illu_img = transforms.ToPILImage()(adjust_illu.squeeze(0))
    reflect_img = transforms.ToPILImage()(reflectance.squeeze(0))
    noise_img = transforms.ToPILImage()(normalize01(noise.squeeze(0)))

    return res_img, illum_img, adjust_illu_img, reflect_img, noise_img


if __name__ == '__main__':

    # Init Model
    net = RRDNet()
    net = net.to(conf.device)

    # Test
    img = Image.open(conf.test_image_path)

    res_img, illum_img, adjust_illu_img, reflect_img, noise_img = pipline_retinex(net, img)
    res_img.save('./test/result.jpg')
    illum_img.save('./test/illumination.jpg')
    adjust_illu_img.save('./test/adjust_illumination.jpg')
    reflect_img.save('./test/reflectance.jpg')
    noise_img.save('./test/noise_map.jpg')

这段代码基本都注释了,就不再详细讲解了~


🚀三、RDDNet代码复现

☀️3.1 环境配置

  • Python 3
  • PyTorch >= 0.4.1
  • PIL >= 6.1.0
  • Opencv-python>=3.4

☀️3.2 运行过程

这个也是运行比较简单,配好环境就行 。不再过多叙述~


☀️3.3 运行效果

没错,你怎么知道我去看邓紫棋演唱会啦~ 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1688795.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

qt改变样式表 label

border:1px solid black; background-color:rgb(238,234,235); border-color:rgb(0,112,249);

Zabbix实现7x24小时架构监控

上篇:https://blog.csdn.net/Lzcsfg/article/details/138774511 文章目录 Zabbix功能介绍Zabbix平台选择安装Zabbix监控端部署MySQL数据库Zabbix参数介绍登录Zabbix WEBWEB界面概览修改WEB界面语言添加被控主机导入监控模板主机绑定模板查看主机状态查看监控数据解…

一文了解基于ITIL的运维管理体系框架

本文来自腾讯蓝鲸智云社区用户:CanWay ITIL(Information Technology Infrastructure Library)是全球最广泛使用的 IT 服务管理方法,旨在帮助组织充分利用其技术基础设施和云服务来实现增长和转型。优化IT运维,作为企业…

U-Boot menu菜单分析

文章目录 前言目标环境背景U-Boot如何自动调起菜单U-Boot添加自定义命令实践 前言 在某个厂家的开发板中,在进入它的U-Boot后,会自动弹出一个菜单页面,输入对应的选项就会执行对应的功能。如SD卡镜像更新、显示设置等: 目标 本…

Linux进程--函数 system 和 popen 的区别

system() 和 popen() 是 C 语言中用于执行外部命令的两个函数,它们的功能类似,但在使用方式和特性上有一些区别。 system() system() 函数允许您在程序中执行外部命令,并等待该命令执行完成后继续执行程序。其基本语法如下: in…

本地idea连接虚拟机linux中的docker进行打包镜像上传--maven的dockerfile-maven-plugin插件

项目名必须是英文,-,.,_,这些数字,idea需要管理员运行,因为idea控制台mvn命令需要管理员权限才能运行(maven需配置环境变量)改linux中的Docker服务文件,使用2375 进行非加密通信,然后加载重启 2.1 #修改Docker服务文件 vi /lib/systemd/system/docker.service ​ # 通常使…

深度学习基于Tensorflow卷积神经网络VGG16的CT影像识别分类

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景 随着医疗技术的快速发展,CT(Computed Tomography)影像已成为医生…

练习题(2024/5/22)

1N 皇后 II n 皇后问题 研究的是如何将 n 个皇后放置在 n n 的棋盘上,并且使皇后彼此之间不能相互攻击。 给你一个整数 n ,返回 n 皇后问题 不同的解决方案的数量。 示例 1: 输入:n 4 输出:2 解释:如上…

WAF绕过(下)

过流量检测 这里的流量检测就是在网络层的waf拦截到我们向webshell传输的数据包,以及webshell返回的数据 包,检测其中是否包含敏感信息的一种检测方式。如果是大马的情况下,可以在大马中添加多处判断代码,因此在执行大马提供的功…

设计模式12——外观模式

写文章的初心主要是用来帮助自己快速的回忆这个模式该怎么用,主要是下面的UML图可以起到大作用,在你学习过一遍以后可能会遗忘,忘记了不要紧,只要看一眼UML图就能想起来了。同时也请大家多多指教。 外观模式(Facade&a…

【Linux】高效文本处理命令

目录 一.sort命令(排序) 1.语法格式 2.常用选项 3.相关示例 3.1. 3.2. 二.unip命令(去重) 1.语法格式 2.常用选项 3.相关示例 3.1. 3.2. 三.tr命令(替换) 1.语法格式 2.常用选项 3.相关示例…

如何在 Ubuntu 24.04 (桌面版) 上配置静态IP地址 ?

如果你想在你的 Ubuntu 24.04 桌面有一个持久的 IP 地址,那么你必须配置一个静态 IP 地址。当我们安装 Ubuntu 时,默认情况下 DHCP 是启用的,如果网络上可用,它会尝试从 DHCP 服务器获取 IP 地址。 在本文中,我们将向…

MySQL主从复制(二):高可用

正常情况下, 只要主库执行更新生成的所有binlog, 都可以传到备库并被正确地执行, 备库就能达到跟主库一致的状态, 这就是最终一致性。 但是, MySQL要提供高可用能力, 只有最终一致性是不够的。 双M结构的…

2024年甘肃特岗教师招聘报名流程,速速查收哦!

2024年甘肃特岗教师招聘报名流程,速速查收哦!

WXML模板语法-事件绑定

一、 1.事件 事件是渲染层到逻辑层的通讯方式,通过事件可以将用户在渲染层产生的行为,反馈到逻辑层进行业务的处理 2.小程序中常用的事件 3.事件对象的属性列表 当事件回调触发的时候,会收到一个事件对象event,其属性为&#x…

Mysql之基本架构

1.Mysql简介 mysql是一种关系型数据库,由表结构来存储数据与数据之间的关系,同时为sql(结构化查询语句)来进行数据操作。 sql语句进行操作又分为几个重要的操作类型 DQL: Data Query Language 数据查询语句 DML: Data Manipulation Language 添加、删…

axios案例应用

1、Spring概述 Spring 是分层的 Java SE/EE 应用 full-stack 轻量级开源框架,以 IoC(Inverse Of Control: 反转控制)和 AOP(Aspect Oriented Programming:面向切面编程)为内核,提供了展现层 Spring MVC 和持久层。Spring JDBC 以及业务层事务管理等众多…

C++进阶:C++11(列表初始化、右值引用与移动构造移动赋值、可变参数模版...Args、lambda表达式、function包装器)

C进阶:C11(列表初始化、右值引用与移动构造移动赋值、可变参数模版…Args、lambda表达式、function包装器) 今天接着进行语法方面知识点的讲解 文章目录 1.统一的列表初始化1.1{}初始化1.2 initializer_listpair的补充 2.声明相关关键字2.1a…

STM32——DAC篇(基于f103)

技术笔记! 一、DAC简介(了解) 1.1 DAC概念 传感器信号采集改变电信号,通过ADC转换成单片机可以处理的数字信号,处理后,通过DAC转换成电信号,进而实现对系统的控制。 1.2 DAC的特性参数 1.3…

amis-editor 低代码可视化编辑器开发 和 使用说明

1.amis-editor可视化编辑器 React版本(推荐): GitHub - aisuda/amis-editor-demo: amis 可视化编辑器示例 https://aisuda.github.io/amis-editor-demo 建议使用react版本,好维护,升级版本更新package.json中对应版本…