【NeurIPS 2023】PromptIR: Prompting for All-in-One Blind Image Restoration

news2025/2/25 10:13:45

PromptIR: Prompting for All-in-One Blind Image Restoration, NeurIPS 2023

论文:https://arxiv.org/abs/2306.13090

代码:https://github.com/va1shn9v/promptir

解读:即插即用系列 | PromptIR:MBZUAI提出一种基于Prompt的全能图像恢复网络 - 知乎 (zhihu.com)

摘要

图像恢复是从其受损版本中恢复高质量清晰图像的过程。deep-learning方法显著提升了图像恢复性能,然而,它们在不同类型和级别的退化上的泛化能力有限。这限制了它们在实际应用中的使用,因为需要针对每种具体的退化进行单独训练模型,并了解输入图像的退化类型才能应用相应的模型。本文介绍了一种基于提示的学习方法,称为PromptIR,用于全能图像恢复,可以有效地从各种类型和级别的退化中恢复图像。具体而言,本文方法使用提示来编码退化特定信息,并动态引导恢复网络。 PromptIR提供了一个通用且高效的插件模块,只需少量轻量级提示即可用于恢复各种类型和级别的受损图像,无需事先了解图像中存在的损坏信息。

动机

图像恢复过程中,会由于各种客观因素或限制(相机设备、环境条件)出现各种退化现象。deep-learning方法虽然有用,但在特定的退化类型和程度之外缺乏泛化性。因此,迫切需要开发一种能够有效恢复各种类型和程度的退化图像的一体化方法。

AirNet,采用对比学习解决一体化恢复任务,但需要训练额外的编码器,会增加训练负担。

为此,本文提出了一种基于提示学习的方法来执行一体化图像恢复。该方法利用提示(一组可调参数),用于编码关于各种图像退化类型的重要区分信息。通过将提示与主恢复网络的特征表示相互作用,动态地增强表示,以获得具有退化特定知识的适应性,这种适应性使网络能够通过动态调整其行为有效地恢复图像。

该图显示了在PromptIR和AirNet中使用的退化嵌入的tSNE图。不同的颜色表示不同的退化类型。每个任务的嵌入更好地聚集在一起,显示了提示标记学习具有区分性的退化上下文的有效性,从而有助于恢复过程。

 

所提的Prompt 方法

本文PromptIR,提出了一个即插即用的提示模块,它隐式地预测与退化条件相关的提示,以引导具有未知退化的输入图像的恢复过程。来自提示的指导被注入到网络的多个解码阶段,其中包含少量可学习参数。

 

贡献

  • 提出了一种基于提示的一体化恢复框架PromptIR,它仅依赖于输入图像来恢复清晰图像,而不需要任何关于图像中存在的退化的先验知识。
  • 本文prompt block 是一个可轻松集成到任何现有恢复网络中的插件模块。它由提示生成模块(PGM)和提示交互模块(PIM)组成。提示块的目标是生成与输入条件相关的提示(通过PGM),这些提示具有有用的上下文信息,以指导恢复网络(通过PIM)有效地消除输入图像中的破坏。
  • 本文实验证明了PromptIR的动态适应行为,在包括图像降噪、去雨和去雾在内的各种图像恢复任务上实现了最先进的性能。

方法 PromptIR

“一体式”图像恢复的目标是,学习单个模型M,以从退化的图像中恢复干净图像,该图像已使用退化方式D退化,而没有关于D的先验信息。虽然该模型对退化方式“不可见”,可以通过提供关于退化类型的隐含上下文信息来增强其在恢复干净图像方面的性能。基于提示学习的图像恢复框架PromptIR,用于在恢复干净图像的同时用退化类型的相关知识补充模型。关键元素是 提示块prompt block。

PromptIR使用提示块来生成可学习的提示参数,并在恢复过程中利用这些提示来指导模型。框架通过逐级编码器-解码器将特征逐步转换为深层特征,并在解码器中引入提示块来辅助恢复过程。提示块在解码器的每个级别中连接,隐式地为输入特征提供关于退化类型的信息,以实现引导恢复。总体来说,PromptIR框架通过逐级编码和解码以及引入提示块的方式实现图像恢复任务。 

Prompt Block

提示块 prompt block,由两个模块组成:提示生成模块(PGM)和提示交互模块(PIM)。

  • 提示生成模块使用输入特征Fl和提示组件生成与输入条件相关的提示P。
  • 提示交互模块通过Transformer块使用生成的提示动态调整输入特征。提示与解码器特征在多个级别交互,以丰富特定于退化的上下文信息。 

给N个prompt-components P_c \in R^{N*\hat{H}*\hat{W}*\hat{C}}和 input feature F_1\in R^{\hat{H}*\hat{W}*\hat{C}}, prompt block 可表示为:

 

 其中,PGM表示提示生成模块,PIM表示提示交互模块。

Prompt Generation Module (PGM)

提示组件 P_c是一组可学习的参数,与输入特征交互,嵌入了退化信息。一种直接的特征-提示交互方法是直接使用学习到的提示来校准特征,可能会产生次优结果,因为它对输入内容是无知的。本文提出了提示生成模块(PGM),它从输入特征中动态预测基于注意力的权重,并将这些权重应用于提示组件,生成与输入条件相关的提示P。此外,PGM创建了一个共享空间,促进了提示组件之间的相关知识共享。PGM公式表达为:

Prompt Interaction Module (PIM)

PIM的主要目标是实现输入特征F_1和提示P之间的交互,以实现有指导的恢复过程。

在PIM中,沿着通道维度将生成的提示与输入特征进行拼接。接下来将拼接后的表示通过一个Transformer块进行处理,该块利用提示中编码的退化信息来转换输入特征。

​​本文的主要贡献是提示块,它是一个插件模块,与具体的架构无关。PromptIR框架中,使用了现有的Transformer块。Transformer块由两个顺序连接的子模块组成:多转置卷积头转置注意力(MDTA)和门控转置卷积前馈网络(GDFN)。MDTA在通道而不是空间维度上应用自注意操作,并具有线性复杂度。GDFN的目标是以可控的方式转换特征,即抑制信息较少的特征,只允许有用的特征在网络中传播。PIM的整体过程为:

实验

主要实验与可视化样例 

全能恢复设置下的比较:

去雾、去雨、去噪实验比较: 

 

去雾、去雨、去噪可视化比较:

消融实验 

关键代码

PGM

# https://github.com/va1shn9v/PromptIR/blob/main/net/model.py

##---------- Prompt Gen Module -----------------------
class PromptGenBlock(nn.Module):
    def __init__(self,prompt_dim=128,prompt_len=5,prompt_size = 96,lin_dim = 192):
        super(PromptGenBlock,self).__init__()
        self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
        self.linear_layer = nn.Linear(lin_dim,prompt_len)
        self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)
        

    def forward(self,x):
        B,C,H,W = x.shape
        emb = x.mean(dim=(-2,-1))
        prompt_weights = F.softmax(self.linear_layer(emb),dim=1)
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
        prompt = torch.sum(prompt,dim=1)
        prompt = F.interpolate(prompt,(H,W),mode="bilinear")
        prompt = self.conv3x3(prompt)

        return prompt

PromptIR

# https://github.com/va1shn9v/PromptIR/blob/main/net/model.py

class PromptIR(nn.Module):
    def __init__(self, 
        inp_channels=3, 
        out_channels=3, 
        dim = 48,
        num_blocks = [4,6,6,8], 
        num_refinement_blocks = 4,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        decoder = False,
    ):

        super(PromptIR, self).__init__()

        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
        
        
        self.decoder = decoder
        
        if self.decoder:
            self.prompt1 = PromptGenBlock(prompt_dim=64,prompt_len=5,prompt_size = 64,lin_dim = 96)
            self.prompt2 = PromptGenBlock(prompt_dim=128,prompt_len=5,prompt_size = 32,lin_dim = 192)
            self.prompt3 = PromptGenBlock(prompt_dim=320,prompt_len=5,prompt_size = 16,lin_dim = 384)
        
        
        self.chnl_reduce1 = nn.Conv2d(64,64,kernel_size=1,bias=bias)
        self.chnl_reduce2 = nn.Conv2d(128,128,kernel_size=1,bias=bias)
        self.chnl_reduce3 = nn.Conv2d(320,256,kernel_size=1,bias=bias)



        self.reduce_noise_channel_1 = nn.Conv2d(dim + 64,dim,kernel_size=1,bias=bias)
        self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        
        self.down1_2 = Downsample(dim) ## From Level 1 to Level 2

        self.reduce_noise_channel_2 = nn.Conv2d(int(dim*2**1) + 128,int(dim*2**1),kernel_size=1,bias=bias)
        self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        
        self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3

        self.reduce_noise_channel_3 = nn.Conv2d(int(dim*2**2) + 256,int(dim*2**2),kernel_size=1,bias=bias)
        self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])

        self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
        self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
        
        self.up4_3 = Upsample(int(dim*2**2)) ## From Level 4 to Level 3
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**1)+192, int(dim*2**2), kernel_size=1, bias=bias)
        self.noise_level3 = TransformerBlock(dim=int(dim*2**2) + 512, num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)
        self.reduce_noise_level3 = nn.Conv2d(int(dim*2**2)+512,int(dim*2**2),kernel_size=1,bias=bias)


        self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])


        self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.noise_level2 = TransformerBlock(dim=int(dim*2**1) + 224, num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)
        self.reduce_noise_level2 = nn.Conv2d(int(dim*2**1)+224,int(dim*2**2),kernel_size=1,bias=bias)


        self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        
        self.up2_1 = Upsample(int(dim*2**1))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

        self.noise_level1 = TransformerBlock(dim=int(dim*2**1)+64, num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)
        self.reduce_noise_level1 = nn.Conv2d(int(dim*2**1)+64,int(dim*2**1),kernel_size=1,bias=bias)


        self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        
        self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
                    
        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, inp_img,noise_emb = None):

        inp_enc_level1 = self.patch_embed(inp_img)

        out_enc_level1 = self.encoder_level1(inp_enc_level1)
        
        inp_enc_level2 = self.down1_2(out_enc_level1)

        out_enc_level2 = self.encoder_level2(inp_enc_level2)

        inp_enc_level3 = self.down2_3(out_enc_level2)

        out_enc_level3 = self.encoder_level3(inp_enc_level3) 

        inp_enc_level4 = self.down3_4(out_enc_level3)        
        latent = self.latent(inp_enc_level4)
        if self.decoder:
            dec3_param = self.prompt3(latent)

            latent = torch.cat([latent, dec3_param], 1)
            latent = self.noise_level3(latent)
            latent = self.reduce_noise_level3(latent)
                        
        inp_dec_level3 = self.up4_3(latent)

        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)

        out_dec_level3 = self.decoder_level3(inp_dec_level3) 
        if self.decoder:
            dec2_param = self.prompt2(out_dec_level3)
            out_dec_level3 = torch.cat([out_dec_level3, dec2_param], 1)
            out_dec_level3 = self.noise_level2(out_dec_level3)
            out_dec_level3 = self.reduce_noise_level2(out_dec_level3)

        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)

        out_dec_level2 = self.decoder_level2(inp_dec_level2)
        if self.decoder:
           
            dec1_param = self.prompt1(out_dec_level2)
            out_dec_level2 = torch.cat([out_dec_level2, dec1_param], 1)
            out_dec_level2 = self.noise_level1(out_dec_level2)
            out_dec_level2 = self.reduce_noise_level1(out_dec_level2)
        
        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        
        out_dec_level1 = self.decoder_level1(inp_dec_level1)

        out_dec_level1 = self.refinement(out_dec_level1)


        out_dec_level1 = self.output(out_dec_level1) + inp_img


        return out_dec_level1

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1278903.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

node.js express路由和中间件

目录 路由 解释 使用方式 中间件 解释 使用方式 中间件类型 路由注册和中间件注册 代码 app全局路由接口请求以及代码解析 示例1 示例2 示例3 示例4 中间件req继承 嵌套子路由 解释 代码 示例1 路由 解释 在 Express 中,路由(Route&…

Spring | Spring的基本应用

目录: 1.什么是Spring?2.Spring框架的优点3.Spring的体系结构 (重点★★★) :3.1 Core Container (核心容器) ★★★Beans模块 (★★★) : BeanFactoryCore核心模块 (★★★) : IOCContext上下文模块 (★★★) : ApplicationContextContext-support模块 (★★★)SpE…

Linux中文件的打包压缩、解压,下载到本地——zip,tar指令等

目录 1 .zip后缀名: 1.1 zip指令 1.2 unzip指令 2 .tar后缀名 3. sz 指令 4. rz 指令 5. scp指令 1 .zip后缀名: 1.1 zip指令 语法:zip [namefile.zip] [namefile]... 功能:将目录或者文件压缩成zip格式 常用选项&#xff1a…

《YOLOv8原创自研》专栏介绍 CSDN独家改进创新实战专栏目录

YOLOv8原创自研 https://blog.csdn.net/m0_63774211/category_12511737.html?spm1001.2014.3001.5482 💡💡💡全网独家首发创新(原创),适合paper !!! 💡&a…

QT 中 QProgressDialog 进度条窗口 备查

基础API //两个构造函数 QProgressDialog::QProgressDialog(QWidget *parent nullptr, Qt::WindowFlags f Qt::WindowFlags());QProgressDialog::QProgressDialog(const QString &labelText, const QString &cancelButtonText, int minimum, int maximum, QWidget *…

java源码-类与对象

1、类与对象的初步认知 在了解类和对象之前我们先了解一下什么是面向过程和面向对象。 1)面向过程编程: C语言就是面向过程编程的,关注的是过程,分析出求解问题的步骤,通过函数调用逐步解决问题。 2)面向对…

【数据结构】最短路径(Dijskra算法)

一.引例 计算机网络传输的问题: 怎样找到一种最经济的方式,从一台计算机向网上所有其他计算机发送一条消息。 抽象为: 给定带权有向图G(V,E)和源点v,求从v到G中其余各顶点的最短路径。 即&…

23.Python 图形化界面编程

目录 1.认识GUI和使用tkinter2.使用组件2.1 标签2.2 按钮2.3 文本框2.4 单选按钮和复选按钮2.5 菜单和消息2.6 列表框2.7 滚动条2.8 框架2.9 画布 3. 组件布局4.事件处理 1.认识GUI和使用tkinter 人机交互是从人努力适应计算机,到计算机不断适应人的发展过程&#…

Redis中的数据结构

文章目录 第1关:Redis中的数据结构 第1关:Redis中的数据结构 这是上篇文章的第一关,只不过本篇是代码按行做的,方便一下大家使用。 代码如下: redis-cliset hello redislpush educoder-list hellorpush educoder-lis…

cmake和vscode 下的cmake的使用详解(三)

第七讲:【实战】使用 VSCode 进行完整项目开发 案例:士兵突击 需求: 1. 士兵 许三多 有一把 枪 ,叫做 AK47 2. 士兵 可以 开火 3. 士兵 可以 给枪装填子弹 4. 枪 能够 发射 子弹 5. 枪 能够 装填子弹 ——…

初识RocketMQ

1、简介 RocketMQ 是阿里巴巴在 2012 年开源的消息队列产品,用 Java 语言实现,在设计时参考了 Kafka,并做出了自己的一些改进,后来捐赠给 Apache 软件基金会,2017 正式毕业,成为 Apache 的顶级项目。Rocket…

canvas基础:绘制贝塞尔曲线

canvas实例应用100 专栏提供canvas的基础知识,高级动画,相关应用扩展等信息。 canvas作为html的一部分,是图像图标地图可视化的一个重要的基础,学好了canvas,在其他的一些应用上将会起到非常重要的帮助。 文章目录 bez…

pygame实现贪吃蛇小游戏

import pygame import random# 游戏初始化 pygame.init()# 游戏窗口设置 win_width, win_height 800, 600 window pygame.display.set_mode((win_width, win_height)) pygame.display.set_caption("Snake Game")# 颜色设置 WHITE (255, 255, 255) BLACK (0, 0, 0…

docker容器内部文件挂载主机

docker images执行该命令可以发现一个centos镜像 docker run --namemycentos -itd --privilegedtrue --restartalways -p 88:80 -v C:\Users\Administrator\Desktop\dockerTest:/bin/gh:ro centosdocker run 命令用于在 Docker 上创建和运行容器。 --namemycentos 指定容器…

SHAP(四):NHANES I 生存模型

SHAP(四):NHANES I 生存模型 这是一个 Cox 比例风险模型,基于来自 NHANES I 的数据以及来自 NHANES I 流行病学随访研究。 它旨在说明 SHAP 值如何能够以传统上仅由线性模型提供的清晰度解释 XGBoost 模型。 我们在数据中看到有趣…

JDK17的安装与配置

JDK17的安装与配置 下载地址安装步骤配置环境变量验证安装是否成功 下载地址 此jdk17安装的系统是win10系统 https://www.oracle.com/java/technologies/downloads/ 这里选择JDK17进行下载 下载完成之后,显示如下图: 安装步骤 自定义的安装路径&…

【从删库到跑路 | MySQL总结篇】事务详细介绍

个人主页:兜里有颗棉花糖 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 兜里有颗棉花糖 原创 收录于专栏【MySQL学习专栏】🎈 本专栏旨在分享学习MySQL的一点学习心得,欢迎大家在评论区讨论💌 目录 一、事务…

多项式拟合求解

目录 简介 基本原理 例1 例2 例3 参考资料 简介 多项式拟合可以用最小二乘求解,不管是一元高阶函数,还是多元多项式函数,还是二者的混合,都可以通过统一的方法求解。当然除了最小二乘法,还是其他方法可以求解&…

分享一个国内可用的免费GPT4-AI提问AI绘画网站工具

一、前言 ChatGPT GPT4.0,Midjourney绘画,相信对大家应该不感到陌生吧?简单来说,GPT-4技术比之前的GPT-3.5相对来说更加智能,会根据用户的要求生成多种内容甚至也可以和用户进行创作交流。 然而,GPT-4对普…

【尾递归】

尾递归 如果函数在返回前才进行递归调用,则该函数可以被编译器或解释器优化,使其在空间效率上与迭代相当。这种情况被称为「尾递归 tail recursion」。 普通递归:当函数返回到上一层级的函数后,需要继续执行代码,因此…