车道线检测模型 RESA
该模型只有一个关键点就是resa模块,把这个想清楚就没什么了,下面看代码
class RESA(nn.Module):
def __init__(self, cfg):
super(RESA, self).__init__()
# self.iter = cfg.resa.iter
# chan = cfg.resa.input_channel
# fea_stride = cfg.backbone.fea_stride
# self.height = cfg.img_height // fea_stride
# self.width = cfg.img_width // fea_stride
# self.alpha = cfg.resa.alpha
# conv_stride = cfg.resa.conv_stride
self.iter = 5 #5
chan = 64 #128
fea_stride = 4 #8
self.height = 96
self.width =160
# print("self.width",self.width)
# print("self.height",self.height)
self.alpha = 2.0 #2
conv_stride = 9 #9
for i in range(self.iter):
conv_vert1 = nn.Conv2d(
chan, chan, (1, conv_stride),
padding=(0, conv_stride//2), groups=1, bias=False)
conv_vert2 = nn.Conv2d(
chan, chan, (1, conv_stride),
padding=(0, conv_stride//2), groups=1, bias=False)
setattr(self, 'conv_d'+str(i), conv_vert1)
setattr(self, 'conv_u'+str(i), conv_vert2)
conv_hori1 = nn.Conv2d(
chan, chan, (conv_stride, 1),
padding=(conv_stride//2, 0), groups=1, bias=False)
conv_hori2 = nn.Conv2d(
chan, chan, (conv_stride, 1),
padding=(conv_stride//2, 0), groups=1, bias=False)
setattr(self, 'conv_r'+str(i), conv_hori1)
setattr(self, 'conv_l'+str(i), conv_hori2)
idx_d = (torch.arange(self.height) + self.height //
2**(self.iter - i)) % self.height
setattr(self, 'idx_d'+str(i), idx_d)
idx_u = (torch.arange(self.height) - self.height //
2**(self.iter - i)) % self.height
setattr(self, 'idx_u'+str(i), idx_u)
idx_r = (torch.arange(self.width) + self.width //
2**(self.iter - i)) % self.width
setattr(self, 'idx_r'+str(i), idx_r)
idx_l = (torch.arange(self.width) - self.width //
2**(self.iter - i)) % self.width
setattr(self, 'idx_l'+str(i), idx_l)
def forward(self, x):
print('------------------',x.shape)
print(x.shape)
x = x.clone()
for direction in ['d', 'u']:
for i in range(self.iter):
conv = getattr(self, 'conv_' + direction + str(i))
idx = getattr(self, 'idx_' + direction + str(i))
x.add_(self.alpha * F.relu(conv(x[..., idx, :])))
for direction in ['r', 'l']:
for i in range(self.iter):
conv = getattr(self, 'conv_' + direction + str(i))
idx = getattr(self, 'idx_' + direction + str(i))
x.add_(self.alpha * F.relu(conv(x[..., idx])))
return x
上述代码中的一些超参数,是我自己设置的,便于看,免得看config了,这个的关键就是如何x.add_是怎么加的,这里面用到了一些索引,我们具体来dubug看一下
循环iter,
iter=0时
iter=1时
iter=2时
iter=3时
iter=4时
看到这里大家应该就明白了吧,主要实现错位的相加,依照这个顺序执行的啊,这样就实现了文中说的消息的传递,比CNN好