小波卷积:为计算机视觉任务开辟新的参数效率之路

news2024/11/15 8:11:07

论文复述

这篇论文介绍了一种创新的卷积神经网络层——WTConv,它通过小波变换技术显著扩展了CNN的感受野,同时保持了参数效率。WTConv层能够实现对输入数据的多频率响应,增强了模型对形状而非纹理的特征识别能力,提高了在图像分类、语义分割和目标检测等视觉任务中的性能和鲁棒性。论文通过广泛的实验验证了WTConv的有效性,并展示了其在不同视觉任务中的应用潜力。

论文地址: https://arxiv.org/abs/2407.05848

摘要

论文指出,近年来尝试通过增加卷积核的大小来模仿视觉变换器(Vision Transformers, ViTs)自注意力模块的全局感受野,但这种方法很快遇到了上限,并且在达到全局感受野之前就饱和了。作者展示了通过利用小波变换(WT),实际上可以不遭受过度参数化的问题,获得非常大的感受野。例如,对于一个k×k的感受野,所提出方法中可训练参数的数量仅以k的对数级增长。提出的层名为WTConv,可以作为现有架构中的替代品,有效响应多频率,并随着感受野大小的增加而优雅地扩展。通过在ConvNeXt和MobileNetV2架构中展示WTConv层的有效性,以及作为下游任务的骨干网络,并展示了它带来的额外属性,如对图像损坏的鲁棒性增加以及对形状而非纹理的响应增加。

引言

引言指出了卷积神经网络(CNN)在计算机视觉领域的统治地位正受到视觉变换器(ViTs)的挑战,特别是由于ViTs的多头自注意力层能够实现全局特征混合。为了缩小CNN和ViTs之间的性能差距,研究人员尝试通过增大卷积核来增加感受野,但这种方法遇到了饱和问题。论文提出了一个问题:是否有可能在不增加过多参数的情况下,利用信号处理工具有效增加卷积的感受野,从而提高性能。

总结

论文成功地利用小波变换(WT)提出了WTConv层,这是一种新的CNN层,能够在不大幅增加参数的情况下显著增加感受野。WTConv层通过在小波域中进行卷积操作,实现了对输入数据的多频率响应,这使得网络能够更好地捕捉低频信息,从而提高了对形状的敏感性,并增强了网络的鲁棒性。实验结果表明,WTConv层在多个视觉任务中都取得了性能提升,证明了其有效性。

全文要点

WTConv

WTConv(Wavelet Transform Convolution)是一种基于小波变换的卷积层,它旨在为卷积神经网络(CNN)提供更大的感受野,同时避免因使用大卷积核而带来的参数数量急剧增加的问题。WTConv是一种创新的卷积神经网络层,它通过小波变换技术实现了对输入数据的深层次和多尺度分析。以下是WTConv的几个关键特点和工作原理的详细概括:

  1. 小波变换的应用:WTConv使用小波变换对输入信号进行分解,这允许网络在不同的频率和空间尺度上捕捉信息。小波变换提供了一种将信号分解为可提供时间和频率信息的组成部分的方法。

  2. 感受野的显著扩展:通过小波变换的多级分解,WTConv能够在保持参数数量相对较低的同时,实现对输入数据更大范围的覆盖。这意味着即使是小的卷积核也能够通过小波变换捕捉到更广泛的上下文信息。

  3. 参数效率与性能提升:WTConv的设计减少了模型参数的数量,与传统的大卷积核相比,它以参数数量的对数级增长实现了感受野的扩展。这种效率的提升使得WTConv在保持计算成本较低的同时,能够提高模型在图像分类、语义分割等任务上的性能。

  4. 多频率特征的独立处理:WTConv允许网络对分解出的不同频率特征进行独立的卷积处理,这增强了模型对信号中不同特征的响应能力,特别是对低频特征的捕捉,这对于理解图像中的形状和结构非常重要。

  5. 小波反变换的集成:在小波域中处理完信号后,WTConv利用小波反变换将处理后的信号重新组合,以生成最终的输出。这一步骤确保了信号的完整性,并允许网络在原始域中进行最终的特征整合。

WTConv通过这些设计,有效地结合了小波变换的多尺度分析能力和卷积神经网络的深度学习能力,为解决计算机视觉中的复杂问题提供了一种新的工具。

wt(Wavelet Transform)

小波变换(Wavelet Transform, WT)是一种数学变换,用于将信号分解成不同时间尺度上的成分,这些成分能够提供信号的时频信息。它广泛应用于信号处理、图像分析、数据压缩和其他许多领域。以下是小波变换的几个关键特点:

  1. 时频联合表示:与仅提供频率信息的傅里叶变换相比,小波变换能够同时提供信号的时间(或空间)和频率信息,使得它在分析非平稳信号时特别有用。

  2. 多尺度分析能力:小波变换通过在不同的尺度上分析信号,能够揭示信号在不同分辨率下的特性。这种多尺度分解使得小波变换能够适应信号的局部变化,捕捉到重要细节。

  3. 正交小波基:在某些小波变换中,如Haar小波变换,变换基是正交的,这允许无失真地从变换后的系数重构原始信号,保证了变换的逆过程的准确性。

  4. 稀疏性优势:小波变换通常能够产生稀疏的系数矩阵,其中许多系数为零或很小,这不仅有助于数据压缩,还可以在信号去噪和特征提取中发挥作用。

  5. 计算效率:小波变换可以通过快速算法实现,如快速小波变换(FWT),它减少了计算量,提高了处理速度。

小波变换的这些特性使其成为分析和处理信号的理想选择,特别是在需要同时考虑时间和频率信息的复杂场景中。

iwt

小波反变换(Inverse Wavelet Transform, IWT)是小波变换的逆过程,它用于从小波变换的系数中重构原始信号。以下是IWT的关键特点和工作原理:

  1. 信号重构:IWT的主要目的是将小波变换产生的系数转换回原始的信号或数据。这是通过使用小波变换时定义的相同小波函数来实现的,但是以相反的顺序。

  2. 逆过程:IWT是小波变换的逆操作,它利用了小波变换的正交性质,特别是当使用正交小波基时,可以确保信号的精确重构。

  3. 多尺度合成:在多级小波分解的情况下,IWT通过逐步合成不同尺度(或分辨率)上的细节信息来重构信号。这包括将低频和高频成分重新组合。

  4. 系数的整合:IWT通过整合小波变换产生的所有系数,包括近似系数(Approximation coefficients)和细节系数(Detail coefficients),来恢复原始数据。

  5. 计算流程:IWT的计算通常涉及从最粗糙的尺度开始,逐步向上细化至更高尺度的过程。每一步都涉及到将当前尺度的系数与小波函数相结合,以及将从更粗糙尺度上恢复的信息逐步添加进来。

  6. 稀疏性利用:如果小波变换产生了稀疏系数,IWT可以利用这一特性来减少计算量,因为许多接近零的系数可以被忽略或近似处理。

  7. 与WT的兼容性:IWT与小波变换紧密兼容,确保了变换和反变换过程的一致性,这对于保持信号的完整性至关重要。

小波反变换是小波分析中不可或缺的一部分,它确保了小波变换的实用性和有效性,特别是在需要从变换后的系数中恢复原始信号的场景中。

pytorch代码实现

源自:https://github.com/BGU-CS-VIL/WTConv

import torch
import torch.nn as nn
import torch.nn.functional as F
import pywt
import pywt.data

from functools import partial

def create_wavelet_filter(wave, in_size, out_size, type=torch.float):
    w = pywt.Wavelet(wave)
    dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)
    dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)
    dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
                               dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
                               dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
                               dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)

    dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)

    rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])
    rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])
    rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
                               rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
                               rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
                               rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)

    rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)

    return dec_filters, rec_filters

def wavelet_transform(x, filters):
    b, c, h, w = x.shape
    pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
    x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)
    x = x.reshape(b, c, 4, h // 2, w // 2)
    return x


def inverse_wavelet_transform(x, filters):
    b, c, _, h_half, w_half = x.shape
    pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
    x = x.reshape(b, c * 4, h_half, w_half)
    x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)
    return x


class _ScaleModule(nn.Module):
    def __init__(self, dims, init_scale=1.0, init_bias=0):
        super(_ScaleModule, self).__init__()
        self.dims = dims
        self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
        self.bias = None

    def forward(self, x):
        return torch.mul(self.weight, x)

class WTConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):
        super(WTConv2d, self).__init__()

        assert in_channels == out_channels

        self.in_channels = in_channels
        self.wt_levels = wt_levels
        self.stride = stride
        self.dilation = 1

        self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)
        self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
        self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)

        self.wt_function = partial(wavelet_transform, filters=self.wt_filter)
        self.iwt_function = partial(inverse_wavelet_transform, filters=self.iwt_filter)

        self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1,
                                   groups=in_channels, bias=bias)
        self.base_scale = _ScaleModule([1, in_channels, 1, 1])

        self.wavelet_convs = nn.ModuleList(
            [nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,
                       groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)]
        )
        self.wavelet_scale = nn.ModuleList(
            [_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)]
        )

        if self.stride > 1:
            self.stride_filter = nn.Parameter(torch.ones(in_channels, 1, 1, 1), requires_grad=False)
            self.do_stride = lambda x_in: F.conv2d(x_in, self.stride_filter, bias=None, stride=self.stride,
                                                   groups=in_channels)
        else:
            self.do_stride = None

    def forward(self, x):

        x_ll_in_levels = []
        x_h_in_levels = []
        shapes_in_levels = []

        curr_x_ll = x

        for i in range(self.wt_levels):
            curr_shape = curr_x_ll.shape
            shapes_in_levels.append(curr_shape)
            if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
                curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)
                curr_x_ll = F.pad(curr_x_ll, curr_pads)

            curr_x = self.wt_function(curr_x_ll)
            curr_x_ll = curr_x[:, :, 0, :, :]

            shape_x = curr_x.shape
            curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
            curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))
            curr_x_tag = curr_x_tag.reshape(shape_x)

            x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])
            x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])

        next_x_ll = 0

        for i in range(self.wt_levels - 1, -1, -1):
            curr_x_ll = x_ll_in_levels.pop()
            curr_x_h = x_h_in_levels.pop()
            curr_shape = shapes_in_levels.pop()

            curr_x_ll = curr_x_ll + next_x_ll

            curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)
            next_x_ll = self.iwt_function(curr_x)

            next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]

        x_tag = next_x_ll
        assert len(x_ll_in_levels) == 0

        x = self.base_scale(self.base_conv(x))
        x = x + x_tag

        if self.do_stride is not None:
            x = self.do_stride(x)

        return x


x = torch.randn((4, 64, 128, 128))
model = WTConv2d(in_channels=64, out_channels=64)
out = model(x)
print(out.shape)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2071453.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

黑神话悟空不只是玩游戏 有人用它3天赚了85W

这几天你是不是在想办法升级电脑配置,买PS5玩黑神话悟空游戏,每一个男人看到那么好的游戏画面,都控制不住想玩,今天分享给大家一些资料,让你快速玩游戏的同时,还能挣点外快,黑神话悟空不只是玩游…

MATLAB 计算两点沿某个方向的间距(81)

MATLAB 计算两点沿某个方向的间距(81) 一、算法介绍二、算法实现1.代码2.效果一、算法介绍 上一章介绍了如何计算点到空间直线的距离,这里进一步的,我们也可以计算两个点,沿着某个方向的距离,这在很多处理中都会使用到,实际上就是将两点投影到该方向的直线,再计算间距…

线性表复习之初始化顺序表操作

线性表的顺序表示-初始化顺序表 代码 #include <stdio.h> #define MaxSize 10 // 定义最大长度typedef struct{int data[MaxSize]; // 申请空间&#xff08;静态&#xff09;int length; // 当前长度 }SqList;void InitList(SqList &L){for (int i 0; i < MaxS…

java-队列--黑马

队列 别看这个&#xff0c;没用&#xff0c;还是多刷力扣队列题 定义 队列是以顺序的方式维护一组数据的集合&#xff0c;在一端添加数据&#xff0c;从另一端移除数据。一般来讲&#xff0c;添加的一端称之尾&#xff0c;而移除一端称为头 。 队列接口定义 // 队列的接口定…

河南萌新联赛2024第(六)场:郑州大学

目录 A-装备二选一&#xff08;一&#xff09;_河南萌新联赛2024第&#xff08;六&#xff09;场&#xff1a;郑州大学 (nowcoder.com) 思路&#xff1a; 代码&#xff1a; B-百变吗喽_河南萌新联赛2024第&#xff08;六&#xff09;场&#xff1a;郑州大学 (nowcoder.com) …

3DsMax将两个模型的UV展到一个UV上面

3DsMax将两个模型的UV展到一个UV上面 3Dmax中的准备工作 创建一个方块&#xff0c;一个球体&#xff0c;模拟两个模型 添加修改器 打开UV编辑器&#xff0c;快速剥 使用缩放工具&#xff0c;缩放UV&#xff0c;放到一个位置 选择正方形&#xff1a;添加修改器&#xff0…

8.3 数据库基础技术-关系代数

并、交、差 笛卡尔积、投影、选择 自然连接 真题

宝塔面板配置node/npm/yarn/pm2....相关全局变量 npm/node/XXX: command not found

1.打开终端 , cd 到根目录 cd / 2.跳转至node目录下,我的node版本是v16.14.2 cd /www/server/nodejs/v16.14.2/bin 2.1 如果不知道自己node版本多少就跳转到 cd /www/server/nodejs 然后查找当前目录下的文件 ls 确定自己的node版本 cd /node版本/bin 3.继续查看bin…

天润融通助力呷哺呷哺:AI技术赋能3000万会员精细化运营

呷哺集团于1998年11月在北京成立&#xff0c;以“一人一锅”台式小火锅的用餐模式&#xff0c;以及其推出的多样化套餐与良好的用餐服务赢得了众多消费者的青睐&#xff0c;并迅速在市场上占据了一席之地。经过20多年的发展&#xff0c;呷哺呷哺已成为一个多品牌经营、全产业链…

基于Android的安全知识学习APP的设计与实现(论文+源码)_kaic

基于Android的安全知识学习APP的设计与实现 摘 要 随着科技的进步&#xff0c;智能手机已经成为人们工作、学习和生活的必需品。基于Android系统的强大功能&#xff0c;使用Java语言、Linux操作系统&#xff0c;搭配Android Studio&#xff0c;并配备Android开发插件&#…

Unet改进3:在不同位置添加NAMAttention注意力机制

本文内容:在不同位置添加NAMAttention注意力机制 目录 论文简介 1.步骤一 2.步骤二 3.步骤三 4.步骤四 论文简介 识别不太显著的特征是模型压缩的关键。然而,它在革命性的注意机制中尚未得到研究。在这项工作中,我们提出了一种新的基于归一化的注意力模块(NAM),它抑制…

广州自闭症学校哪家好?

在广州&#xff0c;选择一家适合自闭症儿童的康复学校是一个需要慎重考虑的决定。在众多机构中&#xff0c;星启帆自闭症儿童康复机构以其专业的师资团队、全面的康复服务以及温馨的学习环境脱颖而出&#xff0c;成为众多家长信赖的选择。 星启帆自闭症康复中心&#xff0c;作…

敦煌智旅:Serverless 初探,运维提效 60%

作者&#xff1a; 百潼 行业新趋势 在后疫情时代&#xff0c;文旅行业开始复苏&#xff0c;在行业的发展趋势中&#xff0c;我们看到了一个充满机遇和挑战的未来。通过不断创新和适应市场需求&#xff0c;文旅行业继续不断发展壮大&#xff0c;为消费者提供更加丰富多样的旅游…

UnQLite:多语言支持的嵌入式NoSQL数据库深入解析

文章目录 1. 引言2. Key/Value 存储接口2.1 关键函数2.2 使用示例2.3 高级操作&#xff1a;批量文件存储 3. 游标的使用4. UnQLite-Python使用示例4. UnQLite数据库引擎架构5.1 Key/Value存储层5.2 文档存储层5.3 可插拔的存储引擎5.4 事务管理器与分页模块5.5 虚拟文件系统 6.…

右值引用与左值引用

目录 1. 左值与右值2. 左值引用与右值引用 1. 左值与右值 2. 左值引用与右值引用

千益畅行,旅游卡,案例分享

旅游卡作为新旅游这个赛道&#xff0c;到处都是金矿&#xff0c;看你怎么去挖&#xff0c;商机无限。千益畅行旅游卡作为旅游卡源头&#xff0c;提供优质完善的服务&#xff0c;你只需要去铺卡搞钱&#xff0c;其他的售后交给我们&#xff01; #旅游卡服务#

使用静态IP为什么比动态IP的人多?

在网络世界中&#xff0c;IP地址就好比我们的身份证&#xff0c;用来标识我们在互联网上的唯一身份。而静态IP与动态IP&#xff0c;则是这“身份证”的两种不同分配方式。 一、静态IP与动态IP的区别 动态IP&#xff1a;动态IP地址如同租住的公寓&#xff0c;用户每次上网时&a…

【Qt】常见控件 —— QWidget(下)

文章目录 QWidget 的 windowlcon 属性使用 qrc文件管理资源qrc的使用方式在项目中创建 qrc文件把图片 导入到qrc 文件中 QWidget 的 windowOpacity属性 QWidget 的 windowlcon 属性 windowlcon 表示 一个窗口的图标 ( 只能针对 顶层窗口使用 ) windowlcon() 表示 获取到控件的…

吴恩达机器学习课后作业-05偏差与方差

偏差与方差 题目欠拟合改进欠拟合影响偏差和方差因素训练集拟合情况训练集和测试集代价函数选择最优lamda 整体代码 训练集:训练模型 验证集︰模型选择&#xff0c;模型的最终优化 测试集:利用训练好的模型测试其泛化能力 #训练集 x_train,y_train data[X],data[ y]#验证集 …

【C++ Primer Plus习题】4.9

问题: 解答: #include <iostream> using namespace std;typedef struct _CandyBar {string brand;float weight;int calorie; }CandyBar;int main() {CandyBar* snack new CandyBar[3];snack[0] {"德芙",2.1,20};snack[1] { "箭牌",2.2,16 };sna…