Haar小波下采样模块

news2024/11/15 11:02:42

论文原址:Haar wavelet downsampling: A simple but effective downsampling module for semantic segmentation - ScienceDirect

原文代码:HWD/HWD.py at main · apple1986/HWD (github.com)

介绍 

深度卷积神经网络 (DCNN) 通常采用标准的下采样操作,例如最大池化、平均池化和跨步卷积,这可能会导致信息丢失。丢失的信息,如边界和纹理,对于语义分割可能是必不可少的。为了缓解这个问题,一般有下面四种方法:

  1. 通过跳过连接到解码器子网(如U-Net、LCU-Net、CENet、LinkNet和RefineNet )。
  2. 提取具有空间金字塔池化或扩展卷积的多尺度特征图到融合模块中(如DeepLab、PSPNet、PCPLP-Net、BiSenet和ICNet)。
  3. 向编码器提供多模态图像(如DiSegNet、MMADT、CANet和CCFFNet)。
  4. 增加先验信息。轮廓增强关注模块,旨在从CT图像中提取边界和形状线索,以细化分割区域。

这些方法的主要目的是通过基于多尺度、先验指导、多模态等各种策略提供更多的学习信息或特征,帮助下采样特征与分割标签之间建立良好的关系。

因此,是否可以设计一个保留信息的下采样模块,使DCNNs中尽可能多地保留信息进行语义分割?这就是作者的想法。 

下采样模块

最大池化与平均池化

池化过程类似于卷积过程。在这个示意图中,我们看到对一个 4x4 的特征图邻域进行操作,使用了一个 2x2 的滤波器,步长为2进行扫描。这个过程被称为最大池化(Max Pooling),其中选择邻域内的最大值并输出到下一层。

常用的 max pooling 参数是 S=2、f=2,其效果是将特征图的高度和宽度减半,而通道数保持不变。

如上图所示,描述的是对一个 4x4 的特征图邻域内的数值进行操作。使用了一个 2x2 的滤波器,步长为2进行扫描,计算邻域内数值的平均值并将其输出到下一层。这种操作被称为平均池化(Mean Pooling)。

"""
Copyright (c) 2023, Auorui.
All rights reserved.

The Torch implementation of average pooling and maximum pooling has been compared with the official Torch implementation
"""
import torch
import torch.nn as nn

__all__ = ["MaxPool2d", "AvgPool2d"]

class MaxPool2d(nn.Module):
    """
    池化层计算公式:
        output_size = [(input_size−kernel_size) // stride + 1]
    """
    def __init__(self, kernel_size, stride):
        super(MaxPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride

    def max_pool2d(self, input_tensor, kernel_size, stride):
        batch_size, channels, height, width = input_tensor.size()
        output_height = (height - kernel_size) // stride + 1
        output_width = (width - kernel_size) // stride + 1
        output_tensor = torch.zeros(batch_size, channels, output_height, output_width)

        for i in range(output_height):
            for j in range(output_width):
                # 获取输入张量中与池化窗口对应的部分
                window = input_tensor[:, :,
                         i * stride: i * stride + kernel_size, j * stride: j * stride + kernel_size]
                output_tensor[:, :, i, j] = torch.max(window.reshape(batch_size, channels, -1), dim=2)[0]
        return output_tensor

    def forward(self, input_tensor):
        return self.max_pool2d(input_tensor, kernel_size=self.kernel_size, stride=self.stride)


class AvgPool2d(nn.Module):
    """
    池化层计算公式:
        output_size = [(input_size−kernel_size) // stride + 1]
    """
    def __init__(self, kernel_size, stride):
        super(AvgPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride

    def avg_pool2d(self, input_tensor, kernel_size, stride):
        batch_size, channels, height, width = input_tensor.size()
        output_height = (height - kernel_size) // stride + 1
        output_width = (width - kernel_size) // stride + 1
        output_tensor = torch.zeros(batch_size, channels, output_height, output_width)

        for i in range(output_height):
            for j in range(output_width):
                # 获取输入张量中与池化窗口对应的部分
                window = input_tensor[:, :,
                         i * stride: i * stride + kernel_size, j * stride:j * stride + kernel_size]
                output_tensor[:, :, i, j] = torch.mean(window.reshape(batch_size, channels, -1), dim=2)
        return output_tensor

    def forward(self, input_tensor):
        return self.avg_pool2d(input_tensor, kernel_size=self.kernel_size, stride=self.stride)


if __name__=="__main__":
    # input_data = torch.rand((1, 3, 3, 3))
    input_data = torch.Tensor([[[[0.3939, 0.8964, 0.3681],
                               [0.5134, 0.3780, 0.0047],
                               [0.0681, 0.0989, 0.5962]],
                              [[0.7954, 0.4811, 0.3329],
                               [0.8804, 0.3986, 0.3561],
                               [0.2797, 0.3672, 0.6508]],
                              [[0.6309, 0.1340, 0.0564],
                               [0.3101, 0.9927, 0.5554],
                               [0.0947, 0.2305, 0.8299]]]])

    print(input_data.shape)

    kernel_size = 3
    stride = 1
    MaxPool2d1 = nn.MaxPool2d(kernel_size, stride)
    output_data_with_torch_max = MaxPool2d1(input_data)
    AvgPool2d1 = nn.AvgPool2d(kernel_size, stride)
    output_data_with_torch_avg = AvgPool2d1(input_data)
    AvgPool2d2 = AvgPool2d(kernel_size, stride)
    output_data_with_torch_Avg = AvgPool2d2(input_data)
    MaxPool2d2 = MaxPool2d(kernel_size, stride)
    output_data_with_torch_Max = MaxPool2d2(input_data)
    # output_data_with_max = max_pool2d(input_data, kernel_size, stride)
    # output_data_with_avg = avg_pool2d(input_data, kernel_size, stride)

    print("\ntorch.nn pooling Output:")
    print(output_data_with_torch_max,"\n",output_data_with_torch_max.size())
    print(output_data_with_torch_avg,"\n",output_data_with_torch_avg.size())
    print("\npooling Output:")
    print(output_data_with_torch_Max,"\n",output_data_with_torch_Max.size())
    print(output_data_with_torch_Avg,"\n",output_data_with_torch_Avg.size())
    # 直接使用bool方法判断会因为浮点数的原因出现偏差
    print(torch.allclose(output_data_with_torch_max,output_data_with_torch_Max))
    print(torch.allclose(output_data_with_torch_avg,output_data_with_torch_Avg))
    # tensor([[[[0.8964]],       # output_data_with_max
    #          [[0.8804]],
    #          [[0.9927]]]])
    # tensor([[[[0.3686]],       # output_data_with_avg
    #           [[0.5047]],
    #           [[0.4261]]]])

在这里,简单地与PyTorch官方的实现进行了比对,成功的进行复现。

跨步卷积

import torch
import torch.nn as nn

class StridedConvolution(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, is_relu=True):
        super(StridedConvolution, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.is_relu = is_relu

    def forward(self, x):
        x = self.conv(x)
        if self.is_relu:
            x = self.relu(x)
        return x

if __name__ == '__main__':
    input_data = torch.rand((1, 3, 64, 64))
    strided_conv = StridedConvolution(3, 64)
    output_data = strided_conv(input_data)
    print("Input shape:", input_data.shape)
    print("Output shape:", output_data.shape)

对输入进行跨步卷积,并根据 is_relu 参数选择是否添加ReLU激活函数。在构建卷积神经网络时经常被用于下采样步骤,以减小特征图的尺寸。

Haar小波下采样

这一部分就直接参考的作者的代码,与池化不同的是,这里它是要指定输入输出几个通道。

"""
Haar Wavelet-based Downsampling (HWD)

Original address of the paper: https://www.sciencedirect.com/science/article/abs/pii/S0031320323005174
Code reference: https://github.com/apple1986/HWD/tree/main
"""
import torch
import torch.nn as nn
from pytorch_wavelets import DWTForward

class HWDownsampling(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(HWDownsampling, self).__init__()
        self.wt = DWTForward(J=1, wave='haar', mode='zero')
        self.conv_bn_relu = nn.Sequential(
            nn.Conv2d(in_channel * 4, out_channel, kernel_size=1, stride=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        yL, yH = self.wt(x)
        y_HL = yH[0][:, :, 0, ::]
        y_LH = yH[0][:, :, 1, ::]
        y_HH = yH[0][:, :, 2, ::]
        x = torch.cat([yL, y_HL, y_LH, y_HH], dim=1)
        x = self.conv_bn_relu(x)
        return x


if __name__ == '__main__':
    downsampling_layer = HWDownsampling(3, 64)
    input_data = torch.rand((1, 3, 64, 64))
    output_data = downsampling_layer(input_data)
    print("Input shape:", input_data.shape)
    print("Output shape:", output_data.shape)

Haar小波变换是一种基于小波的信号处理方法,它将信号分解成低频和细节高频两个部分。在图像处理中,Haar小波通常用于图像压缩和特征提取,代码中使用的DWTForward模块中离散小波变换,通过选择 yH 中的不同方向上的高频分量,构建了新的特征图。将原始低频分量 yL 与新构建的高频分量拼接在一起。最后通过一个包含卷积、批归一化和ReLU激活函数的序列处理最终的特征图。

实验验证

这是作者论文中做的实验,这样看起来,似乎HWD在细节上确实是比池化和跨步卷积效果要好。

这里因为我也用我自己的数据进行了实验:

最大池化效果

平均池化效果

跨步卷积效果 

HDW效果

从肉眼上来看,HDW的效果确实要比其他的效果要好一些。

下面是我做实验的代码,感兴趣的可以在自己的数据上面进行实验,我觉得用于交通和医学上应该会有比较好的效果。

import cv2
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn as nn
from pytorch_wavelets import DWTForward

class StridedConvolution(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, is_relu=True):
        super(StridedConvolution, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.is_relu = is_relu

    def forward(self, x):
        x = self.conv(x)
        if self.is_relu:
            x = self.relu(x)
        return x

class HWDownsampling(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(HWDownsampling, self).__init__()
        self.wt = DWTForward(J=1, wave='haar', mode='zero')
        self.conv_bn_relu = nn.Sequential(
            nn.Conv2d(in_channel * 4, out_channel, kernel_size=1, stride=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        yL, yH = self.wt(x)
        y_HL = yH[0][:, :, 0, ::]
        y_LH = yH[0][:, :, 1, ::]
        y_HH = yH[0][:, :, 2, ::]
        x = torch.cat([yL, y_HL, y_LH, y_HH], dim=1)
        x = self.conv_bn_relu(x)
        return x

class DeeperCNN(nn.Module):
    def __init__(self):
        super(DeeperCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        # self.pool1 = HWDownsampling(16, 16)
        self.pool1 = StridedConvolution(16, 16, is_relu=True)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(32)
        # self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        # self.pool2 = HWDownsampling(32, 32)
        self.pool2 = StridedConvolution(32, 32, is_relu=True)

        self.conv6 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.pool1(self.relu(self.batch_norm1(self.conv1(x))))
        print(x.shape)
        x = self.pool2(self.relu(self.batch_norm2(self.conv2(x))))
        print(x.shape)
        x = self.conv6(x)
        return x

image_path = r'D:\PythonProject\Crack_classification_training_script\data\base\val\crack\2416.png'
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

transform = transforms.Compose([transforms.ToTensor()])
input_image = transform(image).unsqueeze(0)
import numpy as np
model = DeeperCNN()
output = model(input_image)
print("Output shape:", output.shape)

input_image = input_image.squeeze(0).permute(1, 2, 0).numpy()
output_image = output.squeeze(0).permute(1, 2, 0).detach().numpy()
output_image = output_image / output_image.max()
output_image = np.clip(output_image, 0, 1)

plt.subplot(1, 2, 1)
plt.imshow(input_image)
plt.title('Input Image')

plt.subplot(1, 2, 2)
plt.imshow(output_image)
plt.title('Output Image')

plt.show()

总结 

在论文当中,作者也做了大量的消融实验去证实这个下采样模块的有效性,建议大家去看看原著作,或许会有更多的收获。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1403917.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

性能优化-HVX 指令介绍

「发表于知乎专栏《移动端算法优化》」 本文主要介绍了 HVX 指令相关的知识,包括 HVX 寄存器相关内容,指令的背景依赖,部分常用 intrinsic HVX 指令。具体指令的详细内容及使用还需阅读 HVX 的指令文档,以及细致的实践操作。 &…

MATLAB - 激光雷达 - 相机联合标定(Lidar-Camera Calibration)

系列文章目录 前言 一、 激光雷达 - 相机标定建立了三维激光雷达点和二维相机数据之间的对应关系,从而将激光雷达和相机输出融合在一起。 激光雷达传感器和相机被广泛用于自动驾驶、机器人和导航等应用中的三维场景重建。激光雷达传感器捕捉环境的三维结构信息&am…

惊了!竟然有上千款小游戏源码,可直接打包H5\微信\抖音,赶紧收藏!

很多人还不知道 Cocos Store 资源商城,它是国内最大的小游戏资源平台。上面有大量免费游戏源码可以下载,比如下图的《赛博朋克》,项目中包含大量模型、贴图,还有游戏源代码,通过Cocos引擎可以直接在浏览器上玩。 本文就…

编程语言MoonBit新增矩阵函数的语法糖

MoonBit更新 1. 新增矩阵函数的语法糖 新增矩阵函数的语法糖,用于方便地定义局部函数和具有模式匹配的匿名函数: fn init {fn boolean_or { // 带有模式匹配的局部函数true, _ > true_, true > true_, _ > false}fn apply(f, x) {f(x)}le…

vConsole 与 Vue中未定义变量而引发的Maximum call stack size exceeded异常问题

一、问题描述 前段时间有个前端小伙伴反馈在打包发布正式环境后调用VantUI的<van-popup>组件显示时&#xff0c;显示空白&#xff0c;并且在控制台看到一个Maximum call stacksize exceeded&#xff08;超出最大调用堆栈大小&#xff09;,而本地开发环境正常&#xff1a…

NOC总线(2)

1. NoC的路由 在NoC交换信息时&#xff0c;需要确定从源节点到目标节点所经过的路径&#xff0c;这时就需要路由算法来确定该路径。路由算法分为静态路由算法和动态路由算法两种。 静态路由算法对于两节点之间的路径是固定的&#xff0c;结构简单&#xff0c;便于硬件实…

mysql 导入数据 1273 - Unknown collation: ‘utf8mb4_0900_ai_ci‘

前言: mysql 导入数据 遇到这个错误 1273 - Unknown collation: utf8mb4_0900_ai_ci 具体原因没有深究 但应该是设计数据库的 字符集类型会出现这个问题 例如: char varchar text..... utf8mb4 类型可以存储表情 在现在这个时代会用很多 以后会用的更多 所以不建议改…

基于LLaMA Factory,单卡3小时训练专属大模型 Agent

大家好&#xff0c;今天给大家带来一篇 Agent 微调实战文章 Agent&#xff08;智能体&#xff09;是当今 LLM&#xff08;大模型&#xff09;应用的热门话题 [1]&#xff0c;通过任务分解&#xff08;task planning&#xff09;、工具调用&#xff08;tool using&#xff09;和…

多维时序 | Matlab实现CNN-GRU-Mutilhead-Attention卷积门控循环单元融合多头注意力机制多变量时间序列预测

多维时序 | Matlab实现CNN-GRU-Mutilhead-Attention卷积门控循环单元融合多头注意力机制多变量时间序列预测 目录 多维时序 | Matlab实现CNN-GRU-Mutilhead-Attention卷积门控循环单元融合多头注意力机制多变量时间序列预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍…

从CNN ,LSTM 到Transformer的综述

前情提要&#xff1a;文本大量参照了以下的博客&#xff0c;本文创作的初衷是为了分享博主自己的学习和理解。对于刚开始接触NLP的同学来说&#xff0c;可以结合唐宇迪老师的B站视频【【NLP精华版教程】强推&#xff01;不愧是的最完整的NLP教程和学习路线图从原理构成开始学&a…

k8s--helm

什么是helm&#xff1f;在没有这个helm之前&#xff0c;deployment service ingress helm的作用 通过打包的方式&#xff0c;把deployment service ingress等打包在一块&#xff0c;一键式的部署服务&#xff0c;类似yum安装 官方提供的一个类似与安装仓库额功能&#xff0c;…

详解APQC流程分级分类框架PCF13个高阶分类和5级业务流程

一&#xff1a;什么是APQC 美国生产力与质量中心(American Productivity and Quality Center&#xff0c;简称为APQC)&#xff0c;创立于1977年是一个会员制的非营利机构&#xff0c;使命是“发现有效的改进方法&#xff0c;广泛地传播其发现成果&#xff0c;实现个人之间及其…

MySQL函数—字符串函数

MySQL函数—字符串函数 函数功能CONCAT(s1,s2,...sn)字符串拼接&#xff0c;将s1,s2,...sn拼接成一个字符串LOWER(str)将字符串全部转为小写UPPER(str)将字符串全部转为大写LPAD(str,n,pad)左填充&#xff0c;用字符串pad对str左边进行填充&#xff0c;达到n个字符串长度RPAD(s…

Leetcode—19.删除链表的倒数第 N 个结点【中等】

2023每日刷题&#xff08;七十五&#xff09; Leetcode—19.删除链表的倒数第 N 个结点 算法思想 实现代码 /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode() : val(0), next(nullptr) {}* ListNode(int…

EHS管理系统为何需要物联网的加持?

EHS是Environment、Health、Safety的缩写&#xff0c;是从欧美企业引进的管理体系&#xff0c;在国外也被称为HSE。EHS是指健康、安全与环境一体化的管理。 而在国内&#xff0c;整个EHS市场一共被分成三类&#xff1b; 一类是EHS管培体系&#xff0c;由专门的EHS机构去为公司…

使用AFPN渐近特征金字塔网络优化YOLOv8改进小目标检测效果(不适合新手)

目录 简单概述 算法概述 优化效果 参考文献 文献地址&#xff1a;paper 废话少说&#xff0c;上demo源码链接&#xff1a; 简单概述 AFPN的核心思想&#xff1a;AFPN主要通过引入渐近的特征融合策略&#xff0c;逐步整合底层、高层和顶层的特征到目标检测过程中。这种融合…

架构篇08:架构设计三原则

文章目录 合适原则简单原则演化原则小结 成为架构师是每个程序员的梦想&#xff0c;但并不意味着把编程做好就能够自然而然地成为一个架构师&#xff0c;优秀程序员和架构师之间还有一个明显的鸿沟需要跨越&#xff0c;这个鸿沟就是“不确定性”。 对于编程来说&#xff0c;本…

高效构建Java应用:Maven的使用总结

一、Maven简介和快速入门 1.1 Maven介绍 Maven-Introduction Maven 是一款为 Java 项目构建管理、依赖管理的工具&#xff08;软件&#xff09;&#xff0c;使用 Maven 可以自动化构建、测试、打包和发布项目&#xff0c;大大提高了开发效率和质量。 总结&#xff1a;Maven…

用ChatGPT教学、科研!大学与OpenAI合作

亚利桑那州立大学&#xff08;简称“ASU”&#xff09;在官网宣布与OpenAI达成技术合作。从2024年2月份开始&#xff0c;为所有学生提供ChatGPT企业版访问权限&#xff0c;主要用于学习、课程作业和学术研究等。 为了帮助学生更好地学习ChatGPT和大语言模型产品&#xff0c;AS…

mysql生成最近24小时整点时间临时表

文章目录 生成最近24小时整点生成最近30天生成12个月 生成最近24小时整点 SELECT-- 每向下推1行, i比上次减去1b.*, i.*,DATE_FORMAT( DATE_SUB( NOW(), INTERVAL ( -( i : i - 1 ) ) HOUR ), %Y-%m-%d %H:00 ) AS time FROM-- 目的是生成12行数据( SELECTa FROM( SELECT 1 A…