AIRNet模型使用与代码分析(All-In-One Image Restoration Network)

news2024/12/23 3:24:52

AIRNet提出了一种较为简易的pipeline,以单一网络结构应对多种任务需求(不同类型,不同程度)。但在效果上看,ALL-In-One是不如One-By-One的,且本文方法的亮点是batch内选择patch进行对比学习。在与sota对比上,仅是Denoise任务精度占优,在Derain与Dehaze任务上,效果不如One-By-One的MPRNet方法。本博客对AIRNet的关键结构实现,loss实现,data_patch实现进行深入分析,并对模型进行推理使用。

其论文的详细可以阅读:https://blog.csdn.net/a486259/article/details/139559389?spm=1001.2014.3001.5501

项目地址:https://blog.csdn.net/a486259/article/details/139559389?spm=1001.2014.3001.5501

项目依赖:torch、mmcv-full
安装mmcv-full时,需要注意torch所对应的cuda版本,要与系统中的cuda版本一致。

1、模型结构

AirNet的网络结构如下所示,输入图像x交由CBDE提取到嵌入空间z,z与x输入到DGRN模块的DGG block中逐步优化,最终输出预测结果。
在这里插入图片描述
模型代码在net\model.py

from torch import nn

from net.encoder import CBDE
from net.DGRN import DGRN


class AirNet(nn.Module):
    def __init__(self, opt):
        super(AirNet, self).__init__()
        # Encoder
        self.E = CBDE(opt)  #编码特征值

        # Restorer
        self.R = DGRN(opt) #特征解码


    def forward(self, x_query, x_key):
        if self.training:
            fea, logits, labels, inter = self.E(x_query, x_key)

            restored = self.R(x_query, inter)

            return restored, logits, labels
        else:
            fea, inter = self.E(x_query, x_query)

            restored = self.R(x_query, inter)

            return restored

1.1 CBDE模块

CBDE模块的功能是在模块内进行对比学习,核心是MoCo. Moco论文地址:https://arxiv.org/pdf/1911.05722

class CBDE(nn.Module):
    def __init__(self, opt):
        super(CBDE, self).__init__()

        dim = 256

        # Encoder
        self.E = MoCo(base_encoder=ResEncoder, dim=dim, K=opt.batch_size * dim)

    def forward(self, x_query, x_key):
        if self.training:
            # degradation-aware represenetion learning
            fea, logits, labels, inter = self.E(x_query, x_key)

            return fea, logits, labels, inter
        else:
            # degradation-aware represenetion learning
            fea, inter = self.E(x_query, x_query)
            return fea, inter

ResEncoder所对应的网络结构如下所示
在这里插入图片描述

在AIRNet中的CBDE模块里的MoCo模块的关键代码如下,其在内部自行完成了正负样本的分配,最终输出logits, labels用于计算对比损失的loss。但其所优化的模块实际上是ResEncoder。MoCo模块只是在训练阶段起作用,在推理阶段是不起作用的。

class MoCo(nn.Module):
    def forward(self, im_q, im_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """
        if self.training:
            # compute query features
            embedding, q, inter = self.encoder_q(im_q)  # queries: NxC
            q = nn.functional.normalize(q, dim=1)

            # compute key features
            with torch.no_grad():  # no gradient to keys
                self._momentum_update_key_encoder()  # update the key encoder

                _, k, _ = self.encoder_k(im_k)  # keys: NxC
                k = nn.functional.normalize(k, dim=1)
            # compute logits
            # Einstein sum is more intuitive
            # positive logits: Nx1
            l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
            # negative logits: NxK
            l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

            # logits: Nx(1+K)
            logits = torch.cat([l_pos, l_neg], dim=1)

            # apply temperature
            logits /= self.T

            # labels: positive key indicators
            labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

            # dequeue and enqueue
            self._dequeue_and_enqueue(k)

            return embedding, logits, labels, inter
        else:
            embedding, _, inter = self.encoder_q(im_q)

            return embedding, inter

1.2 DGRN模块

DGRN模块的实现代码如下所示,可以看到核心是DGG模块,其不断迭代优化输入图像。

class DGRN(nn.Module):
    def __init__(self, opt, conv=default_conv):
        super(DGRN, self).__init__()

        self.n_groups = 5
        n_blocks = 5
        n_feats = 64
        kernel_size = 3

        # head module
        modules_head = [conv(3, n_feats, kernel_size)]
        self.head = nn.Sequential(*modules_head)

        # body
        modules_body = [
            DGG(default_conv, n_feats, kernel_size, n_blocks) \
            for _ in range(self.n_groups)
        ]
        modules_body.append(conv(n_feats, n_feats, kernel_size))
        self.body = nn.Sequential(*modules_body)

        # tail
        modules_tail = [conv(n_feats, 3, kernel_size)]
        self.tail = nn.Sequential(*modules_tail)

    def forward(self, x, inter):
        # head
        x = self.head(x)

        # body
        res = x
        for i in range(self.n_groups):
            res = self.body[i](res, inter)
        res = self.body[-1](res)
        res = res + x

        # tail
        x = self.tail(res)

        return x

在这里插入图片描述
DGG模块的结构示意如下所示
在这里插入图片描述
DGG代码实现如下所示,DGG模块内嵌DGB模块,DGB模块内嵌DGM模块,DGM模块内嵌SFT_layer模块与DCN_layer(可变性卷积)
在这里插入图片描述

2、loss实现

AIRNet中提到的loss如下所示,其中Lrec是L1 loss,Lcl是Moco模块实现的对比损失。
在这里插入图片描述
AIRNet的loss实现代码在train.py中,CE loss是针对CBDE(Moco模块)的输出进行计算,l1 loss是针对修复图像与清晰图片。

    # Network Construction
    net = AirNet(opt).cuda()
    net.train()

    # Optimizer and Loss
    optimizer = optim.Adam(net.parameters(), lr=opt.lr)
    CE = nn.CrossEntropyLoss().cuda()
    l1 = nn.L1Loss().cuda()

    # Start training
    print('Start training...')
    for epoch in range(opt.epochs):
        for ([clean_name, de_id], degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2) in tqdm(trainloader):
            degrad_patch_1, degrad_patch_2 = degrad_patch_1.cuda(), degrad_patch_2.cuda()
            clean_patch_1, clean_patch_2 = clean_patch_1.cuda(), clean_patch_2.cuda()

            optimizer.zero_grad()

            if epoch < opt.epochs_encoder:
                _, output, target, _ = net.E(x_query=degrad_patch_1, x_key=degrad_patch_2)
                contrast_loss = CE(output, target)
                loss = contrast_loss
            else:
                restored, output, target = net(x_query=degrad_patch_1, x_key=degrad_patch_2)
                contrast_loss = CE(output, target)
                l1_loss = l1(restored, clean_patch_1)
                loss = l1_loss + 0.1 * contrast_loss

            # backward
            loss.backward()
            optimizer.step()

这里可以看出来,AIRNet首先是训练CBDE模块,最后才训练CBDE模块+DGRN模块。

3、TrainDataset

TrainDataset的实现代码在utils\dataset_utils.py中,首先找到__getitem__函数进行分析。以下代码为关键部分,删除了大部分在逻辑上重复的部分。TrainDataset一共支持5种数据类型,‘denoise_15’: 0, ‘denoise_25’: 1, ‘denoise_50’: 2,是不需要图像对的(在代码里面自动对图像添加噪声);‘derain’: 3, ‘dehaze’: 4是需要图像对进行训练的。

class TrainDataset(Dataset):
    def __init__(self, args):
        super(TrainDataset, self).__init__()
        self.args = args
        self.rs_ids = []
        self.hazy_ids = []
        self.D = Degradation(args)
        self.de_temp = 0
        self.de_type = self.args.de_type

        self.de_dict = {'denoise_15': 0, 'denoise_25': 1, 'denoise_50': 2, 'derain': 3, 'dehaze': 4}

        self._init_ids()

        self.crop_transform = Compose([
            ToPILImage(),
            RandomCrop(args.patch_size),
        ])
        self.toTensor = ToTensor()
        
    def __getitem__(self, _):
        de_id = self.de_dict[self.de_type[self.de_temp]]

        if de_id < 3:
            if de_id == 0:
                clean_id = self.s15_ids[self.s15_counter]
                self.s15_counter = (self.s15_counter + 1) % self.num_clean
                if self.s15_counter == 0:
                    random.shuffle(self.s15_ids)

            # clean_id = random.randint(0, len(self.clean_ids) - 1)
            clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16)
            clean_patch_1, clean_patch_2 = self.crop_transform(clean_img), self.crop_transform(clean_img)
            clean_patch_1, clean_patch_2 = np.array(clean_patch_1), np.array(clean_patch_2)

            # clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]
            clean_name = clean_id.split("/")[-1].split('.')[0]

            clean_patch_1, clean_patch_2 = random_augmentation(clean_patch_1, clean_patch_2)
            degrad_patch_1, degrad_patch_2 = self.D.degrade(clean_patch_1, clean_patch_2, de_id)

        clean_patch_1, clean_patch_2 = self.toTensor(clean_patch_1), self.toTensor(clean_patch_2)
        degrad_patch_1, degrad_patch_2 = self.toTensor(degrad_patch_1), self.toTensor(degrad_patch_2)

        self.de_temp = (self.de_temp + 1) % len(self.de_type)
        if self.de_temp == 0:
            random.shuffle(self.de_type)

        return [clean_name, de_id], degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2

可以看出TrainDataset返回的数据有:degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2。

3.1 clean_patch分析

通过以下代码可以看出 clean_patch_1, clean_patch_2是来自于同一个图片,然后基于crop_transform变化,变成了2个对象

            clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16)
            clean_patch_1, clean_patch_2 = self.crop_transform(clean_img), self.crop_transform(clean_img)
            # clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]
            clean_name = clean_id.split("/")[-1].split('.')[0]

            clean_patch_1, clean_patch_2 = random_augmentation(clean_patch_1, clean_patch_2)

crop_transform的定义如下,可见是随机进行crop

crop_transform = Compose([
            ToPILImage(),
            RandomCrop(args.patch_size),
        ])

random_augmentation的实现代码如下,可以看到只是随机对图像进行翻转或旋转,其目的是尽可能使随机crop得到clean_patch_1, clean_patch_2差异更大,避免裁剪出高度相似的patch。

def random_augmentation(*args):
    out = []
    flag_aug = random.randint(1, 7)
    for data in args:
        out.append(data_augmentation(data, flag_aug).copy())
    return out
def data_augmentation(image, mode):
    if mode == 0:
        # original
        out = image.numpy()
    elif mode == 1:
        # flip up and down
        out = np.flipud(image)
    elif mode == 2:
        # rotate counterwise 90 degree
        out = np.rot90(image)
    elif mode == 3:
        # rotate 90 degree and flip up and down
        out = np.rot90(image)
        out = np.flipud(out)
    elif mode == 4:
        # rotate 180 degree
        out = np.rot90(image, k=2)
    elif mode == 5:
        # rotate 180 degree and flip
        out = np.rot90(image, k=2)
        out = np.flipud(out)
    elif mode == 6:
        # rotate 270 degree
        out = np.rot90(image, k=3)
    elif mode == 7:
        # rotate 270 degree and flip
        out = np.rot90(image, k=3)
        out = np.flipud(out)
    else:
        raise Exception('Invalid choice of image transformation')
    return out
    

3.2 degrad_patch分析

degrad_patch来自于clean_patch,可以看到是通过D.degrade进行转换的。

degrad_patch_1, degrad_patch_2 = self.D.degrade(clean_patch_1, clean_patch_2, de_id)

D.degrade相关的代码如下,可以看到只是对图像添加噪声。难怪AIRNet在图像去噪上效果最好。

class Degradation(object):
    def __init__(self, args):
        super(Degradation, self).__init__()
        self.args = args
        self.toTensor = ToTensor()
        self.crop_transform = Compose([
            ToPILImage(),
            RandomCrop(args.patch_size),
        ])

    def _add_gaussian_noise(self, clean_patch, sigma):
        # noise = torch.randn(*(clean_patch.shape))
        # clean_patch = self.toTensor(clean_patch)
        noise = np.random.randn(*clean_patch.shape)
        noisy_patch = np.clip(clean_patch + noise * sigma, 0, 255).astype(np.uint8)
        # noisy_patch = torch.clamp(clean_patch + noise * sigma, 0, 255).type(torch.int32)
        return noisy_patch, clean_patch

    def _degrade_by_type(self, clean_patch, degrade_type):
        if degrade_type == 0:
            # denoise sigma=15
            degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=15)
        elif degrade_type == 1:
            # denoise sigma=25
            degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=25)
        elif degrade_type == 2:
            # denoise sigma=50
            degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=50)

        return degraded_patch, clean_patch

    def degrade(self, clean_patch_1, clean_patch_2, degrade_type=None):
        if degrade_type == None:
            degrade_type = random.randint(0, 3)
        else:
            degrade_type = degrade_type

        degrad_patch_1, _ = self._degrade_by_type(clean_patch_1, degrade_type)
        degrad_patch_2, _ = self._degrade_by_type(clean_patch_2, degrade_type)
        return degrad_patch_1, degrad_patch_2

4、推理演示

项目中默认包含了All.pth,要单独任务的模型可以到预训练模型下载地址: Google Drive and Baidu Netdisk (password: cr7d). 下载模型放到 ckpt/ 目录下

打开demo.py,将 subprocess.check_output(['mkdir', '-p', opt.output_path]) 替换为os.makedirs(opt.output_path,exist_ok=True),避免在window上报错,具体修改如下所示
在这里插入图片描述

demo.py默认从test\demo目录下读取图片进行测试,可见原始图像如下
在这里插入图片描述
代码运行后的输出结果默认保存在 output\demo目录下,可见对于去雨,去雾,去噪声效果都比较好。
在这里插入图片描述
模型推理时间如下所示,可以看到对一张320, 480的图片,要0.54s
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1806789.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电影制作中的版本控制:Perforce Helix Core帮助某电影短片避免灾难性文件损坏,简化艺术资产管理

Zubaida Nila是来自马来西亚的一名视觉特效师和虚拟制作研究员&#xff0c;她参加了Epic Games的一个为期六周的虚拟培训和指导项目——女性创作者计划。该计划提供了虚幻引擎工作流程的实践经验以及其他课程。Zubaida希望从中获得更多关于虚幻引擎的灯光、后期处理和特效技能方…

csrf与xss差别 别在弄乱了 直接靶场实操pikachu的csrf题 token绕过可以吗???

我们现在来说说这2个之间的关系&#xff0c;因为昨天的我也没有弄清楚这2者的关系&#xff0c;总感觉迷迷糊糊的。 xss这个漏洞是大家并不怎么陌生&#xff0c;导致xss漏洞的产生是服务器没有对用户提交数据过滤不严格&#xff0c;导致浏览器把用户输入的当作js代码返回客户端…

玉米粒计数检测数据集VOC+YOLO格式107张1类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;107 标注数量(xml文件个数)&#xff1a;107 标注数量(txt文件个数)&#xff1a;107 标注类别…

群体优化算法----树蛙优化算法介绍以及应用于资源分配示例

介绍 树蛙优化算法&#xff08;Tree Frog Optimization Algorithm, TFO&#xff09;是一种基于群体智能的优化算法&#xff0c;模拟了树蛙在自然环境中的跳跃和觅食行为。该算法通过模拟树蛙在树枝间的跳跃来寻找最优解&#xff0c;属于近年来发展起来的自然启发式算法的一种 …

c# iText使用

引入包 用nuget安装itext和itext.bouncy-castle-adapter包&#xff1a; 创建pdf string path "a.pdf"; PdfWriter writer new PdfWriter(path); PdfDocument pdfDoc new PdfDocument(writer); var docnew Document(pdfDoc); Paragraph p new Paragraph(&quo…

基于I2C协议的OLED显示(利用U82G库)

目录 一、实验目的 二、 U8g2下载 三、利用stm32f103的GPIO管脚、VCC和GND连接 OLED屏的I2C接口&#xff0c;采用cubemx设计一个HAL库程序框架&#xff0c;然后下载U82G源码&#xff0c;针对stm32f103和 0.96寸的I2C接口OLED屏&#xff0c;进行代码裁剪&#xff0c;然后移植到…

Fences 5 激活码 - 电脑桌面整理软件

提起桌面整理&#xff0c;经典老牌工具 Fences 必有一席之地&#xff0c;Stardock 发布了最新的 Fences 5 版本。 可以将文件和图标归类放入各个栅栏分区&#xff0c;并支持文件夹展开至桌面、分区置顶、淡化隐藏图标等功能&#xff0c;能让你的桌面焕然一新&#xff0c;不再混…

电阻十大品牌供应商

选型时选择热门的电阻品牌&#xff0c;主要是产品丰富&#xff0c;需求基本都能满足。 所所有的电路中&#xff0c;基本没有不用电阻的&#xff0c;电阻的选型需要参考阻值、精度、封装、温度范围&#xff0c;贴片/插件等参数&#xff0c;优秀的供应商如下&#xff1a; 十大电…

Cweek4+5

C语言学习 十.指针详解 6.有关函数指针的代码 代码1&#xff1a;(*(void (*)())0)(); void(*)()是函数指针类型&#xff0c;0是一个函数的地址 (void(*)())是强制转换 总的是调用0地址处的函数&#xff0c;传入参数为空 代码2&#xff1a;void (*signal(int, void(*)(int))…

系统思考—心智模式

凯恩斯说&#xff1a;“介绍新观念倒不是很难&#xff0c;难的是清除那些旧观念。”在过去的任何一年&#xff0c;如果你一次都没有推翻过自己最中意的想法&#xff0c;那么你这一年就算浪费了。旧观念像是根深蒂固的杂草&#xff0c;即使在新知识的光照下&#xff0c;也需要时…

Docker Desktop - WSL distro terminated abruptly

打开 PowerShell 或以管理员身份运行的命令提示符。运行以下命令以列出已安装的 WSL 分发&#xff1a; wsl --list 运行以下命令以注销 Docker 相关的分发 wsl --unregister <distro_name> 将<distro_name>替换为实际的 Docker 相关分发的名称。将<distro_…

模型 利特尔法则

说明&#xff1a;系列文章 分享 模型&#xff0c;了解更多&#x1f449; 模型_思维模型目录。揭示流量、存量、时间的数学关系。 1 利特尔法则的应用 1.1 银行服务系统的优化 一家银行希望优化其服务系统以减少客户的等待时间并提高服务效率。银行决定使用利特尔法则来分析和…

string经典题目(C++)

文章目录 前言一、最长回文子串1.题目解析2.算法原理3.代码编写 二、字符串相乘1.题目解析2.算法原理3.代码编写 总结 前言 一、最长回文子串 1.题目解析 给你一个字符串 s&#xff0c;找到 s 中最长的回文子串。 示例 1&#xff1a; 输入&#xff1a;s “babad” 输出&am…

人工智能系统越来越擅长欺骗我们?

人工智能系统越来越擅长欺骗我们&#xff1f; 一波人工智能系统以他们没有被明确训练过的方式“欺骗”人类&#xff0c;通过为他们的行为提供不真实的解释&#xff0c;或者向人类用户隐瞒真相并误导他们以达到战略目的。 发表在《模式》(Patterns)杂志上的一篇综述论文总结了之…

红黑树的介绍与实现

前言 前面我们介绍了AVL树&#xff0c;AVL树是一棵非常自律的树&#xff0c;有着严格的高度可控制&#xff01;但是正它的自律给他带来了另一个问题&#xff0c;即虽然他的查找效率很高&#xff0c;但是插入和删除由于旋转而导致效率没有那么高。我们上一期的结尾说过经常修改…

Java SE(Java Platform, Standard Edition)

Java SE&#xff08;Java Platform, Standard Edition&#xff09; 是Java平台的一个版本&#xff0c;面向桌面应用程序、服务器和嵌入式环境。Java SE提供了开发和运行Java应用程序的基础API&#xff08;Application Programming Interface&#xff0c;应用程序编程接口&…

Docker之路(三)docker安装nginx实现对springboot项目的负载均衡

Docker之路&#xff08;三&#xff09;dockernginxspringboot负载均衡 前言&#xff1a;一、安装docker二、安装nginx三、准备好我们的springboot项目四、将springboot项目分别build成docker镜像五、配置nginx并且启动六、nginx的负载均衡策略七、nginx的常用属性八、总结 前言…

【leetcode--盛水最多的容器】

给定一个长度为 n 的整数数组 height 。有 n 条垂线&#xff0c;第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线&#xff0c;使得它们与 x 轴共同构成的容器可以容纳最多的水。 返回容器可以储存的最大水量。 写出来了一半&#xff0c;想到用双指针&am…

大数据数仓的数据回溯

在大数据领域&#xff0c;数据回溯是一项至关重要的任务&#xff0c;它涉及到对历史数据的重新处理以确保数据的准确性和一致性。 数据回溯的定义与重要性 数据回溯&#xff0c;也称为数据补全&#xff0c;是指在数据模型迭代或新模型上线后&#xff0c;对历史数据进行重新处理…

VisionPro的应用和入门教程

第1章 关于VisionPro 1.1 康耐视的核心技术 1. 先进的视觉系统 康耐视的视觉系统结合了高性能的图像传感器、复杂的算法和强大的计算能力&#xff0c;能够实时捕捉、分析和处理高分辨率图像。其视觉系统包括固定式和手持式两种&#xff0c;适用于各种工业环境。无论是精密电…