文章目录
- 一、SPP模块
- 二、使用pytorch实现
一、SPP模块
SPP模块是指定空间特征金字塔模块,是由何凯明在2014年的论文中所提出的。
论文地址如下:
论文地址
该模块的主要作用是:在分类网络中,通过分类器之后,与全连接层连接时,全连接层的形状是固定的,所以必须将输入网络的图片resize成224224,否则当数据传输到全连接层时,权重不匹配,会发生错误。而将所以图片都resize成224224,可能会使图片失真等等。因此SPP模块提出的作用就是,将分类器最后的pooling层变成SPP模块,这样网络可以接受任意尺寸的输入,不需要将输入图片resize成224*224。
SPP模块的结构如下:
主要理解论文中这句话即可:
These spatial bins have sizes proportional to the image size, so the number of bins is fixed regardless of the image size.
翻译:这些空间箱的大小与图像大小成正比,因此无论图像大小如何,箱的数量都是固定的。
池化窗口的大小和步长都是跟随输入的h和w所变化的,所导致的结果就是,池化之后的h和w一定是4乘以4,2乘以2,1乘以1。
箱指的是小网格
整个的流程如下:
1: 经过分类器的feature map 的尺寸是channel h w(忽略batch)
2: 首先经过第一个最大池化,得到的结果是44大小的,然后经过第二个最大池化,得到的结果是22大小的, 然后经过第三个最大池化,得到的结果是1*1大小的。然后将其展平,拼接起来,就会得到21列的向量。
不论输入的图像尺寸为多少,最后在全连接层之前feature map都会都会变成256 乘以 21大小,其中256是channel。
二、使用pytorch实现
实际上关键就是,动态的求解出池化窗口的k和s大小
class SPP(torch.nn.Module):
def __init__(self, input):
super(SPP, self).__init__()
self.pool_param = [(4, 4), (2, 2), (1, 1)]
# 假设h和w相等,不相等的情况,h和w单独处理即可
h = input.shape[2]
w = input.shape[3]
s1 = h // self.pool_param[0][0]
k1 = h - s1 * (self.pool_param[0][0] - 1)
self.pool_4_4 = torch.nn.MaxPool2d(kernel_size=(k1, k1), stride=(s1, s1))
s2 = h // self.pool_param[1][0]
k2 = h - s2 * (self.pool_param[1][0] - 1)
self.pool_2_2 = torch.nn.MaxPool2d(kernel_size=(k2, k2), stride=(s2, s2))
s3 = h // self.pool_param[2][0]
k3 = h - s2 * (self.pool_param[2][0] - 1)
self.pool_1_1 = torch.nn.MaxPool2d(kernel_size=(k3, k3), stride=(s3, s3))
def forward(self, x):
x1 = self.pool_4_4(x)
x1 = torch.flatten(x1, start_dim=-2, end_dim=-1)
x2 = self.pool_2_2(x)
x2 = torch.flatten(x2, start_dim=-2, end_dim=-1)
x3 = self.pool_1_1(x)
x3 = torch.flatten(x3, start_dim=-2, end_dim=-1)
x = torch.cat((x1, x2, x3), dim=-1)
return x
if __name__ == "__main__":
vgg_model = vgg16_bn(weights = VGG16_BN_Weights.DEFAULT)
# print(vgg_model)
# print(list(vgg_model.features.children()))
test = torch.rand(8, 3, 16, 16)
model = SPP(test)
output = model(test)
print(output.shape)
输入为(8,3,16,16),经过SPP模块之后,大小为
这里16,16刚好是4,2,1的整数倍,更换其他数字
输入为(8,3,15,15),经过SPP模块之后,大小为
可以再尝试其他数字,只要大于等于4,得到的尺寸都是统一的。