Run, Don’t Walk: Chasing Higher FLOPS for Faster Neural Networks
论文地址:
1. 论文解决的问题
2. 解决问题的方法
3. PConv 的适用范围
4. PConv 在目标检测中的应用
5. 评估方法
6. 潜在挑战
7. 未来研究方向
8.即插即用代码
论文地址:
2303.03667https://arxiv.org/pdf/2303.03667
1. 论文解决的问题
这篇论文主要解决的是神经网络的运行速度问题。尽管近年来神经网络的性能突飞猛进,但其高延迟和高计算量也限制了其在实际应用中的推广。为了解决这个问题,研究者们通常关注降低浮点运算次数 (FLOPs),但论文指出,单纯降低 FLOPs 并不一定能带来相应的延迟降低。
2. 解决问题的方法
论文分析了导致低延迟的主要原因,发现是运算符频繁的内存访问导致的。因此,论文提出了一个新的运算符——部分卷积 (PConv),它通过减少冗余计算和内存访问来更有效地提取空间特征。
PConv 的原理:
-
PConv 只对输入通道的一部分应用常规卷积,而其余通道则保持不变。
-
通过这种方式,PConv 在降低 FLOPs 的同时,也减少了内存访问次数,从而提高了运行速度。
-
为了充分利用所有通道的信息,PConv 通常与逐点卷积 (PWConv) 结合使用,形成一个 T 形的感受野,更专注于中心位置。
3. PConv 的适用范围
PConv 可以应用于各种需要提取空间特征的神经网络任务,例如:
-
图像分类: PConv 可以替代现有的卷积运算符,例如深度可分离卷积 (DWConv) 和分组卷积 (GConv),从而提高运行速度。
-
目标检测: PConv 可以用于特征提取网络,例如骨干网络,从而提高检测速度。
-
语义分割: PConv 可以用于特征提取网络,例如编码器,从而提高分割速度。
4. PConv 在目标检测中的应用
PConv 在目标检测中的应用位置:
-
骨干网络: PConv 可以用于替代骨干网络中的 DWConv 或 GConv,从而提高特征提取速度。
-
特征金字塔网络 (FPN): PConv 可以用于替代 FPN 中的 DWConv 或 GConv,从而提高多尺度特征提取速度。
-
注意力机制: PConv 可以用于改进注意力机制,例如 Squeeze-and-Excitation (SE) 块,从而提高注意力机制的效率。
PConv 在目标检测中的优势:
-
提高检测速度: PConv 可以降低目标检测的推理时间,从而提高检测速度。
-
提高检测精度: PConv 可以提取更丰富的特征,从而提高检测精度。
-
降低计算量: PConv 可以降低目标检测的计算量,从而降低对计算资源的需求。
5. 评估方法
为了评估 PConv 在目标检测中的应用效果,可以使用以下指标:
-
平均精度 (AP): 评估目标检测算法的精度。
-
平均精度均值 (mAP): 评估目标检测算法的平均精度。
-
推理时间: 评估目标检测算法的运行速度。
-
计算量: 评估目标检测算法的计算复杂度。
6. 潜在挑战
尽管 PConv 在目标检测中具有很大的潜力,但也存在一些潜在挑战:
-
参数调整: PConv 的性能可能受到参数设置的影响,例如部分比例和卷积核大小。
-
与现有模型的兼容性: PConv 需要与现有的目标检测模型进行整合,这可能需要进行一些修改。
-
训练时间: PConv 可能需要更长的训练时间才能达到最佳性能。
7. 未来研究方向
未来研究方向可以包括:
-
改进 PConv 的设计: 探索更有效的 PConv 设计,例如不同的部分比例和卷积核大小。
-
将 PConv 应用于其他目标检测模型: 将 PConv 应用于其他目标检测模型,例如 YOLO 和 SSD。
-
探索 PConv 在其他视觉任务中的应用: 探索 PConv 在其他视觉任务中的应用,例如图像检索和视频理解。
PConv 是一种很有潜力的运算符,可以用于提高目标检测的速度和精度。将 PConv 应用于目标检测模型,可以降低推理时间、提高检测精度,并降低对计算资源的需求。未来研究可以进一步探索 PConv 的设计、与其他模型的兼容性,以及在其他视觉任务中的应用。
8.即插即用代码
from torch import nn
import torch
class Partial_conv3(nn.Module):
def __init__(self, dim, n_div, forward):
super().__init__()
self.dim_conv3 = dim // n_div
self.dim_untouched = dim - self.dim_conv3
self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
if forward == 'slicing':
self.forward = self.forward_slicing
elif forward == 'split_cat':
self.forward = self.forward_split_cat
else:
raise NotImplementedError
def forward_slicing(self, x):
# only for inference
x = x.clone() # !!! Keep the original input intact for the residual connection later
x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
return x
def forward_split_cat(self, x):
# for training/inference
x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
x1 = self.partial_conv3(x1)
x = torch.cat((x1, x2), 1)
return x
if __name__ == '__main__':
block = Partial_conv3(64, 2, 'split_cat').cuda()
input = torch.rand(3, 64, 64, 64).cuda() #输入shape b c h w
output = block(input)
print(input.size(), output.size())
大家对于YOLO改进感兴趣的可以进群了解,群中有答疑,(QQ群:828370883)