本文不细看paper,只看网络结构和源码实现。
看下ShuffleNetV2的结构吧。
image是3通道进去,经过conv1和maxpool,
然后stage2~4则是主题,里面stride = 2和 stride = 1的shuffleBlock分别重复几次。
shuffleBlock如下,左边是stride = 1的,右边是stride = 2的。
举个栗子,stage2的in_channel为24, out_channel为116,
每个block是有2个branch的,这个channel要分配一下,比如左边右边各58,经过最后的Concat, 就是116.
stride = 2时channel会加倍,stride = 1时channel不变。
后面代码里会看到。
class ShuffleNetV2(nn.Module):
def __init__(
self,
model_size="1.5x",
out_stages=(2, 3, 4),
with_last_conv=False,
kernal_size=3,
activation="ReLU",
pretrain=True,
):
super(ShuffleNetV2, self).__init__()
# out_stages can only be a subset of (2, 3, 4)
assert set(out_stages).issubset((2, 3, 4))
print("model size is ", model_size) #1.0x
self.stage_repeats = [4, 8, 4]
self.model_size = model_size
self.out_stages = out_stages
self.with_last_conv = with_last_conv #False
self.kernal_size = kernal_size
self.activation = activation #LeakyReLU
if model_size == "0.5x":
self._stage_out_channels = [24, 48, 96, 192, 1024]
elif model_size == "1.0x":
self._stage_out_channels = [24, 116, 232, 464, 1024]
elif model_size == "1.5x":
self._stage_out_channels = [24, 176, 352, 704, 1024]
elif model_size == "2.0x":
self._stage_out_channels = [24, 244, 488, 976, 2048]
else:
raise NotImplementedError
# building first layer
input_channels = 3
output_channels = self._stage_out_channels[0] #24
#conv3x3,s=2.3->24,p=1,BN,LeakyReLU
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
act_layers(activation),
)
input_channels = output_channels #24
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
stage_names = ["stage{}".format(i) for i in [2, 3, 4]] #paper中的stage2~4
#zip是把几个数组的元素打包,以最短的数组为基准
for name, repeats, output_channels in zip(
stage_names, self.stage_repeats, self._stage_out_channels[1:]
):
#看paper中的表格,stride=2的repeat一次,stride=1的repeat多少次
seq = [
ShuffleV2Block(
input_channels, output_channels, 2, activation=activation
)
]
for i in range(repeats - 1):
seq.append(
ShuffleV2Block(
output_channels, output_channels, 1, activation=activation
)
)
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
if self.with_last_conv:
conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
act_layers(activation),
)
self.stage4.add_module("conv5", conv5)
self._initialize_weights(pretrain)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
output = []
for i in range(2, 5):
stage = getattr(self, "stage{}".format(i))
x = stage(x)
if i in self.out_stages:
output.append(x)
return tuple(output)
ShuffleNetV2 block
class ShuffleV2Block(nn.Module):
def __init__(self, inp, oup, stride, activation="ReLU"):
super(ShuffleV2Block, self).__init__()
if not (1 <= stride <= 3):
raise ValueError("illegal stride value")
self.stride = stride
branch_features = oup // 2 #每个branch分配一半的channel
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(
inp, inp, kernel_size=3, stride=self.stride, padding=1
),
nn.BatchNorm2d(inp),
nn.Conv2d(
inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False
),
nn.BatchNorm2d(branch_features),
act_layers(activation),
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(
inp if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
act_layers(activation),
self.depthwise_conv(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
),
nn.BatchNorm2d(branch_features),
nn.Conv2d(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
act_layers(activation),
)
@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1) #在dim=1(channel)上分成2块
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
channel shuffle在这篇文章里面说过。
channel数转为矩阵,矩阵转置再压平。
def channel_shuffle(x, groups):
# type: (torch.Tensor, int) -> torch.Tensor
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
#channel变为groups x channels_per_group
x = x.view(batchsize, groups, channels_per_group, height, width)
#转置为channels_per_group x groups
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x