YOLOv10改进 | 特殊场景检测篇 | 单阶段盲真实图像去噪网络RIDNet辅助YOLOv10图像去噪(全网独家首发)

news2024/9/21 8:03:18

 一、本文介绍

本文给大家带来的改进机制是单阶段盲真实图像去噪网络RIDNet,RIDNet(Real Image Denoising with Feature Attention)是一个用于真实图像去噪的卷积神经网络(CNN),旨在解决现有去噪方法在处理真实噪声图像时性能受限的问题。通过单阶段结构和特征注意机制,RIDNet在多种数据集上展示了其优越性,下面的图片为其效果图片包括和其它图像图像网络的对比图。

欢迎大家订阅我的专栏一起学习YOLO!   

  专栏回顾:YOLOv10改进系列专栏——本专栏持续复习各种顶会内容——科研必备 


目录

 一、本文介绍

二、RIDNet 网络的原理和机制

三、核心代码 

四、手把手教你添加RIDNet 

4.1 修改一

4.2 修改二 

4.3 修改三 

4.4 修改四 

五、RIDNet 的yaml文件和运行记录

5.1 RIDNet 的yaml文件

5.2 训练代码 

5.3 RIDNet 的训练过程截图 

五、本文总结


二、RIDNet 网络的原理和机制

官方论文地址: 官方论文地址点击此处即可跳转

官方代码地址: 官方代码地址点击此处即可跳转


RIDNet(Real Image Denoising with Feature Attention)是一个用于真实图像去噪的卷积神经网络(CNN),旨在解决现有去噪方法在处理真实噪声图像时性能受限的问题。通过单阶段结构和特征注意机制,RIDNet在多种数据集上展示了其优越性。

网络架构
RIDNet由三个主要模块组成:

1. 特征提取模块(Feature Extraction Module):
   该模块包含一个卷积层,旨在从输入的噪声图像中提取初始特征。

2. 特征学习模块(Feature Learning Module):

  •    核心部分是增强注意模块(Enhanced Attention Module,EAM),使用残差在残差结构(Residual on Residual)和特征注意机制来增强特征学习能力。
  •    EAM包括两个主要部分:
  •    特征提取子模块:通过两个膨胀卷积层和一个合并卷积层提取和学习特征。
  •     特征注意子模块:使用全局平均池化和自门控机制生成特征注意力,调整每个通道的特征权重,以突出重要特征。

3. 重建模块(Reconstruction Module):
   包含一个卷积层,将学习到的特征重建为去噪后的图像。

3. 特征注意机制
特征注意机制通过以下步骤实现:
1. 全局平均池化(Global Average Pooling):
   将特征图从 `h × w × c` 缩小到 `1 × 1 × c`,捕捉全局上下文信息。
2. 自门控机制(Self-Gating Mechanism):
   使用两个卷积层和Sigmoid激活函数调整特征图的权重。
   第一个卷积层减少通道数,第二个卷积层恢复原始通道数,最终生成特征注意力。
3. 特征调整(Feature Rescaling):
   通过特征注意力调整每个通道的特征权重,增强重要特征。

跳跃连接
RIDNet使用了多种跳跃连接,以确保信息的高效流动和梯度的稳定传递:
1. 长跳跃连接(Long Skip Connection,LSC):
   连接输入和EAM模块的输出,增强信息流动。
2. 短跳跃连接(Short Skip Connection,SSC):
   连接EAM内部的不同层,确保特征图的有效传递。
3.局部连接(Local Connection,LC)**:
   在EAM内部实现局部特征学习。

结论:RIDNet在多个合成和真实噪声数据集上进行了广泛的实验,展示了其在定量指标(如PSNR)和视觉质量上的优越性。与现有最先进的算法相比,RIDNet在处理合成噪声和真实噪声图像时均表现出色。RIDNet通过引入特征注意机制和残差在残差结构,实现了对真实图像去噪的有效处理。其单阶段结构、跳跃连接和特征注意机制确保了高效的特征学习和信息传递,使其在多个数据集上均取得了优异的性能。


三、核心代码 

这个代码基础版本原先有1000+GFLOPs,我将其Block层数优化了一些,并将通道数减少了一部分将参数量降低到了20+。

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias)


class MeanShift(nn.Conv2d):
    def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.weight.data.div_(std.view(3, 1, 1, 1))
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
        self.bias.data.div_(std)
        self.requires_grad = False


def init_weights(modules):
    pass



class Merge_Run(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 ksize=3, stride=1, pad=1, dilation=1):
        super(Merge_Run, self).__init__()

        self.body1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
            nn.ReLU(inplace=True)
        )
        self.body2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, ksize, stride, 2, 2),
            nn.ReLU(inplace=True)
        )

        self.body3 = nn.Sequential(
            nn.Conv2d(in_channels * 2, out_channels, ksize, stride, pad),
            nn.ReLU(inplace=True)
        )
        init_weights(self.modules)

    def forward(self, x):
        out1 = self.body1(x)
        out2 = self.body2(x)
        c = torch.cat([out1, out2], dim=1)
        c_out = self.body3(c)
        out = c_out + x
        return out


class Merge_Run_dual(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 ksize=3, stride=1, pad=1, dilation=1):
        super(Merge_Run_dual, self).__init__()

        self.body1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, ksize, stride, 2, 2),
            nn.ReLU(inplace=True)
        )
        self.body2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, ksize, stride, 3, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, ksize, stride, 4, 4),
            nn.ReLU(inplace=True)
        )

        self.body3 = nn.Sequential(
            nn.Conv2d(in_channels * 2, out_channels, ksize, stride, pad),
            nn.ReLU(inplace=True)
        )
        init_weights(self.modules)

    def forward(self, x):
        out1 = self.body1(x)
        out2 = self.body2(x)
        c = torch.cat([out1, out2], dim=1)
        c_out = self.body3(c)
        out = c_out + x
        return out


class BasicBlock(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 ksize=3, stride=1, pad=1):
        super(BasicBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
            nn.ReLU(inplace=True)
        )

        init_weights(self.modules)

    def forward(self, x):
        out = self.body(x)
        return out


class BasicBlockSig(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 ksize=3, stride=1, pad=1):
        super(BasicBlockSig, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
            nn.Sigmoid()
        )

        init_weights(self.modules)

    def forward(self, x):
        out = self.body(x)
        return out


class ResidualBlock(nn.Module):
    def __init__(self,
                 in_channels, out_channels):
        super(ResidualBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
        )

        init_weights(self.modules)

    def forward(self, x):
        out = self.body(x)
        out = F.relu(out + x)
        return out


class EResidualBlock(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 group=1):
        super(EResidualBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 1, 1, 0),
        )

        init_weights(self.modules)

    def forward(self, x):
        out = self.body(x)
        out = F.relu(out + x)
        return out





class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.c1 = BasicBlock(channel, channel // reduction, 1, 1, 0)
        self.c2 = BasicBlockSig(channel // reduction, channel, 1, 1, 0)

    def forward(self, x):
        y = self.avg_pool(x)
        y1 = self.c1(y)
        y2 = self.c2(y1)
        return x * y2


class Block(nn.Module):
    def __init__(self, in_channels, out_channels, group=1):
        super(Block, self).__init__()

        self.r1 = Merge_Run_dual(in_channels, out_channels)
        self.r2 = ResidualBlock(in_channels, out_channels)
        self.r3 = EResidualBlock(in_channels, out_channels)
        # self.g = ops.BasicBlock(in_channels, out_channels, 1, 1, 0)
        self.ca = CALayer(in_channels)

    def forward(self, x):
        r1 = self.r1(x)
        r2 = self.r2(r1)
        r3 = self.r3(r2)
        # g = self.g(r3)
        out = self.ca(r3)

        return out


class RIDNET(nn.Module):
    def __init__(self, args):
        super(RIDNET, self).__init__()

        n_feats = 16
        kernel_size = 3
        rgb_range = 255
        mean = (0.4488, 0.4371, 0.4040)
        std = (1.0, 1.0, 1.0)

        self.sub_mean = MeanShift(rgb_range, mean, std)
        self.add_mean = MeanShift(rgb_range, mean, std, 1)

        self.head = BasicBlock(3, n_feats, kernel_size, 1, 1)

        self.b1 = Block(n_feats, n_feats)
        self.b2 = Block(n_feats, n_feats)
        self.b3 = Block(n_feats, n_feats)
        self.b4 = Block(n_feats, n_feats)

        self.tail = nn.Conv2d(n_feats, 3, kernel_size, 1, 1, 1)

    def forward(self, x):
        s = self.sub_mean(x)
        h = self.head(s)

        # b1 = self.b1(h)
        # b2 = self.b2(b1)
        # b3 = self.b3(b2)
        b_out = self.b4(h)

        res = self.tail(b_out)
        out = self.add_mean(res)
        f_out = out + x

        return f_out



if __name__ == "__main__":

    # Generating Sample image
    image_size = (1, 3, 640, 640)
    image = torch.rand(*image_size)

    # Model
    model = RIDNET(3)

    out = model(image)
    print(out.size())


四、手把手教你添加RIDNet 

4.1 修改一

第一还是建立文件,我们找到如下ultralytics/nn文件夹下建立一个目录名字呢就是'Addmodules'文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。


4.2 修改二 

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'(用群内的文件的话已经有了无需新建),然后在其内部导入我们的检测头如下图所示。


4.3 修改三 

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块(用群内的文件的话已经有了无需重新导入直接开始第四步即可)

从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!


4.4 修改四 

按照我的添加在parse_model里添加即可。

到此就修改完成了,大家可以复制下面的yaml文件运行。


五、RIDNet 的yaml文件和运行记录

5.1 RIDNet 的yaml文件

此版本训练信息:YOLOv10n-RIDNet summary: 541 layers, 2811717 parameters, 2811701 gradients, 28.2 GFLOPs

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, RIDNET, []] # 0-P1/2
  - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 2-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 4-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, SCDown, [512, 3, 2]] # 6-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, SCDown, [1024, 3, 2]] # 8-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 10
  - [-1, 1, PSA, [1024]] # 11

# YOLOv10.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 7], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 14

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 5], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 17 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 14], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 20 (P4/16-medium)

  - [-1, 1, SCDown, [512, 3, 2]]
  - [[-1, 11], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2fCIB, [1024, True, True]] # 23 (P5/32-large)

  - [[17, 20, 23], 1, v10Detect, [nc]] # Detect(P3, P4, P5)


5.2 训练代码 

大家可以创建一个py文件将我给的代码复制粘贴进去,配置好自己的文件路径即可运行。

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO('ultralytics/cfg/models/v8/yolov8-C2f-FasterBlock.yaml')
    # model.load('yolov8n.pt') # loading pretrain weights
    model.train(data=r'替换数据集yaml文件地址',
                # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
                cache=False,
                imgsz=640,
                epochs=150,
                single_cls=False,  # 是否是单类别检测
                batch=4,
                close_mosaic=10,
                workers=0,
                device='0',
                optimizer='SGD', # using SGD
                # resume='', # 如过想续训就设置last.pt的地址
                amp=False,  # 如果出现训练损失为Nan可以关闭amp
                project='runs/train',
                name='exp',
                )


5.3 RIDNet 的训练过程截图 


五、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv10改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

  专栏回顾:YOLOv10改进系列专栏——本专栏持续复习各种顶会内容——科研必备 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1923328.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

零信任作为解决方案,Hvv还能打进去么?

零信任平台由“中心组件服务”三大部分构成,以平台形式充分融合软件定义边界(SDP)、身份与访问管理(IAM)、微隔离 (MSG)的技术方案优势,通过关键技术的创新,实现最佳可信…

PointCloudLib LocalMaximum_DeleteMaxPoint C++版本

测试效果 简介 在点云库(Point Cloud Library,PCL)中,处理点云数据时,经常需要去除局部最大点(Local Maximum),这通常用于去除噪声、提取特定形状的特征或者简化点云数据。局部最大…

python制作甘特图的基本知识(附Demo)

目录 前言1. matplotlib2. plotly 前言 甘特图是一种常见的项目管理工具,用于表示项目任务的时间进度 直观地看到项目的各个任务在时间上的分布和进度 常用的绘制甘特图的工具是 matplotlib 和 plotly 主要以Demo的形式展示 1. matplotlib 功能强大的绘图库&a…

江苏职教高考 计算机 C语言 复习资料

江苏职教高考计算机专业考试内容为 文化课专业课 其中专业课包含: 计算机原理45分 计算机组维45分 计算机网络60分 C语言 6080分 电子电工90分 具体资料可查看链接 链接:https://pan.baidu.com/s/1OXD-zK4V3NsLLDMwfXcTlA?pwd2822 提取码&…

风华绝代林徽因

林徽因 ◉ 卡西莫多 风华绝代林家女,呕心依护古胜迹 短暂一生半百余,忆念至今两甲子 山河有恙因她彩,窈窕淑女众倾慕 不只因她姿颜色,更因巾帼硬风骨 2024年7月12日

Anaconda+Pycharm 项目运行保姆级教程(附带视频)

最近很多小白在问如何用anacondapycharm运行一个深度学习项目,进行代码复现呢?于是写下这篇文章希望能浅浅起到一个指导作用。 附视频讲解地址:AnacondaPycharm项目运行实例_哔哩哔哩_bilibili 一、项目运行前的准备(软件安装&…

【C++深入学习】类和对象(一)

欢迎来到HarperLee的学习笔记! 博主主页传送门:HarperLee博客主页! 欢迎各位大佬交流学习! 本篇本章正式进入C的类和对象部分,本部分知识分为三小节。复习: 结构体复习–内存对齐编译和链接函数栈桢的创建…

如何用码上飞解决企微上真实需求来接单赚米

在企微的工作台中有一个「需求模块」,所有的企微用户都可以在上面提出自己的需求。 例如张三说“在企微上我怎么样才可以把一个客户发的语音,转给另一个客户听?” 李四说“我需要一个能每天在工作群里定时发布信息并能自动修改日期的功能。…

240713-Xinference模型下载、管理及测试

Step 1. 安装Xinference 安装 — Xinference Step 2. 下载模型 方式1: UI界面下载 命令行启动Xinference: xinference-local --host 0.0.0.0 --port 9997在localhost:9997中,通过左侧的Launch Model 配置并下载模型如果模型已下载或配置好&#xff…

好莱坞级别AI视频工具Odyssey亮相!AI世界动态回顾

好莱坞级别的视觉AI:Odyssey 首先,我们要提到的就是Odyssey——一款新晋AI视频工具,它以其好莱坞级别的视觉AI能力引起了广泛关注。奥德赛展示的一些片段令人印象深刻,包括精美的无人机镜头、风景画面以及专业级的B-roll素材。虽…

2024年公司电脑屏幕监控软件推荐|6款好用的屏幕监控软件盘点!

在当今的商业环境中,确保员工的工作效率和数据安全是每个企业管理者的重要任务。屏幕监控软件通过实时监控和记录员工的电脑活动,帮助企业有效地管理和优化工作流程。 1.固信软件 固信软件https://www.gooxion.com/ 主要特点: 实时屏幕监控…

python web自动化测试

安装 seleinum # 设置镜像 pip install statsmodels -i https://pypi.tuna.tsinghua.edu.cn/simple/ pip install selenium安装chrome driver 1、查看chrome的版本 2、https://googlechromelabs.github.io/chrome-for-testing/#stable 上面下载对应版本的chrome driver 只要…

Python | Leetcode Python题解之第230题二叉搜索树中第K小的元素

题目: 题解: class AVL:"""平衡二叉搜索树(AVL树):允许重复值"""class Node:"""平衡二叉搜索树结点"""__slots__ ("val", "parent&quo…

快速理解java的volatile关键字

volatile有啥用呢? volatile是 JVM 提供的 轻量级 的同步机制(有三大特性) 保证可见性 不保证原子性 禁止指令重排 好汉你别激动,同步机制说白了就是 syhconized,这种机制你可以这样理解就是在一段时间内有序的发生操作。假如一…

创建SpringBoot聚合项目

创建SpringBoot聚合项目 需求&#xff1a;以仓库管理系统(warehouse management system)wms为例创建聚合项目 1、创建空项目文件夹 2、创建父工程 删掉src&#xff0c;配置夫工程的pom配置 <properties><maven.compiler.source>8</maven.compiler.source><…

【C++】C++入门实战教程(打造属于自己的C++知识库)

目录 目录 写在前面 1.C学习路线 2.本教程框架介绍 一.C基础部分 1.程序编码规范 2.程序运行与编译 3.关键字 4.常用数据类型 5.运算符相关 二.C进阶部分 1.面向对象编程 2.函数编程 3.模板编程 4.多线程与并发 5.STL介绍及使用 6.内存模型与优化 三.C实战部…

【vue教程】一. 环境搭建与代码规范配置

目录 引言Vue 框架概述起源与设计理念核心特性优势 Vue 开发环境搭建环境要求安装 Vue CLI创建 Vue 项目项目结构介绍运行与构建 组件实例基础模板响应式更新 代码规范为什么要使用代码规范在 Vue 项目中使用 ESLint 、PrettierESLint配置 ESLintrules 自定义错误级别 Prettier…

【Linux杂货铺】3.程序地址空间

1.程序地址空间的引入 fork(&#xff09;函数在调用的时候子如果是子进程则返回0&#xff0c;如果是父进程则返回子进程的pid&#xff0c;在代码中我们分别在子进程和父进程读取全局变量g_val的时候居然出现了俩个不同的值。如下&#xff1a; #include<stdio.h> #includ…

【Linux信号】阻塞信号、信号在内核中的表示、信号集操作函数、sigprocmask、sigpending

我们先来了解一下关于信号的一些常见概念&#xff1a; 实际执行 信号的处理动作 称为信号递达。 信号从产生到递达的之间的状态称为信号未决。 进程可以选择阻塞(Block)某个信号。 被阻塞的信号产生时是处于未决状态的&#xff0c;知道进程解除对该信号的阻塞&#xff0c;该…

01. 课程简介

1. 课程简介 本课程的核心内容可以分为三个部分&#xff0c;分别是需要理解记忆的计算机底层基础&#xff0c;后端通用组件以及需要不断编码练习的数据结构和算法。 计算机底层基础可以包含计算机网络、操作系统、编译原理、计算机组成原理&#xff0c;后两者在面试中出现的频…