本文代码主要来自于OpenMMLab提供的MMEditing开源工具箱中的BasicVSR++代码。第一部分的解读主要是针对每一个部分是在做什么提供一个解释,便于后续细读每一个块的细节代码。
(1)导入库
basicvsr_plusplus_net中主要继承了torch,mmcv,mmengine,mmedit这几个三方的库,还继承了上一作模型basicvsr,主要是继承了其中的特征提取和光流这两个部分的内容。
(2)初始化
使用def __init__
和super().__init__
(调用父类方法)去进行了初始化,主要是对一些训练参数,是否采用预训练光流和对后面各个模块进行了一定的初始化。
(3)镜像扩展序列检测模块def check_if_mirror_extended(self, lqs)
这一块主要是检测输入的数据lqs(low quality sequence)低质量序列是否是镜像扩展的。我们知道对于一个图片,把他变成张量主要有4个参数(n,c,h,w)主要就是一个bitch的图片数,通道数,长,宽。但是在该模型中输入的数据是(n,t,c,h,w)多了一个t维的数据,在检测中也是依据这个维度进行检测的(check if the sequence is augmented by flipping,检查序列是否通过翻转增广)。
关于检测输入是否为镜像扩展,最主要的作用就是为后续计算光流做准备,因为图像是从前往后,还是从后往前的输入,它的下一帧的光流是不一样的,而在basicvsr++中既要从前往后,也要从后往前,所以这一步是必要的,
(4)光流计算模块def compute_flow(self, lqs):
这里会根据预训练光流模型,根据视频帧流是正向还是反向决定返回的光流是正向光流还是反向光流。
(5)传播模块def propagate(self, feats, flows, module_name):
根据Basicvsr++模型,主要将步骤分为了传播、对齐、聚合、上采样这4步。而在代码中传播、对齐、聚合这三步被放在了一个模块中(该模块)进行训练。根据Basicvsr++模型
我们可以直观的看到他有蓝色,绿色,它们分别是指的跨一帧,跨两帧的前向或后向特征信息,所以我们在写代码的时候要区分好每一步的名字。故在输入参数里面就包含了特征,光流,模型名字这三个参数。另外这个模块的返回值也是feats也就是新的特征,他是计算完成一步后,传送到下一步的特征,也就是此前所有步骤融合后的特征。此模块中包含二阶可变形对齐和光流引导可变形卷积这两个非常重要的创新点。最后也是进行了聚合这个步骤。
(6)上采样模块def upsample(self, lqs, feats):
这里就是根据模型论文写的一样,进行了上采样,输入的低质量序列的shape为(n, t, c, h, w),最后输出的HR序列的shape为(n, t, c, 4h, 4w)。
(7)前向传播模块def forward(self, lqs):
这里就是传统神经网络所做的前向传播模块。
(8)两阶段对齐模块class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
传播模块中会调用某个参数,那个参数在初始化的时候是根据这个模块产生的,总之可以理解为这个模块定义了什么是两阶段对齐,并在传播模块中被使用到了。