论文链接:
https://arxiv.org/abs/2307.08388
代码链接:
https://github.com/YaoleiQi/DSCNet
下面直接上代码,并且源码中也给了测试用例,是一个即插即用的模块
import os
import torch
import numpy as np
from torch import nn
import warnings
warnings.filterwarnings("ignore")
"""
This code is mainly the deformation process of our DSConv
"""
class DSConv(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, extend_scope, morph,
if_offset, device):
"""
The Dynamic Snake Convolution
:param in_ch: input channel
:param out_ch: output channel
:param kernel_size: the size of kernel
:param extend_scope: the range to expand (default 1 for this method)
:param morph: the morphology of the convolution kernel is mainly divided into two types
along the x-axis (0) and the y-axis (1) (see the paper for details)
:param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
:param device: set on gpu
"""
super(DSConv, self).__init__()
# use the <offset_conv> to learn the deformable offset
self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
self.bn = nn.BatchNorm2d(2 * kernel_size)
self.kernel_size = kernel_size
# two types of the DSConv (along x-axis and y-axis)
self.dsc_conv_x = nn.Conv2d(
in_ch,
out_ch,
kernel_size=(kernel_size, 1),
stride=(kernel_size, 1),
padding=0,
)
self.dsc_conv_y = nn.Conv2d(
in_ch,
out_ch,
kernel_size=(1, kernel_size),
stride=(1, kernel_size),
padding=0,
)
self.gn = nn.GroupNorm(out_ch // 4, out_ch)
self.relu = nn.ReLU(inplace=True)
self.extend_scope = extend_scope
self.morph = morph
self.if_offset = if_offset
self.device = device
def forward(self, f):
offset = self.offset_conv(f)
offset = self.bn(offset)
# We need a range of deformation between -1 and 1 to mimic the snake's swing
offset = torch.tanh(offset)
input_shape = f.shape
dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph,
self.device)
deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
if self.morph == 0:
x = self.dsc_conv_x(deformed_feature)
x = self.gn(x)
x = self.relu(x)
return x
else:
x = self.dsc_conv_y(deformed_feature)
x = self.gn(x)
x = self.relu(x)
return x
# Core code, for ease of understanding, we mark the dimensions of input and output next to the code
class DSC(object):
def __init__(self, input_shape, kernel_size, extend_scope, morph, device):
self.num_points = kernel_size
self.width = input_shape[2]
self.height = input_shape[3]
self.morph = morph
self.device = device
self.extend_scope = extend_scope # offset (-1 ~ 1) * extend_scope
# define feature map shape
"""
B: Batch size C: Channel W: Width H: Height
"""
self.num_batch = input_shape[0]
self.num_channels = input_shape[1]
"""
input: offset [B,2*K,W,H] K: Kernel size (2*K: 2D image, deformation contains <x_offset> and <y_offset>)
output_x: [B,1,W,K*H] coordinate map
output_y: [B,1,K*W,H] coordinate map
"""
def _coordinate_map_3D(self, offset, if_offset):
# offset
y_offset, x_offset = torch.split(offset, self.num_points, dim=1)
y_center = torch.arange(0, self.width).repeat([self.height])
y_center = y_center.reshape(self.height, self.width)
y_center = y_center.permute(1, 0)
y_center = y_center.reshape([-1, self.width, self.height])
y_center = y_center.repeat([self.num_points, 1, 1]).float()
y_center = y_center.unsqueeze(0)
x_center = torch.arange(0, self.height).repeat([self.width])
x_center = x_center.reshape(self.width, self.height)
x_center = x_center.permute(0, 1)
x_center = x_center.reshape([-1, self.width, self.height])
x_center = x_center.repeat([self.num_points, 1, 1]).float()
x_center = x_center.unsqueeze(0)
if self.morph == 0:
"""
Initialize the kernel and flatten the kernel
y: only need 0
x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
!!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step
"""
y = torch.linspace(0, 0, 1)
x = torch.linspace(
-int(self.num_points // 2),
int(self.num_points // 2),
int(self.num_points),
)
y, x = torch.meshgrid(y, x)
y_spread = y.reshape(-1, 1)
x_spread = x.reshape(-1, 1)
y_grid = y_spread.repeat([1, self.width * self.height])
y_grid = y_grid.reshape([self.num_points, self.width, self.height])
y_grid = y_grid.unsqueeze(0) # [B*K*K, W,H]
x_grid = x_spread.repeat([1, self.width * self.height])
x_grid = x_grid.reshape([self.num_points, self.width, self.height])
x_grid = x_grid.unsqueeze(0) # [B*K*K, W,H]
y_new = y_center + y_grid
x_new = x_center + x_grid
y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(self.device)
x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(self.device)
y_offset_new = y_offset.detach().clone()
if if_offset:
y_offset = y_offset.permute(1, 0, 2, 3)
y_offset_new = y_offset_new.permute(1, 0, 2, 3)
center = int(self.num_points // 2)
# The center position remains unchanged and the rest of the positions begin to swing
# This part is quite simple. The main idea is that "offset is an iterative process"
y_offset_new[center] = 0
for index in range(1, center):
y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])
y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])
y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(self.device)
y_new = y_new.add(y_offset_new.mul(self.extend_scope))
y_new = y_new.reshape(
[self.num_batch, self.num_points, 1, self.width, self.height])
y_new = y_new.permute(0, 3, 1, 4, 2)
y_new = y_new.reshape([
self.num_batch, self.num_points * self.width, 1 * self.height
])
x_new = x_new.reshape(
[self.num_batch, self.num_points, 1, self.width, self.height])
x_new = x_new.permute(0, 3, 1, 4, 2)
x_new = x_new.reshape([
self.num_batch, self.num_points * self.width, 1 * self.height
])
return y_new, x_new
else:
"""
Initialize the kernel and flatten the kernel
y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
x: only need 0
"""
y = torch.linspace(
-int(self.num_points // 2),
int(self.num_points // 2),
int(self.num_points),
)
x = torch.linspace(0, 0, 1)
y, x = torch.meshgrid(y, x)
y_spread = y.reshape(-1, 1)
x_spread = x.reshape(-1, 1)
y_grid = y_spread.repeat([1, self.width * self.height])
y_grid = y_grid.reshape([self.num_points, self.width, self.height])
y_grid = y_grid.unsqueeze(0)
x_grid = x_spread.repeat([1, self.width * self.height])
x_grid = x_grid.reshape([self.num_points, self.width, self.height])
x_grid = x_grid.unsqueeze(0)
y_new = y_center + y_grid
x_new = x_center + x_grid
y_new = y_new.repeat(self.num_batch, 1, 1, 1)
x_new = x_new.repeat(self.num_batch, 1, 1, 1)
y_new = y_new.to(self.device)
x_new = x_new.to(self.device)
x_offset_new = x_offset.detach().clone()
if if_offset:
x_offset = x_offset.permute(1, 0, 2, 3)
x_offset_new = x_offset_new.permute(1, 0, 2, 3)
center = int(self.num_points // 2)
x_offset_new[center] = 0
for index in range(1, center):
x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])
x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])
x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(self.device)
x_new = x_new.add(x_offset_new.mul(self.extend_scope))
y_new = y_new.reshape(
[self.num_batch, 1, self.num_points, self.width, self.height])
y_new = y_new.permute(0, 3, 1, 4, 2)
y_new = y_new.reshape([
self.num_batch, 1 * self.width, self.num_points * self.height
])
x_new = x_new.reshape(
[self.num_batch, 1, self.num_points, self.width, self.height])
x_new = x_new.permute(0, 3, 1, 4, 2)
x_new = x_new.reshape([
self.num_batch, 1 * self.width, self.num_points * self.height
])
return y_new, x_new
"""
input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H]
output: [N,1,K*D,K*W,K*H] deformed feature map
"""
def _bilinear_interpolate_3D(self, input_feature, y, x):
y = y.reshape([-1]).float()
x = x.reshape([-1]).float()
zero = torch.zeros([]).int()
max_y = self.width - 1
max_x = self.height - 1
# find 8 grid locations
y0 = torch.floor(y).int()
y1 = y0 + 1
x0 = torch.floor(x).int()
x1 = x0 + 1
# clip out coordinates exceeding feature map volume
y0 = torch.clamp(y0, zero, max_y)
y1 = torch.clamp(y1, zero, max_y)
x0 = torch.clamp(x0, zero, max_x)
x1 = torch.clamp(x1, zero, max_x)
input_feature_flat = input_feature.flatten()
input_feature_flat = input_feature_flat.reshape(
self.num_batch, self.num_channels, self.width, self.height)
input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
dimension = self.height * self.width
base = torch.arange(self.num_batch) * dimension
base = base.reshape([-1, 1]).float()
repeat = torch.ones([self.num_points * self.width * self.height
]).unsqueeze(0)
repeat = repeat.float()
base = torch.matmul(base, repeat)
base = base.reshape([-1])
base = base.to(self.device)
base_y0 = base + y0 * self.height
base_y1 = base + y1 * self.height
# top rectangle of the neighbourhood volume
index_a0 = base_y0 - base + x0
index_c0 = base_y0 - base + x1
# bottom rectangle of the neighbourhood volume
index_a1 = base_y1 - base + x0
index_c1 = base_y1 - base + x1
# get 8 grid values
value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(self.device)
value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(self.device)
value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(self.device)
value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(self.device)
# find 8 grid locations
y0 = torch.floor(y).int()
y1 = y0 + 1
x0 = torch.floor(x).int()
x1 = x0 + 1
# clip out coordinates exceeding feature map volume
y0 = torch.clamp(y0, zero, max_y + 1)
y1 = torch.clamp(y1, zero, max_y + 1)
x0 = torch.clamp(x0, zero, max_x + 1)
x1 = torch.clamp(x1, zero, max_x + 1)
x0_float = x0.float()
x1_float = x1.float()
y0_float = y0.float()
y1_float = y1.float()
vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(self.device)
vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(self.device)
vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(self.device)
vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(self.device)
outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
value_c1 * vol_c1)
if self.morph == 0:
outputs = outputs.reshape([
self.num_batch,
self.num_points * self.width,
1 * self.height,
self.num_channels,
])
outputs = outputs.permute(0, 3, 1, 2)
else:
outputs = outputs.reshape([
self.num_batch,
1 * self.width,
self.num_points * self.height,
self.num_channels,
])
outputs = outputs.permute(0, 3, 1, 2)
return outputs
def deform_conv(self, input, offset, if_offset):
y, x = self._coordinate_map_3D(offset, if_offset)
deformed_feature = self._bilinear_interpolate_3D(input, y, x)
return deformed_feature
# Code for testing the DSConv
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
A = np.random.rand(4, 5, 6, 7)
# A = np.ones(shape=(3, 2, 2, 3), dtype=np.float32)
# print(A)
A = A.astype(dtype=np.float32)
A = torch.from_numpy(A)
# print(A.shape)
conv0 = DSConv(
in_ch=5,
out_ch=10,
kernel_size=15,
extend_scope=1,
morph=0,
if_offset=True,
device=device)
if torch.cuda.is_available():
A = A.to(device)
conv0 = conv0.to(device)
out = conv0(A)
print(out.shape)
谢谢小伙伴们多多支持,对了,如果希望出一个将其改进到不同的模型中的教程可以在评论区留言哦