Octave Conv
Octave Convolution 代码详解_octconv代码_zghydx1924的博客-CSDN博客
def forward(self, x):
X_h, X_l = x
if self.stride ==2:
X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
X_h2l = self.h2g_pool(X_h)
# X_h2l指的是对输入进行下采样,下采样的方法时卷积核大小2×2,步长为2的平均池化。
end_h_x = int(self.in_channels*(1- self.alpha_in))
end_h_y = int(self.out_channels*(1- self.alpha_out))
# 假设输入的通道数为256,输出的通道数为512,alpha_in=alpha_out=0.75。那么end_h_x=64,end_h_y=128。
X_h2h = F.conv2d(X_h, self.weights[0:end_h_y, 0:end_h_x, :,:], self.bias[0:end_h_y], 1,
self.padding, self.dilation, self.groups)
X_l2l = F.conv2d(X_l, self.weights[end_h_y:, end_h_x:, :,:], self.bias[end_h_y:], 1,
self.padding, self.dilation, self.groups)
X_h2l = F.conv2d(X_h2l, self.weights[end_h_y:, 0: end_h_x, :,:], self.bias[end_h_y:], 1,
self.padding, self.dilation, self.groups)
X_l2h = F.conv2d(X_l, self.weights[0:end_h_y, end_h_x:, :,:], self.bias[0:end_h_y], 1,
self.padding, self.dilation, self.groups)
X_l2h = F.upsample(X_l2h, scale_factor=2, **self.up_kwargs)
#低频分量的分辨率为高频分量的一般,因此需要上采样后进行计算
X_h = X_h2h + X_l2h
X_l = X_l2l + X_h2l
return X_h, X_l