【Block总结】WTConv,小波变换(Wavelet Transform)来扩展卷积神经网络(CNN)的感受野

news2025/1/24 4:24:10

论文解读:Wavelet Convolutions for Large Receptive Fields

论文信息

  • 标题: Wavelet Convolutions for Large Receptive Fields
  • 作者: Shahaf E. Finder, Roy Amoyal, Eran Treister, Oren Freifeld
  • 提交日期: 2024年7月8日
  • arXiv链接: Wavelet Convolutions for Large Receptive Fields
  • Github: https://github.com/BGU-CS-VIL/WTConv

概述

论文《Wavelet Convolutions for Large Receptive Fields》提出了一种新型卷积层,称为WTConv(Wavelet Transform Convolution),旨在通过小波变换(Wavelet Transform)来扩展卷积神经网络(CNN)的感受野。该方法能够在不显著增加参数数量的情况下,获得接近全局的感受野,从而提高模型对低频信息的捕捉能力。
在这里插入图片描述

主要贡献

  1. 感受野扩展:传统的卷积神经网络通过增加卷积核的大小来扩展感受野,但这种方法在达到一定程度后会遇到参数过多的问题。WTConv通过小波变换实现了感受野的有效扩展,且参数数量仅以对数方式增长。

  2. 多频率响应:WTConv能够有效地响应不同频率的输入信号,增强了模型对形状的响应能力,而不仅仅是对纹理的响应。

  3. 架构兼容性:WTConv可以作为现有架构的替代层,适用于多种网络结构,如ConvNeXt和MobileNetV2,且在图像分类等下游任务中表现出色。
    在这里插入图片描述

WTConv如何在不增加参数的情况下扩展感受野

WTConv(Wavelet Transform Convolution)是一种新型卷积层,旨在通过小波变换(Wavelet Transform)有效扩展卷积神经网络(CNN)的感受野,而不显著增加模型的参数数量。这一方法的核心在于利用小波变换的特性,使得感受野的扩展与参数的增长呈对数关系。

  1. 小波变换的优势:小波变换能够将信号分解为不同频率的成分,这使得WTConv能够同时捕捉到低频和高频信息。通过这种方式,WTConv可以在保持较小卷积核的情况下,获得较大的感受野。

  2. 参数增长控制:传统的卷积层通过增加卷积核的大小来扩展感受野,但这会导致参数数量的急剧增加。WTConv的设计使得对于一个 k × k k \times k k×k 的感受野,所需的可训练参数数量仅以对数方式增长,这样可以有效避免过度参数化的问题[7][8]。

  3. 架构兼容性:WTConv可以作为现有网络架构的替代层,例如ConvNeXt和MobileNetV2,能够无缝集成到这些模型中,增强其对形状的响应能力,并提高对图像损坏的鲁棒性[5][10]。

实验结果

在多个图像分类任务中,WTConv表现出色,尤其是在处理复杂形状和纹理时,显示出更强的适应性和准确性,在图像分类任务中优于传统卷积层,尤其在处理图像损坏和复杂形状时表现出更强的鲁棒性。
。这表明WTConv不仅在理论上有效,而且在实际应用中也具有良好的性能。

通过这些机制,WTConv实现了感受野的有效扩展,同时保持了模型的参数效率,适应了现代深度学习对计算资源的需求。

代码:

import torch
import torch.nn as nn
import pywt
import pywt.data

import torch.nn.functional as F


def create_wavelet_filter(wave, in_size, out_size, type=torch.float):
    w = pywt.Wavelet(wave)
    dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)
    dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)
    dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
                               dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
                               dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
                               dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)

    dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)

    rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])
    rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])
    rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
                               rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
                               rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
                               rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)

    rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)

    return dec_filters, rec_filters

def wavelet_transform(x, filters):
    b, c, h, w = x.shape
    pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
    x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)
    x = x.reshape(b, c, 4, h // 2, w // 2)
    return x


def inverse_wavelet_transform(x, filters):
    b, c, _, h_half, w_half = x.shape
    pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
    x = x.reshape(b, c * 4, h_half, w_half)
    x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)
    return x



class WTConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):
        super(WTConv2d, self).__init__()

        assert in_channels == out_channels

        self.in_channels = in_channels
        self.wt_levels = wt_levels
        self.stride = stride
        self.dilation = 1

        self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)
        self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
        self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)

        self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1,
                                   groups=in_channels, bias=bias)
        self.base_scale = _ScaleModule([1, in_channels, 1, 1])

        self.wavelet_convs = nn.ModuleList(
            [nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,
                       groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)]
        )
        self.wavelet_scale = nn.ModuleList(
            [_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)]
        )

        if self.stride > 1:
            self.do_stride = nn.AvgPool2d(kernel_size=1, stride=stride)
        else:
            self.do_stride = None

    def forward(self, x):

        x_ll_in_levels = []
        x_h_in_levels = []
        shapes_in_levels = []

        curr_x_ll = x

        for i in range(self.wt_levels):
            curr_shape = curr_x_ll.shape
            shapes_in_levels.append(curr_shape)
            if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
                curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)
                curr_x_ll = F.pad(curr_x_ll, curr_pads)

            curr_x =wavelet_transform(curr_x_ll, self.wt_filter)
            curr_x_ll = curr_x[:, :, 0, :, :]

            shape_x = curr_x.shape
            curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
            curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))
            curr_x_tag = curr_x_tag.reshape(shape_x)

            x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])
            x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])

        next_x_ll = 0

        for i in range(self.wt_levels - 1, -1, -1):
            curr_x_ll = x_ll_in_levels.pop()
            curr_x_h = x_h_in_levels.pop()
            curr_shape = shapes_in_levels.pop()

            curr_x_ll = curr_x_ll + next_x_ll

            curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)
            next_x_ll = inverse_wavelet_transform(curr_x, self.iwt_filter)

            next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]

        x_tag = next_x_ll
        assert len(x_ll_in_levels) == 0

        x = self.base_scale(self.base_conv(x))
        x = x + x_tag

        if self.do_stride is not None:
            x = self.do_stride(x)

        return x


class _ScaleModule(nn.Module):
    def __init__(self, dims, init_scale=1.0, init_bias=0):
        super(_ScaleModule, self).__init__()
        self.dims = dims
        self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
        self.bias = None

    def forward(self, x):
        return torch.mul(self.weight, x)

if __name__ == '__main__':
    # 创建一个随机输入张量,形状为 (batch_size,height×width,channels)
    input1 = torch.rand(1, 64,40, 40)


    # 实例化EFC模块
    block = WTConv2d(64,64,kernel_size=7)
    # 前向传播
    output = block(input1)

    # 打印输入和输出的形状
    print(input1.size())
    print(output.size())

输出结果:

torch.Size([1, 64, 40, 40])
torch.Size([1, 64, 40, 40])

应用案例

https://jingjing.blog.csdn.net/article/details/145248050?spm=1001.2014.3001.5502

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2281204.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Couchbase UI: Indexes

在Couchbase中,索引的这些指标可以帮助你评估索引的性能和状态。下面是每个指标的详细解释,以及如何判断索引的有效性: 1. Index Name(索引名称) 描述:每个索引都有一个唯一的名称。这个名称通常会包括表…

(3)STM32 USB设备开发-USB存储设备

例程:STM32USBdevice: 基于STM32的USB设备例子程序 - Gitee.com 本篇为使用芯片内部flash作为USB存储设备的例程,没有知识,全是实操,按照步骤就能获得一个STM32的U盘。本例子是在野火F103MINI开发板上验证的,如果代码…

细说STM32F407单片机电源低功耗StopMode模式及应用示例

目录 一、停止模式基础知识 1、进入停止模式 2、停止模式的状态 3、退出停止模式 4、SysTick定时器的影响 二、停止模式应用示例 1、示例功能和CubeMX项目配置 (1)时钟 (2)RTC (3)ADC1 &#xf…

Blazor-Blazor WebAssmbly项目结构(上)

创建项目 今天我们来创建一个BlazorWebAssmbly项目,来看看项目结构是如何得,我们创建带模板得项目,会创建出一个demo,来看看项目结构。 创建的项目可以直接启动运行,首次启动会看见加载的过程,这个过程…

【2024年终总结】我与CSDN的一年

👉作者主页:心疼你的一切 👉作者简介:大家好,我是心疼你的一切。Unity3D领域新星创作者🏆,华为云享专家🏆 👉记得点赞 👍 收藏 ⭐爱你们,么么哒 文章目录 …

开篇:吴恩达《机器学习》课程及免费旁听方法

课程地址: Machine Learning | Coursera 共包含三个子课程 Supervised Machine Learning: Regression and Classification | Coursera Advanced Learning Algorithms | Coursera Unsupervised Learning, Recommenders, Reinforcement Learning | Coursera 免费…

推荐一个开源的轻量级任务调度器!TaskScheduler!

大家好,我是麦鸽。 这次推荐一款轻量级的嵌入式任务调度器,目前已经有1.4K的star,这个项目比较轻量化,只有5个源文件,可以作为学习的一个开源项目。 核心文件 项目概述: 这是一个轻量级的协作式多任务处理&…

暑期实习准备:C语言(持续更新)

1.局部变量和全局变量 局部变量的作用域是在变量所在的局部范围,全局变量的作用域是整个工程;局部变量的生命周期是作用域内,全局变量的生命周期是整个程序的生命周期,当两者命名冲突时,优先使用的是局部变量。 2.C语言…

Harmony Next 支持创建分身

应用分身能实现在一个设备上安装多个相同的应用,实现多个账号同时登录使用和运行并且互不影响。主要应用场景有社交账号双开、游戏大小号双开等,无需账号切换,从而省去频繁登录的繁琐。 Harmony Next 很容易就能让 App 支持创建分身。 官方文…

java ,springboot 对接支付宝支付,实现生成付款二维码,退款,查询订单状态等接口

查看文档 支付宝文档地址&#xff1a; 小程序文档 - 支付宝文档中心 使用沙箱环境 沙箱登录地址 登录 - 支付宝 点击查看 才能看钥匙截图写错了。。 问号可以看默认加密方式 点击沙箱帐号 这里我们就具备所有条件了 实战开始 pom文件增加依赖 <dependency> <gro…

深入内核讲明白Android Binder【三】

深入内核讲明白Android Binder【三】 前言一、服务的获取过程内核源码解析1. 客户端获取服务的用户态源码回顾2. 客户端获取服务的内核源码分析2.1 客户端向service_manager发送数据1. binder_ioctl2. binder_ioctl_write_read3. binder_thread_write4. binder_transaction4.1 …

chrome游览器JSON Formatter插件无效问题排查,FastJsonHttpMessageConverter导致Content-Type返回不正确

问题描述 chrome游览器又一款JSON插件叫JSON Formatter&#xff0c;游览器GET请求调用接口时&#xff0c;如果返回的数据是json格式&#xff0c;则会自动格式化展示&#xff0c;类似这样&#xff1a; 但是今天突然发现怎么也格式化不了&#xff0c;打开一个json文件倒是可以格…

canvas基础

今天我们简单的来认识学习一下canvas的基础概念和使用方法。 1. 认识canvas 1.1 什么是canvas 在网页开发中&#xff0c;canvas是html5中的一个元素&#xff0c;用于通过JavaScript绘制图形。它可以用来制作简单的图表、动画和游戏等。 1.2. 使用场景 游戏开发&#xff1a…

OneData体系架构详解

阿里巴巴的 OneData 体系架构方法论&#xff0c;主要分为三个阶段&#xff1a;业务板块、规范定义 和 模型设计。每个阶段的核心目标是确保数据的高效管理、共享与分析能力。 一. 业务板块&#xff08;Business Segment&#xff09; 业务板块是OneData体系架构中的第一步&…

【C++】哈希表的使用

unordered_map/unordered_set 这是C11才新增的两个容器 原本觉得avl树和红黑树效率已经够了。 后来探索和觉得哈希还是有必要加进来的。 JAVA里面是这样取名的&#xff1a; unordered_set unordered_map/set与map/set的功能基本一致&#xff0c;但细节上有所不同&#x…

微信小程序1.1 微信小程序介绍

1.1 微信小程序介绍 内容提要 什么是微信小程序 微信小程序的功能 微信小程序使用场景 微信小程序能取代App吗 微信小程序的发展历程 微信小程序带来的机会

前端Vue2项目使用md编辑器

项目中有一个需求&#xff0c;要在前端给用户展示内容&#xff0c;内容有 AI 生成的&#xff0c;返回来的是 md 格式&#xff0c;所以需要给用户展示 md 格式&#xff0c;并且管理端也可以编辑这个 md 格式的文档。 使用组件库 v-md-editor。 https://code-farmer-i.github.i…

26、正则表达式

目录 一. 匹配字符 .&#xff1a;匹配除换行符外的任意单个字符。 二. 位置锚点 ^&#xff1a;匹配输入字符串的开始位置。 $&#xff1a;匹配输入字符串的结束位置。 \b&#xff1a;匹配单词边界。 \B&#xff1a;匹配非单词边界。 三. 重复限定符 *&#xff1a;匹配…

K8S中Service详解(一)

Service介绍 在Kubernetes中&#xff0c;Service资源解决了Pod IP地址不固定的问题&#xff0c;提供了一种更稳定和可靠的服务访问方式。以下是Service的一些关键特性和工作原理&#xff1a; Service的稳定性&#xff1a;由于Pod可能会因为故障、重启或扩容而获得新的IP地址&a…

【真机调试】前端开发:移动端特殊手机型号有问题,如何在电脑上进行调试?

目录 前言一、怎么设置成开发者模式&#xff1f;二、真机调试基本步骤&#xff1f; &#x1f680;写在最后 前言 edge浏览器 edge://inspect/#devices 谷歌浏览器&#xff08;开tizi&#xff09; chrome://inspect 一、怎么设置成开发者模式&#xff1f; Android 设备 打开设…