论文复述
这篇论文介绍了一种创新的卷积神经网络层——WTConv,它通过小波变换技术显著扩展了CNN的感受野,同时保持了参数效率。WTConv层能够实现对输入数据的多频率响应,增强了模型对形状而非纹理的特征识别能力,提高了在图像分类、语义分割和目标检测等视觉任务中的性能和鲁棒性。论文通过广泛的实验验证了WTConv的有效性,并展示了其在不同视觉任务中的应用潜力。
论文地址: https://arxiv.org/abs/2407.05848
摘要
论文指出,近年来尝试通过增加卷积核的大小来模仿视觉变换器(Vision Transformers, ViTs)自注意力模块的全局感受野,但这种方法很快遇到了上限,并且在达到全局感受野之前就饱和了。作者展示了通过利用小波变换(WT),实际上可以不遭受过度参数化的问题,获得非常大的感受野。例如,对于一个k×k的感受野,所提出方法中可训练参数的数量仅以k的对数级增长。提出的层名为WTConv,可以作为现有架构中的替代品,有效响应多频率,并随着感受野大小的增加而优雅地扩展。通过在ConvNeXt和MobileNetV2架构中展示WTConv层的有效性,以及作为下游任务的骨干网络,并展示了它带来的额外属性,如对图像损坏的鲁棒性增加以及对形状而非纹理的响应增加。
引言
引言指出了卷积神经网络(CNN)在计算机视觉领域的统治地位正受到视觉变换器(ViTs)的挑战,特别是由于ViTs的多头自注意力层能够实现全局特征混合。为了缩小CNN和ViTs之间的性能差距,研究人员尝试通过增大卷积核来增加感受野,但这种方法遇到了饱和问题。论文提出了一个问题:是否有可能在不增加过多参数的情况下,利用信号处理工具有效增加卷积的感受野,从而提高性能。
总结
论文成功地利用小波变换(WT)提出了WTConv层,这是一种新的CNN层,能够在不大幅增加参数的情况下显著增加感受野。WTConv层通过在小波域中进行卷积操作,实现了对输入数据的多频率响应,这使得网络能够更好地捕捉低频信息,从而提高了对形状的敏感性,并增强了网络的鲁棒性。实验结果表明,WTConv层在多个视觉任务中都取得了性能提升,证明了其有效性。
全文要点
WTConv
WTConv(Wavelet Transform Convolution)是一种基于小波变换的卷积层,它旨在为卷积神经网络(CNN)提供更大的感受野,同时避免因使用大卷积核而带来的参数数量急剧增加的问题。WTConv是一种创新的卷积神经网络层,它通过小波变换技术实现了对输入数据的深层次和多尺度分析。以下是WTConv的几个关键特点和工作原理的详细概括:
-
小波变换的应用:WTConv使用小波变换对输入信号进行分解,这允许网络在不同的频率和空间尺度上捕捉信息。小波变换提供了一种将信号分解为可提供时间和频率信息的组成部分的方法。
-
感受野的显著扩展:通过小波变换的多级分解,WTConv能够在保持参数数量相对较低的同时,实现对输入数据更大范围的覆盖。这意味着即使是小的卷积核也能够通过小波变换捕捉到更广泛的上下文信息。
-
参数效率与性能提升:WTConv的设计减少了模型参数的数量,与传统的大卷积核相比,它以参数数量的对数级增长实现了感受野的扩展。这种效率的提升使得WTConv在保持计算成本较低的同时,能够提高模型在图像分类、语义分割等任务上的性能。
-
多频率特征的独立处理:WTConv允许网络对分解出的不同频率特征进行独立的卷积处理,这增强了模型对信号中不同特征的响应能力,特别是对低频特征的捕捉,这对于理解图像中的形状和结构非常重要。
-
小波反变换的集成:在小波域中处理完信号后,WTConv利用小波反变换将处理后的信号重新组合,以生成最终的输出。这一步骤确保了信号的完整性,并允许网络在原始域中进行最终的特征整合。
WTConv通过这些设计,有效地结合了小波变换的多尺度分析能力和卷积神经网络的深度学习能力,为解决计算机视觉中的复杂问题提供了一种新的工具。
wt(Wavelet Transform)
小波变换(Wavelet Transform, WT)是一种数学变换,用于将信号分解成不同时间尺度上的成分,这些成分能够提供信号的时频信息。它广泛应用于信号处理、图像分析、数据压缩和其他许多领域。以下是小波变换的几个关键特点:
-
时频联合表示:与仅提供频率信息的傅里叶变换相比,小波变换能够同时提供信号的时间(或空间)和频率信息,使得它在分析非平稳信号时特别有用。
-
多尺度分析能力:小波变换通过在不同的尺度上分析信号,能够揭示信号在不同分辨率下的特性。这种多尺度分解使得小波变换能够适应信号的局部变化,捕捉到重要细节。
-
正交小波基:在某些小波变换中,如Haar小波变换,变换基是正交的,这允许无失真地从变换后的系数重构原始信号,保证了变换的逆过程的准确性。
-
稀疏性优势:小波变换通常能够产生稀疏的系数矩阵,其中许多系数为零或很小,这不仅有助于数据压缩,还可以在信号去噪和特征提取中发挥作用。
-
计算效率:小波变换可以通过快速算法实现,如快速小波变换(FWT),它减少了计算量,提高了处理速度。
小波变换的这些特性使其成为分析和处理信号的理想选择,特别是在需要同时考虑时间和频率信息的复杂场景中。
iwt
小波反变换(Inverse Wavelet Transform, IWT)是小波变换的逆过程,它用于从小波变换的系数中重构原始信号。以下是IWT的关键特点和工作原理:
-
信号重构:IWT的主要目的是将小波变换产生的系数转换回原始的信号或数据。这是通过使用小波变换时定义的相同小波函数来实现的,但是以相反的顺序。
-
逆过程:IWT是小波变换的逆操作,它利用了小波变换的正交性质,特别是当使用正交小波基时,可以确保信号的精确重构。
-
多尺度合成:在多级小波分解的情况下,IWT通过逐步合成不同尺度(或分辨率)上的细节信息来重构信号。这包括将低频和高频成分重新组合。
-
系数的整合:IWT通过整合小波变换产生的所有系数,包括近似系数(Approximation coefficients)和细节系数(Detail coefficients),来恢复原始数据。
-
计算流程:IWT的计算通常涉及从最粗糙的尺度开始,逐步向上细化至更高尺度的过程。每一步都涉及到将当前尺度的系数与小波函数相结合,以及将从更粗糙尺度上恢复的信息逐步添加进来。
-
稀疏性利用:如果小波变换产生了稀疏系数,IWT可以利用这一特性来减少计算量,因为许多接近零的系数可以被忽略或近似处理。
-
与WT的兼容性:IWT与小波变换紧密兼容,确保了变换和反变换过程的一致性,这对于保持信号的完整性至关重要。
小波反变换是小波分析中不可或缺的一部分,它确保了小波变换的实用性和有效性,特别是在需要从变换后的系数中恢复原始信号的场景中。
pytorch代码实现
源自:https://github.com/BGU-CS-VIL/WTConv
import torch
import torch.nn as nn
import torch.nn.functional as F
import pywt
import pywt.data
from functools import partial
def create_wavelet_filter(wave, in_size, out_size, type=torch.float):
w = pywt.Wavelet(wave)
dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)
dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)
dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)
dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)
rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])
rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])
rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)
rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)
return dec_filters, rec_filters
def wavelet_transform(x, filters):
b, c, h, w = x.shape
pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)
x = x.reshape(b, c, 4, h // 2, w // 2)
return x
def inverse_wavelet_transform(x, filters):
b, c, _, h_half, w_half = x.shape
pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
x = x.reshape(b, c * 4, h_half, w_half)
x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)
return x
class _ScaleModule(nn.Module):
def __init__(self, dims, init_scale=1.0, init_bias=0):
super(_ScaleModule, self).__init__()
self.dims = dims
self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
self.bias = None
def forward(self, x):
return torch.mul(self.weight, x)
class WTConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):
super(WTConv2d, self).__init__()
assert in_channels == out_channels
self.in_channels = in_channels
self.wt_levels = wt_levels
self.stride = stride
self.dilation = 1
self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)
self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
self.wt_function = partial(wavelet_transform, filters=self.wt_filter)
self.iwt_function = partial(inverse_wavelet_transform, filters=self.iwt_filter)
self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1,
groups=in_channels, bias=bias)
self.base_scale = _ScaleModule([1, in_channels, 1, 1])
self.wavelet_convs = nn.ModuleList(
[nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,
groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)]
)
self.wavelet_scale = nn.ModuleList(
[_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)]
)
if self.stride > 1:
self.stride_filter = nn.Parameter(torch.ones(in_channels, 1, 1, 1), requires_grad=False)
self.do_stride = lambda x_in: F.conv2d(x_in, self.stride_filter, bias=None, stride=self.stride,
groups=in_channels)
else:
self.do_stride = None
def forward(self, x):
x_ll_in_levels = []
x_h_in_levels = []
shapes_in_levels = []
curr_x_ll = x
for i in range(self.wt_levels):
curr_shape = curr_x_ll.shape
shapes_in_levels.append(curr_shape)
if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)
curr_x_ll = F.pad(curr_x_ll, curr_pads)
curr_x = self.wt_function(curr_x_ll)
curr_x_ll = curr_x[:, :, 0, :, :]
shape_x = curr_x.shape
curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))
curr_x_tag = curr_x_tag.reshape(shape_x)
x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])
x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])
next_x_ll = 0
for i in range(self.wt_levels - 1, -1, -1):
curr_x_ll = x_ll_in_levels.pop()
curr_x_h = x_h_in_levels.pop()
curr_shape = shapes_in_levels.pop()
curr_x_ll = curr_x_ll + next_x_ll
curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)
next_x_ll = self.iwt_function(curr_x)
next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]
x_tag = next_x_ll
assert len(x_ll_in_levels) == 0
x = self.base_scale(self.base_conv(x))
x = x + x_tag
if self.do_stride is not None:
x = self.do_stride(x)
return x
x = torch.randn((4, 64, 128, 128))
model = WTConv2d(in_channels=64, out_channels=64)
out = model(x)
print(out.shape)