基于深度学习的图像去雨去雾

news2024/11/16 22:36:16

基于深度学习的图像去雨去雾


文末附有源码下载地址
b站视频地址: https://www.bilibili.com/video/BV1Jr421p7cT/

基于深度学习的图像去雨去雾,使用的网络为unet,
网络代码:

import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor
import torch.nn.functional as F
from torchstat import stat

class Resnet18(nn.Module):
    def __init__(self):
        super(Resnet18, self).__init__()
        self.resnet = models.resnet18(pretrained=False)
        # self.resnet = create_feature_extractor(self.resnet, {'relu': 'feat320', 'layer1': 'feat160', 'layer2': 'feat80',
        #                                                'layer3': 'feat40'})

    def forward(self,x):
        for name,m in self.resnet._modules.items():

            x=m(x)
            if name=='relu':
                x1=x
            elif name=='layer1':
                x2=x
            elif name=='layer2':
                x3=x
            elif name=='layer3':
                x4=x
                break
        # x=self.resnet(x)
        return x1,x2,x3,x4
class Linears(nn.Module):
    def __init__(self,a,b):
        super(Linears, self).__init__()
        self.linear1=nn.Linear(a,b)
        self.relu1=nn.LeakyReLU()
        self.linear2 = nn.Linear(b, a)
        self.sigmoid=nn.Sigmoid()
    def forward(self,x):
        x=self.linear1(x)
        x=self.relu1(x)
        x=self.linear2(x)
        x=self.sigmoid(x)
        return x
class DenseNetBlock(nn.Module):
    def __init__(self,inplanes=1,planes=1,stride=1):
        super(DenseNetBlock,self).__init__()
        self.conv1=nn.Conv2d(inplanes,planes,3,stride,1)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1=nn.LeakyReLU()

        self.conv2 = nn.Conv2d(inplanes, planes, 3,stride,1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.LeakyReLU()

        self.conv3 = nn.Conv2d(inplanes, planes, 3,stride,1)
        self.bn3 = nn.BatchNorm2d(planes)
        self.relu3 = nn.LeakyReLU()
    def forward(self,x):
        ins=x
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x=x+ins

        x2=self.conv3(x)
        x2 = self.bn3(x2)
        x2=self.relu3(x2)

        out=ins+x+x2
        return out
class SEnet(nn.Module):
    def __init__(self,chs,reduction=4):
        super(SEnet,self).__init__()
        self.average_pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.fc = nn.Sequential(
            # First reduce dimension, then raise dimension.
            # Add nonlinear processing to fit the correlation between channels
            nn.Linear(chs, chs // reduction),
            nn.LeakyReLU(inplace=True),
            nn.Linear(chs // reduction, chs)
        )
        self.activation = nn.Sigmoid()
    def forward(self,x):
        ins=x
        batch_size, chs, h, w = x.shape
        x=self.average_pooling(x)
        x = x.view(batch_size, chs)
        x=self.fc(x)
        x = x.view(batch_size,chs,1,1)
        return x*ins
class UAFM(nn.Module):
    def __init__(self):
        super(UAFM, self).__init__()
        # self.meanPool_C=torch.max()

        self.attention=nn.Sequential(
            nn.Conv2d(4, 8, 3, 1,1),
            nn.LeakyReLU(),
            nn.Conv2d(8, 1, 1, 1),
            nn.Sigmoid()
        )


    def forward(self,x1,x2):
        x1_mean_pool=torch.mean(x1,dim=1)
        x1_max_pool,_=torch.max(x1,dim=1)
        x2_mean_pool = torch.mean(x2, dim=1)
        x2_max_pool,_ = torch.max(x2, dim=1)

        x1_mean_pool=torch.unsqueeze(x1_mean_pool,dim=1)
        x1_max_pool=torch.unsqueeze(x1_max_pool,dim=1)
        x2_mean_pool=torch.unsqueeze(x2_mean_pool,dim=1)
        x2_max_pool=torch.unsqueeze(x2_max_pool,dim=1)

        cat=torch.cat((x1_mean_pool,x1_max_pool,x2_mean_pool,x2_max_pool),dim=1)
        a=self.attention(cat)
        out=x1*a+x2*(1-a)
        return out

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.resnet18=Resnet18()
        self.SENet=SEnet(chs=256)
        self.UAFM=UAFM()
        self.DenseNet1=DenseNetBlock(inplanes=256,planes=256)
        self.transConv1=nn.ConvTranspose2d(256,128,3,2,1,output_padding=1)

        self.DenseNet2 = DenseNetBlock(inplanes=128, planes=128)
        self.transConv2 = nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding=1)

        self.DenseNet3 = DenseNetBlock(inplanes=64, planes=64)
        self.transConv3 = nn.ConvTranspose2d(64, 64, 3, 2, 1, output_padding=1)

        self.transConv4 = nn.ConvTranspose2d(64, 32, 3, 2, 1, output_padding=1)
        self.DenseNet4=DenseNetBlock(inplanes=32,planes=32)
        self.out=nn.Sequential(
            nn.Conv2d(32,3,1,1),
            nn.Sigmoid()
        )

    def forward(self,x):
        """
        下采样部分
        """
        x1,x2,x3,x4=self.resnet18(x)
        # feat320=features['feat320']
        # feat160=features['feat160']
        # feat80=features['feat80']
        # feat40=features['feat40']
        feat320=x1
        feat160=x2
        feat80=x3
        feat40=x4
        """
        上采样部分
        """
        x=self.SENet(feat40)
        x=self.DenseNet1(x)
        x=self.transConv1(x)
        x=self.UAFM(x,feat80)

        x=self.DenseNet2(x)
        x=self.transConv2(x)
        x=self.UAFM(x,feat160)

        x = self.DenseNet3(x)
        x = self.transConv3(x)
        x = self.UAFM(x, feat320)

        x=self.transConv4(x)
        x=self.DenseNet4(x)
        out=self.out(x)

        # out=torch.concat((out,out,out),dim=1)*255.

        return out

    def freeze_backbone(self):
        for param in self.resnet18.parameters():
            param.requires_grad = False

    def unfreeze_backbone(self):
        for param in self.resnet18.parameters():
            param.requires_grad = True


if __name__ == '__main__':

    net=Net()
    print(net)
    # stat(net,(3,640,640))

    summary(net,input_size=(3,512,512),device='cpu')

    aa=torch.ones((6,3,512,512))
    out=net(aa)
    print(out.shape)
    # ii=torch.zeros((1,3,640,640))
    # outs=net(ii)
    # print(outs.shape)






主题界面显示及代码:
在这里插入图片描述

from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from untitled import Ui_Form
import sys
import cv2 as cv
from PyQt5.QtCore import QCoreApplication
import numpy as np
from PyQt5 import QtCore,QtGui
from PIL import Image
from predict import *

class My(QMainWindow,Ui_Form):
    def __init__(self):
        super(My,self).__init__()
        self.setupUi(self)
        self.setWindowTitle('图像去雨去雾')
        self.setIcon()
        self.pushButton.clicked.connect(self.pic)
        self.pushButton_2.clicked.connect(self.pre)
        self.pushButton_3.clicked.connect(self.pre2)
    def setIcon(self):
       palette1 = QPalette()
       # palette1.setColor(self.backgroundRole(), QColor(192,253,123))   # 设置背景颜色
       palette1.setBrush(self.backgroundRole(), QBrush(QPixmap('back.png')))  # 设置背景图片
       self.setPalette(palette1)
    def pre(self):
        out=pre(self.img,0)
        out=self.cv_qt(out)
        self.label_2.setPixmap(QPixmap.fromImage(out).scaled(self.label.width(),self.label.height(),QtCore.Qt.KeepAspectRatio))
    def pre2(self):
        out=pre(self.img,1)
        out=self.cv_qt(out)
        self.label_2.setPixmap(QPixmap.fromImage(out).scaled(self.label.width(),self.label.height(),QtCore.Qt.KeepAspectRatio))

    def pic(self):
        imgName, imgType = QFileDialog.getOpenFileName(self,
                                                       "打开图片",
                                                       "",
                                                       " *.png;;*.jpg;;*.jpeg;;*.bmp;;All Files (*)")
        #KeepAspectRatio
        png = QtGui.QPixmap(imgName).scaled(self.label.width(),self.label.height(),QtCore.Qt.KeepAspectRatio)  # 适应设计label时的大小
        self.label.setPixmap(png)

        self.img=Image.open(imgName)
        self.img=np.array(self.img)
    def cv_qt(self, src):
        #src必须为bgr格式图像
        #src必须为bgr格式图像
        #src必须为bgr格式图像
        if len(src.shape)==2:
            src=np.expand_dims(src,axis=-1)
            src=np.tile(src,(1,1,3))
            h, w, d = src.shape
        else:h, w, d = src.shape



        bytesperline = d * w
        # self.src=cv.cvtColor(self.src,cv.COLOR_BGR2RGB)
        qt_image = QImage(src.data, w, h, bytesperline, QImage.Format_RGB888).rgbSwapped()
        return qt_image

if __name__ == '__main__':
    QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
    app=QApplication(sys.argv)
    my=My()
    my.show()
    sys.exit(app.exec_())

项目结构:
在这里插入图片描述
直接运行main.py即可弹出交互界面。
项目下载地址:下载地址-列表第19

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1515569.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

人工智能迷惑行为大赏——需求与科技的较量

目录 前言 一、 机器行为学 二、人工智能迷惑行为的现象 三、产生迷惑行为的技术原因 四、社会影响分析 五、解决措施 总结 前言 随着ChatGPT热度的攀升,越来越多的公司也相继推出了自己的AI大模型,如文心一言、通义千问等。各大应用也开始内置…

Netty架构详解

文章目录 概述整体结构Netty的核心组件逻辑架构BootStrap & ServerBootStrapChannelPipelineFuture、回调和 ChannelHandler选择器、事件和 EventLoopChannelHandler的各种ChannelInitializer类图 Protocol Support 协议支持层Transport Service 传输服务层Core 核心层模块…

多维时序 | Matlab实现VMD-CNN-GRU变分模态分解结合卷积神经网络门控循环单元多变量时间序列预测

多维时序 | Matlab实现VMD-CNN-GRU变分模态分解结合卷积神经网络门控循环单元多变量时间序列预测 目录 多维时序 | Matlab实现VMD-CNN-GRU变分模态分解结合卷积神经网络门控循环单元多变量时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现VMD-CN…

软件测试面试都问了什么?中级软件测试岗面试(4面)

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 一面(…

Excel判断CD两列在EF两列的列表中是否存在

需求 需要将CD两列的ID和NAME组合起来,查询EF两列的ID和NAME组合起来的列表中是否存在? 比如,判断第二行的“123456ABC”在EF的第二行到第四行中是否存在,若存在则显示Y,不存在则显示N 实现的计算公式 IF(ISNUMBER…

全视智慧机构养老解决方案,以科技守护长者安全

2024年2月28日凌晨1时许,在上海浦东大道的一家养护院四楼杂物间内发生了一起火灾事故。尽管火势不大,过火面积仅为2平方米,但这场小火却造成了1人死亡和3人受伤的悲剧。这一事件再次提醒我们,养老院作为老年人聚集的场所&#xff…

阿里云免费证书改为3个月,应对方法很简单

情商高点的说法是 Google 积极推进90天免费证书,各服务商积极响应。 情商低点的话,就是钱的问题。 现在基本各大服务商都在2024年停止签发1年期的免费SSL证书产品,有效期都缩短至3个月。 目前腾讯云倒还是一年期。 如果是一年期的话&#x…

关于微服务跨数据库联合查询的一些解决思路

微服务架构的一个非常明显的特征就是一个服务所拥有的数据只能通过这个服务的API来访问。通过这种方式来解耦,这样就会带来查询问题。以前通过join就可以满足要求,现在如果需要跨多个服务集成查询就会非常麻烦。 解决思路 下面提供几个思路仅供参考 表…

在centOS服务器安装docker,并使用docker配置nacos

遇到安装慢的情况可以优先选择阿里镜像 安装docker 更新yum版本 yum update安装所需软件包 yum install -y yum-utils device-mapper-persistent-data lvm2添加Docker仓库 yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.rep…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的火焰与烟雾检测系统详解(深度学习模型+UI界面升级版+训练数据集)

摘要:本研究详细介绍了一种集成了最新YOLOv8算法的火焰与烟雾检测系统,并与YOLOv7、YOLOv6、YOLOv5等早期算法进行性能评估对比。该系统能够在包括图像、视频文件、实时视频流及批量文件中准确识别火焰与烟雾。文章深入探讨了YOLOv8算法的原理&#xff0…

二、TensorFlow结构分析(5)案例

案例: minimize(error) 代码: def linear_regression():# 自实现线性回归# 1)准备数据X tf.random.normal(shape[100,1])y_true tf.matmul(X,[[0.8]]) 0.7# 2)构造模型# 定义模型参数 用 变量weights tf.Variable(initial_v…

不想多花10万块,别买理想MEGA

文 | AUTO芯球 作者 | 雷歌 理想MEGA注定要凉凉! 这个口号喊震天响的MPV,直呼要做“50 万以上销量第一,不分能源形式、不分车身形态。” 50万以上?这个限定词真够高,但上面也不是没有狠角色。 比如腾势D9&#xf…

前端自动刷新Token与超时安全退出攻略

一、token的作用 因为http请求是无状态的,是一次性的,请求之间没有任何关系,服务端无法知道请求者的身份,所以需要鉴权,来验证当前用户是否有访问系统的权限。 以oauth2.0授权码模式为例: 每次请求资源服…

机器学习-0X-神经网络

总结 本系列是机器学习课程的系列课程,主要介绍机器学习中神经网络算法。 本门课程的目标 完成一个特定行业的算法应用全过程: 懂业务会选择合适的算法数据处理算法训练算法调优算法融合 算法评估持续调优工程化接口实现 参考 机器学习定义 关于机…

财富池指标公式--通达信短线快攻指标公式

今日分享的通达信短线快攻指标公式是一个分享短线买卖点的指标公式。 具体信号说明: 当指标中出现蓝色的哭脸的图标时,可开始关注该个股,当出现红色向上的箭头时,后市上涨的概率较大,是参考买入的信号。 当只指标中出…

python考点2

只考列表字典 注意1,5,7.10

Manning技术出版公司

Manning 是一家美国的技术出版公司,专门出版与计算机科学、信息技术和编程相关的图书和教育资料。该公司成立于 1990 年代初期,是技术图书领域的知名品牌之一。 Manning 公司的中文翻译名字可以是 “曼宁”。 最近发现好多国外的翻译技术图书是这家出版…

Unix环境高级编程-学习-05-TCP/IP协议与套接字

目录 一、概念 二、TCP/IP参考模型 三、客户端和服务端使用TCP通信过程 1、同一以太网下 四、函数介绍 1、socket (1)声明 (2)作用 (3)参数 (4)返回值 (5&…

先初始化读取数据,然后才填充(低级错误,引以为戒)

本来是先初始化,然后读取数据。 结果上下两句写反了,一直报错。断点打了两个小时,才发现

2024年信息技术与计算机工程国际学术会议(ICITCEI 2024)

2024年信息技术与计算机工程国际学术会议(ICITCEI 2024) 2024 International Conference on Information Technology and Computer Engineering ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 大会主题: 信息系统和技术…