基于深度学习神经网络的AI图像PSD去雾系统源码

news2024/11/26 16:53:18

第一步:PSD介绍

        以往的研究主要集中在具有合成模糊图像的训练模型上,当模型用于真实世界的模糊图像时,会导致性能下降。

        为了解决上述问题,提高去雾的泛化性能,作者提出了一种Principled Synthetic-to-real Dehazing (PSD)框架。

        本文提出的PSD适用于将现有的去雾模型推广到实际领域,包括两个阶段:有监督的预训练无监督的微调

        预训练阶段,作者将选定的去雾模型主干修改为一个基于物理模型的网络,并用合成数据训练该网络。利用设计良好的主干,我们可以得到一个预先训练的模型,在合成域上具有良好的去雾性能。

        微调阶段,作者利用真实的模糊图像以无监督的方式训练模型。

本文的贡献:

  1. 作者将真实世界的去雾任务重新定义为一个合成到真实的泛化框架:首先一个在合成配对数据上预先训练的去雾模型主干,真实的模糊图像随后将被利用以一种无监督的方式微调模型。PSD易于使用,可以以大多数深度去雾模型为骨干。
  2. 由于没有清晰的真实图像作为监督,作者利用几个流行的、有充分根据的物理先验来指导微调。作者将它们合并成一个预先的损失committee,作为具体任务的代理指导,这一部分也是PSD的核心。
  3. 性能达到SOTA

第二步:PSD网络结构

        首先对两个框架大的方向做一个整体概述:

        Pre-training

        首先采用目前性能最好的框架之一作为网络的主干

        然后我们将主干修改为一个基于物理的网络,根据一个单一的雾图同时生成干净的图像 J,传输图 t 和大气光 A,为了共同优化这三个分量,作者加入了一个重建损失,它引导网络输出服从物理散射模型。

        在这个阶段,只使用标记的合成数据进行训练,最终得到一个在合成域上预训练的模型。

        Fine-tuning

        作者利用未标记的真实数据将预训练模型从合成域推广到真实域。受去雾强物理背景的启发,作者认为一个高质量的无雾图像应该遵循一些特定的统计规则,这些规则可以从图像先验中推导出来。此外,单一先验提供的物理知识并不总是可靠的,所以作者的目标是找到多个先验的组合,希望它们能够相互补充。

        基于上述,作者设计了一个先验损失committee来作为任务特定的代理指导,用于训练未标记的真实数据

        此外,作者应用了一种learning without forgetting (LwF)的方法,该方法通过将原始任务的训练数据(即合成的模糊图像)通过网络运转到同真实的模糊数据一起,从而强行使得模型记忆合成领域的知识。

第三步:模型代码展示

import torch
import torch.nn as nn
import torch.nn.functional as F

class BlockUNet1(nn.Module):
    def __init__(self, in_channels, out_channels, upsample=False, relu=False, drop=False, bn=True):
        super(BlockUNet1, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)

        self.dropout = nn.Dropout2d(0.5)
        self.batch = nn.InstanceNorm2d(out_channels)

        self.upsample = upsample
        self.relu = relu
        self.drop = drop
        self.bn = bn

    def forward(self, x):
        if self.relu == True:
            y = F.relu(x)
        elif self.relu == False:
            y = F.leaky_relu(x, 0.2)
        if self.upsample == True:
            y = self.deconv(y)
            if self.bn == True:
                y = self.batch(y)
            if self.drop == True:
                y = self.dropout(y)

        elif self.upsample == False:
            y = self.conv(y)
            if self.bn == True:
                y = self.batch(y)
            if self.drop == True:
                y = self.dropout(y)

        return y

class G2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(G2, self).__init__()

        self.conv = nn.Conv2d(in_channels, 8, 4, 2, 1, bias=False)
        self.layer1 = BlockUNet1(8, 16)
        self.layer2 = BlockUNet1(16, 32)
        self.layer3 = BlockUNet1(32, 64)
        self.layer4 = BlockUNet1(64, 64)
        self.layer5 = BlockUNet1(64, 64)
        self.layer6 = BlockUNet1(64, 64)
        self.layer7 = BlockUNet1(64, 64)
        self.dlayer7 = BlockUNet1(64, 64, True, True, True, False)
        self.dlayer6 = BlockUNet1(128, 64, True, True, True)
        self.dlayer5 = BlockUNet1(128, 64, True, True, True)
        self.dlayer4 = BlockUNet1(128, 64, True, True)
        self.dlayer3 = BlockUNet1(128, 32, True, True)
        self.dlayer2 = BlockUNet1(64, 16, True, True)
        self.dlayer1 = BlockUNet1(32, 8, True, True)
        self.relu = nn.ReLU()
        self.dconv = nn.ConvTranspose2d(16, out_channels, 4, 2, 1, bias=False)
        self.lrelu = nn.LeakyReLU(0.2)

    def forward(self, x):
        y1 = self.conv(x)
        y2 = self.layer1(y1)
        y3 = self.layer2(y2)
        y4 = self.layer3(y3)
        y5 = self.layer4(y4)
        y6 = self.layer5(y5)
        y7 = self.layer6(y6)
        y8 = self.layer7(y7)

        dy8 = self.dlayer7(y8)
        concat7 = torch.cat([dy8, y7], 1)
        dy7 = self.dlayer6(concat7)
        concat6 = torch.cat([dy7, y6], 1)
        dy6 = self.dlayer5(concat6)
        concat5 = torch.cat([dy6, y5], 1)
        dy5 = self.dlayer4(concat5)
        concat4 = torch.cat([dy5, y4], 1)
        dy4 = self.dlayer3(concat4)
        concat3 = torch.cat([dy4, y3], 1)
        dy3 = self.dlayer2(concat3)
        concat2 = torch.cat([dy3, y2], 1)
        dy2 = self.dlayer1(concat2)
        concat1 = torch.cat([dy2, y1], 1)
        out = self.relu(concat1)
        out = self.dconv(out)
        out = self.lrelu(out)

        return F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    
def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size//2), bias=bias)
    
class PALayer(nn.Module):
    def __init__(self, channel):
        super(PALayer, self).__init__()
        self.pa = nn.Sequential(
                nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
                nn.Sigmoid()
        )
    def forward(self, x):
        y = self.pa(x)
        return x * y

class CALayer(nn.Module):
    def __init__(self, channel):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
                nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y

class Block(nn.Module):
    def __init__(self, conv, dim, kernel_size,):
        super(Block, self).__init__()
        self.conv1=conv(dim, dim, kernel_size, bias=True)
        self.act1=nn.ReLU(inplace=True)
        self.conv2=conv(dim,dim,kernel_size,bias=True)
        self.calayer=CALayer(dim)
        self.palayer=PALayer(dim)
    def forward(self, x):
        res=self.act1(self.conv1(x))
        res=res+x 
        res=self.conv2(res)
        res=self.calayer(res)
        res=self.palayer(res)
        res += x 
        return res
class Group(nn.Module):
    def __init__(self, conv, dim, kernel_size, blocks):
        super(Group, self).__init__()
        modules = [ Block(conv, dim, kernel_size)  for _ in range(blocks)]
        modules.append(conv(dim, dim, kernel_size))
        self.gp = nn.Sequential(*modules)
    def forward(self, x):
        res = self.gp(x)
        res += x
        return res

class FFANet(nn.Module):
    def __init__(self,gps,blocks,conv=default_conv):
        super(FFANet, self).__init__()
        self.gps=gps
        self.dim=64
        kernel_size=3
        pre_process = [conv(3, self.dim, kernel_size)]
        assert self.gps==3
        self.g1= Group(conv, self.dim, kernel_size,blocks=blocks)
        self.g2= Group(conv, self.dim, kernel_size,blocks=blocks)
        self.g3= Group(conv, self.dim, kernel_size,blocks=blocks)
        self.ca=nn.Sequential(*[
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(self.dim*self.gps,self.dim//16,1,padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.dim//16, self.dim*self.gps, 1, padding=0, bias=True),
            nn.Sigmoid()
            ])
        self.palayer=PALayer(self.dim)

        self.conv_J_1 = nn.Conv2d(64, 64, 3, 1, 1, bias=False)
        self.conv_J_2 = nn.Conv2d(64, 3, 3, 1, 1, bias=False)
        self.conv_T_1 = nn.Conv2d(64, 16, 3, 1, 1, bias=False)
        self.conv_T_2 = nn.Conv2d(16, 1, 3, 1, 1, bias=False)
        
        post_precess = [
            conv(self.dim, self.dim, kernel_size),
            conv(self.dim, 3, kernel_size)]

        self.pre = nn.Sequential(*pre_process)
        self.post = nn.Sequential(*post_precess)
        self.ANet = G2(3, 3)

    def forward(self, x1, x2=0, Val=False):
        x = self.pre(x1)
        res1=self.g1(x)
        res2=self.g2(res1)
        res3=self.g3(res2)
        w=self.ca(torch.cat([res1,res2,res3],dim=1))
        w=w.view(-1,self.gps,self.dim)[:,:,:,None,None]
        out=w[:,0,::]*res1+w[:,1,::]*res2+w[:,2,::]*res3
        out=self.palayer(out)
        
        out_J = self.conv_J_1(out)
        out_J = self.conv_J_2(out_J)
        out_J = out_J + x1
        out_T = self.conv_T_1(out)
        out_T = self.conv_T_2(out_T)
        
        if Val == False:
            out_A = self.ANet(x1)
        else:
            out_A = self.ANet(x2)
            
        out_I = out_T * out_J + (1 - out_T) * out_A
        #x=self.post(out)
        return out, out_J, out_T, out_A, out_I
    
if __name__ == "__main__":
    net=FFA(gps=3,blocks=19)
    print(net)

第四步:运行

第五步:整个工程的内容

代码的下载路径(新窗口打开链接)基于深度学习神经网络的AI图像PSD去雾系统源码

​​

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1669152.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件体系结构风格

目录 一、定义 二、.经典软件体系结构风格: 1.管道和过滤器 2.数据抽象和面向对象系统 3.基于事件系统(隐式调用) 4.分层系统 5.仓库 6.C2风格 7.C/S 8.三层C/S 9.B/S 题: 一、定义 软件体系机构风格是描述某一特定应用…

【CSP CCF记录】202104-1 灰度直方图

题目 过程 #include<bits/stdc.h> using namespace std; int n,m,L; int A[502][502]; int main() {cin>>n>>m>>L;int h[L]{0};for(int i0;i<n;i){for(int j0;j<m;j){cin>>A[i][j];h[A[i][j]];}}for(int i0;i<L;i)cout<<h[i]<…

第四届上海理工大学程序设计全国挑战赛 J.上学 题解 DFS 容斥

上学 题目描述 usst 小学里有 n 名学生&#xff0c;他们分别居住在 n 个地点&#xff0c;第 i 名学生居住在第 i 个地点&#xff0c;这些地点由 n−1 条双向道路连接&#xff0c;保证任意两个地点之间可以通过若干条双向道路抵达。学校则位于另外的第 0 个地点&#xff0c;第…

SeetaFace6人脸活体检测C++代码实现Demo

SeetaFace6包含人脸识别的基本能力&#xff1a;人脸检测、关键点定位、人脸识别&#xff0c;同时增加了活体检测、质量评估、年龄性别估计&#xff0c;并且顺应实际应用需求&#xff0c;开放口罩检测以及口罩佩戴场景下的人脸识别模型。 官网地址&#xff1a;https://github.co…

Rerank进一步提升RAG效果

RAG & Rerank 目前大模型应用中&#xff0c;RAG&#xff08;Retrieval Augmented Generation&#xff0c;检索增强生成&#xff09;是一种在对话&#xff08;QA&#xff09;场景下最主要的应用形式&#xff0c;它主要解决大模型的知识存储和更新问题。 简述RAG without R…

【前端系列】什么是yarn

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

Python 全栈系列245 nginx 前端web页面透传

说明 过去的几年&#xff0c;我已经构造了很多组件&#xff0c;从图的角度来看&#xff0c;完成了很多点。这些点的单点测试看起来都不错&#xff0c;但是因为没有连起来&#xff0c;所以无法体现系统价值。好比发动机的马力虽然大&#xff0c;但是没有传动轴&#xff0c;那就…

重学JavaScript核心知识点(二)—— 详解Js中的模块化

详解Js中的模块化 1. 模块化的背景2. 来看一个例子3. 优雅的做法 —— 创建模块对象4. 模块与类&#xff08;class&#xff09;5. 合并模块6. 动态加载模块 1. 模块化的背景 JavaScript 在诞生之初是体积很小的&#xff0c;早期&#xff0c;它们大多被用来执行独立的脚本任务&…

Shell编程之循环语甸与函数

for 遍历循环 1&#xff09;for 变量 in 取值列表 for i in $(seq 1 10) do 命令序列 .... done 2&#xff09;for ((变量初始值; 变量范围; 变量的迭代方式)) for ((i1; i<10; i)) do 命令序列 .... done IFS for循环取值列表分隔符 set | grep IFS …

---随笔--Java实现TCP通信(双端通信接收与发送)

---随笔--Java实现TCP通信&#xff08;双端通信接收与发送&#xff09; 引言1. 什么是TCP通信2. 服务器与客户端核心代码2.1 服务器ServerSocket端核心代码2.2 用户Socket端核心代码2.3 小贴士之关于try-with-resources自动关闭资源的使用 3. 具体服务器端实现4. 具体客户端实现…

WordPress 、Typecho 站点的 MySQL/MariaDB 数据库优化

今天明月给大家分享一下 WordPress 、Typecho 站点的 MySQL/MariaDB 数据库优化&#xff0c;无论你的站点采用是 WordPress 还是 Typecho&#xff0c;都要用到 MySQL/MariaDB 数据库&#xff0c;我们以 MySQL 为主&#xff08;MariaDB 其实跟 MySQL 基本没啥大的区别&#xff0…

STC8增强型单片机开发【LED呼吸灯(PWM)⭐⭐】

目录 一、引言 二、硬件准备 三、PWM技术概述 四、电路设计 五、代码编写 EAXSFR&#xff1a; 六、编译与下载 七、测试与调试 八、总结 一、引言 在嵌入式系统开发中&#xff0c;LED呼吸灯是一种常见的示例项目&#xff0c;它不仅能够展示PWM&#xff08;脉冲宽度调制…

解决 Content type ‘application/json;charset=UTF-8‘ not supported

文章目录 问题描述原因分析解决方案参考资料 问题描述 我项目前端采用vue-elementUi-admin框架进行开发&#xff0c;后端使用SpringBoot&#xff0c;但在前后端登录接口交互时&#xff0c;前端报了如下错误 完整报错信息如下 前端登录接口JS代码如下 export function login(…

Busybox 在 Docker 中的部署和启动

可以使用 docker pull 指令下载 busybox:latest 镜像&#xff1a; PS C:\Users\yhu> docker pull busybox:latest latest: Pulling from library/busybox ec562eabd705: Pull complete Digest: sha256:5eef5ed34e1e1ff0a4ae850395cbf665c4de6b4b83a32a0bc7bcb998e24e7bbb St…

网安面经之文件包含漏洞

一、文件包含漏洞 1、文件包含漏洞原理&#xff1f;危害&#xff1f;修复&#xff1f; 原理&#xff1a;开发⼈员⼀般希望代码更灵活&#xff0c;所以将被包含的⽂件设置为变量&#xff0c;⽤来进⾏动态调⽤&#xff0c;但是由于⽂件包含函数加载的参数没有经过过滤或者严格的…

Selenium 自动化 —— 一篇文章彻底搞懂XPath

更多关于Selenium的知识请访问“兰亭序咖啡”的专栏&#xff1a;专栏《Selenium 从入门到精通》 文章目录 前言 一、什么是xpath&#xff1f; 二、XPath 节点 三. 节点的关系 1. 父&#xff08;Parent&#xff09; 2. 子&#xff08;Children&#xff09; 3. 同胞&#xff08;S…

LLM量化

Efficient Finetuning prefix tuning 针对每种任务&#xff0c;学习prefix vector 启发于prompting&#xff0c;调整上下文内容让模型去输出自己想要的内容 核心就是找到一个上下文去引导模型解决NLP生成任务 传统情况下&#xff0c;我们为了得到想要的结果&#xff0c;会…

1.9. 离散时间鞅-组合方法分析随机游动(不用鞅方法)

组合方法分析随机游动-不用鞅方法 组合方法分析随机游动(不用鞅方法)1. 反射原理 → \rightarrow →首击时分布2. 反正弦定律(末击时分布)组合方法分析随机游动(不用鞅方法) 本节将不使用鞅更深入地研究简单随机游动的属性,有益于过渡到马氏链的研究. 1. 反射原理

20240512,函数对象,常用算法:遍历,查找

函数对象 函数对象基本使用 重载 函数调用操作符 的类&#xff0c;其对象被称为函数对象&#xff1b;函数对象使用重载的&#xff08;&#xff09;时&#xff0c;行为类似函数调用&#xff0c;也叫仿函数 函数对象&#xff08;仿函数&#xff09;本质是一个类&#xff0c;不是…

【GlobalMapper精品教程】080:WGS84转UTM投影

参考阅读:ArcGIS实验教程——实验十:矢量数据投影变换 文章目录 一、加载实验数据二、设置输出坐标系三、数据导出一、加载实验数据 打开配套案例数据包中的data080.rar中的矢量数据,如下所示: 查看源坐标系:双击图层的,图层投影选项卡,数据的已有坐标系为WGS84地理坐标…