Python图像处理【22】基于卷积神经网络的图像去雾

news2024/10/7 4:33:45

基于卷积神经网络的图像去雾

    • 0. 前言
    • 1. 渐进特征融合网络
    • 2. 图像去雾
      • 2.1 网络构建
      • 2.2 模型测试
    • 小结
    • 系列链接

0. 前言

单图像去雾 (dehazing) 是一个具有挑战性的图像恢复问题。为了解决这个问题,大多数算法都采用经典的大气散射模型,该模型是一种基于单一散射和均匀大气介质假设的简化物理模型,但现实环境中的雾霾表述更加复杂。

1. 渐进特征融合网络

在本节中,我们将学习如何使用输入自适应端到端深度学习预训练去雾模型,即渐进特征融合网络 (Progressive Feature Fusion Network, PFFNet),并通过使用 Pytorch 来执行模糊图像的去雾操作。渐进特征融合所采用的 U-Net 架构编码器 - 解码器网络,可直接学习从模糊图像到清晰图像的高度非线性转换函数。深度神经网架构如下图所示:

PFFNet
从以上体系结构图可以看出:

  • 编码器由五个卷积层组成,每个卷积层之后都有非线性 ReLU 激活函数;第一层用于从原始模糊图像中相对较大的局部感受野上的提取特征,然后,依次执行四次下采样卷积操作,以获取图像金字塔
  • 特征转换模块由基于残差的模块组成,深层网络可以表示非常复杂的特征,也可以学习到许多不同尺度的特征,但同时,在使用反向传播进行训练时,经常会遇到消失的梯度问题,而残差网络就是为了解决这一问题而被提出的,可以用于训练更深的网络
  • 解码器由四个反卷积层和一个卷积层组成,与编码器相反,解码器的反卷积层顺序堆叠以恢复图像结构细节

2. 图像去雾

2.1 网络构建

(1) 首先下载预训练网络模型,并导入所需的库,模块和函数:

import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage, Normalize, Resize
#from torchviz import make_dot
import matplotlib.pylab as plt 

(2) 定义与深神经网络中不同层相对应的 ConvLayerUpsampleConvLayer 类,所有网络层都继承自 Pytorchnn.module 类;每个层都需要实现自己的 init() (用于初始化参数/成员变量/层)和 forward() 方法(定义前向传播过程中的计算):

class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

class UpsampleConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
      super(UpsampleConvLayer, self).__init__()
      reflection_padding = kernel_size // 2
      self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
      self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

(3) 接下来,我们用两个 ConvLayer 类实例定义类 ResidualBlock,在 ConvLayer 类实例之间使用 PReLU 激活函数,该类同样继承自 nn.module,并定义 forward() 方法用于前向传播:

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.relu = nn.PReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out) * 0.1
        out = torch.add(out, residual)
        return out 

(4) 定义继承自 nn.conv2d 类的 MeanShift 类,通过将 requires_grad 的参数设置为 False,冻结 MeanShift 层:

class MeanShift(nn.Conv2d):
    def __init__(self, rgb_range, rgb_mean, sign):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.bias.data = float(sign) * torch.Tensor(rgb_mean) * rgb_range

        # Freeze the MeanShift layer
        for params in self.parameters():
            params.requires_grad = False

(5) 最后,根据所定义的神经网络层定义深度神经网络类 Net,该类同样需要定义 init() 方法。网络使用了五个 ConvLayer,然后使用四个 UPSampleconvLayer,最后通过 ConvLayer 层后输出,网络使用 LeakyReLU 作为激活函数。
同样,需要定义向前传播方法 forward(),并在每个激活函数后使用双线性上采样:

class Net(nn.Module):
    def __init__(self, res_blocks=18):
        super(Net, self).__init__()

        rgb_mean = (0.5204, 0.5167, 0.5129)
        self.sub_mean = MeanShift(1., rgb_mean, -1)
        self.add_mean = MeanShift(1., rgb_mean, 1)

        self.conv_input = ConvLayer(3, 16, kernel_size=11, stride=1)
        self.conv2x = ConvLayer(16, 32, kernel_size=3, stride=2)
        self.conv4x = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.conv8x = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.conv16x = ConvLayer(128, 256, kernel_size=3, stride=2)

        self.dehaze = nn.Sequential()
        for i in range(1, res_blocks):
            self.dehaze.add_module('res%d' % i, ResidualBlock(256))

        self.convd16x = UpsampleConvLayer(256, 128, kernel_size=3, stride=2)
        self.convd8x = UpsampleConvLayer(128, 64, kernel_size=3, stride=2)
        self.convd4x = UpsampleConvLayer(64, 32, kernel_size=3, stride=2)
        self.convd2x = UpsampleConvLayer(32, 16, kernel_size=3, stride=2)

        self.conv_output = ConvLayer(16, 3, kernel_size=3, stride=1)
()
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.relu(self.conv_input(x))
        res2x = self.relu(self.conv2x(x))
        res4x = self.relu(self.conv4x(res2x))

        res8x = self.relu(self.conv8x(res4x))
        res16x = self.relu(self.conv16x(res8x))

        res_dehaze = res16x
        res16x = self.dehaze(res16x)
        res16x = torch.add(res_dehaze, res16x)

        res16x = self.relu(self.convd16x(res16x))
        res16x = F.upsample(res16x, res8x.size()[2:], mode='bilinear')
        res8x = torch.add(res16x, res8x)

        res8x = self.relu(self.convd8x(res8x))
        res8x = F.upsample(res8x, res4x.size()[2:], mode='bilinear')
        res4x = torch.add(res8x, res4x)

        res4x = self.relu(self.convd4x(res4x))
        res4x = F.upsample(res4x, res2x.size()[2:], mode='bilinear')
        res2x = torch.add(res4x, res2x)

        res2x = self.relu(self.convd2x(res2x))
        res2x = F.upsample(res2x, x.size()[2:], mode='bilinear')
        x = torch.add(res2x, x)

        x = self.conv_output(x)

        return x

(6) 定义预训练模型参数位置以及模型使用的残差块数量:

rb = 13
checkpoint = "I-HAZE_O-HAZE.pth"

(7) 实例化 Net() 类并使用 load_state_dict() 方法从检查点加载预训练权重。由于我们不需要训练模型,因此使用测试模式:

net = Net(rb)
net.load_state_dict(torch.load(checkpoint)['state_dict'])
net.eval()

2.2 模型测试

(1) 接下来,使用 open() 函数读取输入图像:

im_path = "pic.png"
im = Image.open(im_path)
h, w = im.size
print(h, w)

(2) 使用 torchvision.transforms 模块中的 ToTensor() 将图像转换为张量对象以输入网络,然后使用输入图像在模型上运行正向传递过程计算输出,最后将输出转换为图像:

imt = ToTensor()(im)
imt = Variable(imt).view(1, -1, w, h)
#im = im.cuda()
with torch.no_grad():
    imt = net(imt)
out = torch.clamp(imt, 0., 1.)
out = out.cpu()
out = out.data[0]
out = ToPILImage()(out)

def plot_image(image, title=None, sz=10):
    plt.imshow(image)
    plt.title(title, size=sz)
    plt.axis('off')
plt.figure(figsize=(20,10))
plt.subplot(121), plot_image(im, 'hazed input')
plt.subplot(122), plot_image(out, 'de-hazed output')
plt.tight_layout()
plt.show() 

去雾结果

小结

图像去雾已成为计算机视觉的重要研究方向,在雾、霾等恶劣天气下拍摄的的图像通常由于大气散射的作用,图像质量严重下降使颜色偏灰白色,对比度降低,物体特征难以辨认,还会影响图像的分析与处理。因此,需要使用图像去雾技术来增强或修复图像,以改善视觉效果并便于图像的后续处理。在本节中,我们学习了一种基于卷积神经网络的图像去雾模型,通过使用训练后的模型可以显著改善图像视觉效果。

系列链接

Python图像处理【1】图像与视频处理基础
Python图像处理【2】探索Python图像处理库
Python图像处理【3】Python图像处理库应用
Python图像处理【4】图像线性变换
Python图像处理【5】图像扭曲/逆扭曲
Python图像处理【6】通过哈希查找重复和类似的图像
Python图像处理【7】采样、卷积与离散傅里叶变换
Python图像处理【8】使用低通滤波器模糊图像
Python图像处理【9】使用高通滤波器执行边缘检测
Python图像处理【10】基于离散余弦变换的图像压缩
Python图像处理【11】利用反卷积执行图像去模糊
Python图像处理【12】基于小波变换执行图像去噪
Python图像处理【13】使用PIL执行图像降噪
Python图像处理【14】基于非线性滤波器的图像去噪
Python图像处理【15】基于非锐化掩码锐化图像
Python图像处理【16】OpenCV直方图均衡化
Python图像处理【17】指纹增强和细节提取
Python图像处理【18】边缘检测详解
Python图像处理【19】基于霍夫变换的目标检测
Python图像处理【20】图像金字塔
Python图像处理【21】基于卷积神经网络增强微光图像

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1507624.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ECharts饼图图例消失踩的坑

在使用Echarts的饼图时,当时做法是在图例数小于8时显示全部的图例,在大于8的时候显示前8个图例。于是用了两种不同的方式处理。导致出现切换时间后图例不显示的情况。 错误过程: 在进行图例生成时采用了两种不同的方式: ①如果…

Redis底层源码分析系列(前提准备)

文章目录 一、 面试题二、 源码分析1. 源码导入2. 源码核心部分 一、 面试题 1. redis跳跃列表了解吗?这个数据结构有什么缺点? 2. redis项目里面怎么用? redis的数据结构都了解哪些? 3. redis的zset底层实现? redi…

深入理解Servlet

目录: ServletWeb开发历史Servlet简介Servlet技术特点Servlet在应用程序中的位置Tomcat运行过程Servlet继承结构Servlet生命周期Servlet处理请求的原理Servlet的作用HttpServletRequest对象HttpServletResponse对象ServletContext对象ServletConfig对象Cookie对象与…

Constrained Iterative LQR 自动驾驶中使用的经典控制算法

Motion planning 运动规划在自动驾驶领域是一个比较有挑战的部分。它既要接受来自上层的行为理解和决策的输出,也要考虑一个包含道路结构和感知所检测到的所有障碍物状态的动态世界模型。最终生成一个满足安全性和可行性约束并且具有理想驾驶体验的轨迹。 通常,motion plann…

微信小程序开发系列(二十八)·小程序API如何发送网络请求以及网络请求失败后的解决方法

目录 1. 小程序API介绍 2. 网络请求 2.1 网络请求失败解决方法 2.2 如何跳过域名校验 1. 小程序API介绍 小程序开发框架提供丰富的微信原生API,可以方便的调起微信提供的能力,例如:获取用户信息、微信登录、微信支付等,小…

“antd“: Unknown word.cSpell

你遇到的问题是 VS Code 的 Code Spell Checker 插件在检查拼写时,将 "antd" 标记为未知单词。"antd" 是 Ant Design 的缩写,是一个流行的 React UI 库,不是一个英语单词,所以 Spell Checker 会将其标记为错误…

Microsoft SQL Server 编写汉字转拼音函数

目录 应用场景 举例 函数实现 小结 应用场景 在搜索应用中,我们一般会提供一个搜索框,输入关健字,点击查询按钮以获取结果数据。大部分情况我们会提供模糊查询的形式以在一个或多个字段进行搜索以获取结果。这样可以简化用户的操作&…

游戏资讯网站系统aspnet+sqlserver

aspnet游戏资讯网站系统本网站采用三层架构编写 有增删查改全部功能 使用了objectDataSource 新技术:采用bootstrap前端框架 dntb控件 随着游戏行业的快速发展,越来越多的玩家需要一个了解全面游戏资讯信息的平台。 充分做了可行性分析后,我…

从零学习Linux操作系统 第三十五部分 Ansible中的角色

一、理解roles在企业中的定位及写法 #ansible 角色简介# Ansible roles 是为了层次化,结构化的组织Playbookroles就是通过分别将变量、文件、任务、模块及处理器放置于单独的目录中,并可以便捷地include它们roles一般用于基于主机构建服务的场景中&…

如何将视频内容转换为文字文稿?这三款工具助您实现视频转写!

在日常生活中,有时我们需要将视频中的内容转换为文字文稿以便于搜索、编辑或分享。但选择合适的视频转文字软件可能让人感到困惑。今天我将为您推荐三款优秀的视频转文字工具,它们操作简单、准确高效,能够帮助您快速完成视频内容转写的工作。…

MySQL三种日志

一、undo log(回滚日志) 1.作用: (1)保证了事物的原子性 (2)通过read view和undo log实现mvcc多版本并发控制 2.在事务提交前,记录更新前的数据到undo log里,回滚的时候读…

企业智能化转型的关键步骤与陷阱

目录 前言1 转型的关键步骤1.1 深度学习AI技术课程的重要性1.2 激发创意,开展多样化项目的战略意义1.3 招募机器学习专业人才的加速转型1.4 引入具备领导力的AI领导1.5 建立与AI领导的紧密沟通机制 2 智能化转型的陷阱2.1 谨慎期待AI解决所有问题的智慧2.2 综合考虑…

Docker进阶:深入理解 Dockerfile

Docker进阶:深入理解 Dockerfile 一、Dockerfile 概述二、为什么要学习Dockerfile三、Dockerfile 编写规则四、Dockerfile 中常用的指令1、FROM2、LABEL3、RUN4、CMD5、ENTRYPOINT6、COPY7、ADD8、WORKDIR9、 ENV10、EXPOSE11、VOLUME12、USER13、注释14、ONBUILD 命…

Vue+SpringBoot打造独居老人物资配送系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统展示四、核心代码4.1 查询社区4.2 新增物资4.3 查询物资4.4 查询物资配送4.5 新增物资配送 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBootMySQL的独居老人物资配送系统,包含了社区档案、…

几何变换 - 图像的缩放、翻转、仿射变换、透视等

1、前言 图像的几何变换是指改变图像的几何结构,大小、形状等等,让图像呈现出具备缩放、翻转、映射和透视的效果 图像的几何变换都比较复杂,计算也很复杂。 例如仿射变换,像素点的位置和灰度值都需要变换。 数字图像处理中利用后向传播的方法,将像素点变换后的位置通过…

统信OUS安装node, npm,vue (亲测有效)

统信OUS安装node, npm,vue (1)查看自己的系统 uname -a(2)进入nodejs官网下载相应版本 (3)找个位置解压,然后复制文件目录 保存好地址,等下要用到 (4)编辑环…

PostgreSQL索引篇 | GiST索引

PostgreSQL版本为8.4.1 (本文为《PostgreSQL数据库内核分析》一书的总结笔记,需要电子版的可私信我) 索引篇: PostgreSQL索引篇 | BTreePostgreSQL索引篇 | GIN索引PostgreSQL索引篇 | Hash索引PostgreSQL索引篇 | TSearch2 全文…

Java高级编程—注解

文章目录 1.注解的概述2.常见的Annotation示例2.1 生成文档相关的注解2.2 在编译时进行格式检查的注解2.3 跟踪代码依赖性,实现替代配置文件功能的注解 3.自定义Annotation4.JDK中的元注解4.1 Retention4.2 Target4.3 Documented & Inherited 5. JDK8中注解的新…

什么是VR虚拟现实创作工具|元宇宙文化旅游|VR设备在线购买

VR虚拟现实创作工具是用于创建、编辑和制作虚拟现实内容的软件或硬件工具。这些工具提供了创作者在虚拟现实环境中进行创作的功能和自由度,使他们能够构建令人惊叹的虚拟世界和交互体验。 以下是一些常见的VR虚拟现实创作工具: 虚拟现实建模工具&#x…

基于STC系列单片机实现PNP型三极管S8550驱动共阳数码管或NPN型三极管S8050驱动共阴数码管功能

Digitron.c #include "Digitron.h" //#include "Key.h" #define uchar unsigned char//自定义无符号字符型为uchar #define uint unsigned int//自定义无符号整数型为uint //uchar code DigitronBitCodeArray[] {0x01,0x02,0x04,0x08,0x10,0x20,0x40,0x8…