ResNet
目的
Resnet网络是为了解决深度网络中的退化问题,即网络层数越深时,在数据集上表现的性能却越差。
原理
ResNet的单元结构如下:
类似动态规划的选择性继承,同时会在训练过程中逐渐增大(/缩小)该单元中权重层的参数,主要取决于是否是直接继承前面块更优。
实现
对于ResNet50及以上来说,采用的单元块是Bottleneck模块。
在实现Bottleneck模块前,需要先对ResNet中使用到的卷积核进行简化定义。
-
首先是卷积核kernel_sizef分别为1和3的定义:
def conv1x1(in_channel, out_channel, stride=1): return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False) def conv3x3(in_channel, out_channel, stride=1): return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
-
接着定义Bottleneck单元模块:
这里有一个涉及到梯度是否能计算的问题,如果在bn3之后又进行了一次relu操作,然后再自加等,由于relu操作是限定为原地进行的,这就会导致在反向推导时无法计算出梯度,具体原因有待考究。
class Bottleneck(nn.Module): extension = 4 # Bottleneck only decrease the [h,w] in conv1 when stride > 1, # so the [h,w] is to be [(h-1)/stride+1,(w-1)/stride+1]. # the in_channel will be change to channel*extension. # channel is the temp variable. def __init__(self, in_channel, channel, stride, downsample=None): super(Bottleneck, self).__init__() self.conv1 = conv1x1(in_channel, channel, stride) self.bn1 = nn.BatchNorm2d(channel) self.conv2 = conv3x3(channel, channel) self.bn2 = nn.BatchNorm2d(channel) self.conv3 = conv1x1(channel, channel * self.extension) self.bn3 = nn.BatchNorm2d(channel * self.extension) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.conv3(x) x = self.bn3(x) if self.downsample is not None: identity = self.downsample(identity) x += identity x = self.relu(x) return x
-
最后是ResNet的主体,主体包含前向传播函数和构造集合体模块层函数:
ResNet总共可看作6层结构。
-
第一层为大卷积核层,主要是以大卷积核进行卷积,同时将通道数上升到64。[h,w]=[h/2,w/2]。
-
第二至五层是残差模块,其中残差模块由多层Bottleneck组成。多层Bottleneck的第一层的in_channel为上一个模块的out_channel,中间的in_channel则为多层Bottleneck的上一层out_channel,每个Bottleneck的plane为其in_channel的1/2。
第二层的stride为1,但是有maxpool来使得图片尺寸缩小,其他层则通过stride=2使得图片尺寸缩小。
-
第六层则是全连接层,使用torch.flatten进行缩维度处理。
class ResNet(nn.Module):
# size / 32 / 7
def __init__(self, block, layers, num_class):
super(ResNet, self).__init__()
# the first layer changes the channel to 64,
# and the [h,w] will be change to [(h-1)/stride+1,(w-1)/stride+1] after the first layer.
self.in_channel = 64
self.block = block
self.layers = layers
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# there are four block layers, each layers contains more than one block.
self.stage1 = self.make_layer(self.block, 64, layers[0], stride=1)
self.stage2 = self.make_layer(self.block, 128, layers[1], stride=2)
self.stage3 = self.make_layer(self.block, 256, layers[2], stride=2)
self.stage4 = self.make_layer(self.block, 512, layers[3], stride=2)
# in the end, there will be a linear layer to classify all the classes.
# self.avgpool = nn.AvgPool2d(7)
self.fc = nn.Linear(512 * block.extension, num_class)
def make_layer(self, block, plane, block_num, stride=1):
block_list = []
downsample = None
# if the in_channel isn't equal to the out_channel,
# downsample will be needed to process the in_channel to same size as the out_channel
# so that the in_channel can be added to the out_channel to achieve the resnet struct.
if stride != 1 or self.in_channel != plane * block.extension:
downsample = nn.Sequential(
conv1x1(self.in_channel, plane * block.extension, stride),
nn.BatchNorm2d(plane * block.extension)
)
conv_block = block(self.in_channel, plane, stride, downsample=downsample)
# the first block's in_channel is different to the another block_num-1 in_channel.
block_list.append(conv_block)
# modify the in_channel for the next stage layer.
self.in_channel = plane * block.extension
for _ in range(1, block_num):
block_list.append(block(self.in_channel, plane, stride=1))
return nn.Sequential(*block_list)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
# x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
x = nn.Softmax(dim=1)(x)
return x