视频去噪网络BSVD的实现

news2024/10/3 20:23:15

前些天写了视频去噪网络BSVD论文的理解,详情请点击这里,这两个星期动手实践了一下,本篇就来记录一下这个模型的实现。

这个网络的独特之处在于,它的训练和推理在实现上有所差别。在训练阶段,其使用了TSM(Time Shift Module)结构,而在推理时则使用了BBB(Bidirectional Buffer Block)结构。训练时,网络是一个MIMO(多输入多输出)形式,而在推理时,则将其设计成了单输入、单输出的流式形式。推理时,由于网络中存在16个双向buffer,即BBB,因此,前16帧会输出空数据,16帧之后开始正常输出去噪视频帧,到视频序列结束后,还会继续输出16帧的去噪视频帧,也就是,流式推理整体存在16帧的延迟。这在一些对实时性要求不太高的应用中可以推广,但对于实时性要求严格,并且存储资源有限的应用中,就无法有效应用了。

下面,我们就通过对官方代码的理解,来聊一聊BSVD的实现。

官方代码地址:GitHub - ChenyangQiQi/BSVD: [ACM MM 2022] Real-time Streaming Video Denoising with Bidirectional Buffers

BSVD网络采用了两个UNet级联的方式。

1. 训练阶段的网络实现

在训练阶段,网络的实现如下:

class WNet(nn.Module):
    def __init__(self, chns=[32, 64, 128], mid_ch=3, shift_input=False, stage_num=2, in_ch=4, out_ch=3, norm='bn', act='relu', interm_ch=30, blind=False):
    # def __init__(self, chns=[32, 64, 128], mid_ch=3, shift_input=False, stage_num=2, in_ch=4, out_ch=3, norm='bn', act='relu', blind=False):
        super(WNet, self).__init__()
        
        self.stage_num = stage_num
        self.nets_list = nn.ModuleList()
        for i in np.arange(stage_num):
            if i == 0:
                stage_in_ch = in_ch
            else:
                stage_in_ch = mid_ch
            if i == (stage_num-1):
                stage_out_ch = out_ch
            else:
                stage_out_ch = mid_ch
                
            # self.nets_list.append(DenBlock(chns=chns, out_ch=stage_out_ch, in_ch=stage_in_ch, shift_input=shift_input, norm=norm, act=act, interm_ch=interm_ch))
            
            if i == 0:
                self.nets_list.append(DenBlock(chns=chns, out_ch=stage_out_ch, in_ch=stage_in_ch, shift_input=shift_input, norm=norm, act=act, blind=blind, interm_ch=interm_ch))
            else:
                self.nets_list.append(DenBlock(chns=chns, out_ch=stage_out_ch,
                                           in_ch=stage_in_ch, shift_input=shift_input, norm=norm, act=act, interm_ch=interm_ch))
        # self.temp2 = DenBlock(chns=chns, in_ch=mid_ch, shift_input=shift_input)

        # Init weights
        self.reset_params()

    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def reset_params(self):
        for _, m in enumerate(self.modules()):
            self.weight_init(m)

    def forward(self, x, debug=False):
        # if debug: x_in = x
        # x = self.temp1(x)
        for i in np.arange(self.stage_num):
            if debug: x_temp1 = x
            x = self.nets_list[i](x)
        # if debug: x_temp2 = x
        return x

网络由两个DenBlock组成,每个DenBlock是一个UNet结构:


class DenBlock(nn.Module):
    """ Definition of the denosing block of FastDVDnet.
    Inputs of constructor:
        num_input_frames: int. number of input frames
    Inputs of forward():
        xn: input frames of dim [N, C, H, W], (C=3 RGB)
        noise_map: array with noise map of dim [N, 1, H, W]
    """

    def __init__(self, chns=[32, 64, 128], out_ch=3, in_ch=4, shift_input=False, norm='bn', bias=True,  act='relu', interm_ch=30, blind=False):
    # def __init__(self, chns=[32, 64, 128], out_ch=3, in_ch=4, shift_input=False, norm='bn', bias=True,  act='relu', blind=False):
        super(DenBlock, self).__init__()
        self.chs_lyr0, self.chs_lyr1, self.chs_lyr2 = chns
        
        # if stage2: in_ch=3
        if shift_input:
            self.inc = CvBlock(in_ch=in_ch, out_ch=self.chs_lyr0, norm=norm, bias=bias, act=act)
        else:
            self.inc = InputCvBlock(
                num_in_frames=1, out_ch=self.chs_lyr0, in_ch=in_ch, norm=norm, bias=bias, act=act, interm_ch=interm_ch, blind=blind)
                # num_in_frames=1, out_ch=self.chs_lyr0, in_ch=in_ch, norm=norm, bias=bias, act=act, blind=blind)
        self.downc0 = DownBlock(in_ch=self.chs_lyr0, out_ch=self.chs_lyr1, norm=norm, bias=bias, act=act)
        self.downc1 = DownBlock(in_ch=self.chs_lyr1, out_ch=self.chs_lyr2, norm=norm, bias=bias, act=act)
        self.upc2 = UpBlock(in_ch=self.chs_lyr2, out_ch=self.chs_lyr1, norm=norm, bias=bias,    act=act)
        self.upc1 = UpBlock(in_ch=self.chs_lyr1, out_ch=self.chs_lyr0, norm=norm, bias=bias,    act=act)
        self.outc = OutputCvBlock(in_ch=self.chs_lyr0, out_ch=out_ch, norm=norm, bias=bias,     act=act)

        self.reset_params()

    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def reset_params(self):
        for _, m in enumerate(self.modules()):
            self.weight_init(m)

    def forward(self, in1):
        '''Args:
            inX: Tensor, [N, C, H, W] in the [0., 1.] range
            noise_map: Tensor [N, 1, H, W] in the [0., 1.] range
        '''
        # Input convolution block
        x0 = self.inc(in1)
        # Downsampling
        x1 = self.downc0(x0)
        x2 = self.downc1(x1)
        # Upsampling
        x2 = self.upc2(x2)
        x1 = self.upc1(x1+x2)
        # Estimation
        x = self.outc(x0+x1)

        # Residual
        x[:, :3, :, :] = in1[:, :3, :, :] - x[:, :3, :, :]

        return x

这段代码与论文中的UNet结构相对应(见下图),包含一个输入层,两个下采样层,两个上采样层,一个输出层。

输入层没什么特别可说的,主要是两个Conv2d=>BN=>ReLU的组合;输出层也是常规实现,Con2d=>BN=>ReLU=>Con2d,需要注意的是,作者在实现过程中,BN层是没有使用的,是透传通过。

需要花心思理解的是下采样层和上采样层的实现,因为这两个模块在训练和推理过程中,是有所不同的。

两个模块的初始实现很简单,定义如下:

class DownBlock(nn.Module):
    '''Downscale + (Conv2d => BN => ReLU)*2'''

    def __init__(self, in_ch, out_ch, norm='bn', bias=True, act='relu'):
        super(DownBlock, self).__init__()
        norm_fn = get_norm_function(norm)
        act_fn = get_act_function(act)
        self.convblock = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3,
                      padding=1, stride=2, bias=bias),
            norm_fn(out_ch),
            act_fn(inplace=True),
            CvBlock(out_ch, out_ch, norm=norm, bias=bias, act=act)
        )

    def forward(self, x):
        return self.convblock(x)


class UpBlock(nn.Module):
    '''(Conv2d => BN => ReLU)*2 + Upscale'''

    def __init__(self, in_ch, out_ch, norm='bn', bias=True, act='relu'):
        super(UpBlock, self).__init__()
        # norm_fn = get_norm_function(norm)
        self.convblock = nn.Sequential(
            CvBlock(in_ch, in_ch, norm=norm, bias=bias, act=act),
            nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1, bias=bias),
            nn.PixelShuffle(2)
        )

        return self.convblock(x)

关键在于两者共同调用的子模块CvBlock的实现,在定义时,CvBlock被常规定义为:

class CvBlock(nn.Module):
    '''(Conv2d => BN => ReLU) x 2'''

    def __init__(self, in_ch, out_ch, norm='bn', bias=True, act='relu'):
        super(CvBlock, self).__init__()
        norm_fn = get_norm_function(norm)
        act_fn = get_act_function(act)
        self.c1 = nn.Conv2d(in_ch, out_ch, kernel_size=3,
                            padding=1, bias=bias)
        self.b1 = norm_fn(out_ch)
        self.relu1 = act_fn(inplace=True)
        self.c2 = nn.Conv2d(out_ch, out_ch, kernel_size=3,
                            padding=1, bias=bias)
        self.b2 = norm_fn(out_ch)
        self.relu2 = act_fn(inplace=True)

    def forward(self, x):
        x = self.c1(x)
        x = self.b1(x)
        x = self.relu1(x)
        x = self.c2(x)
        x = self.b2(x)
        x = self.relu2(x)
        return x

但接下来,上述定义中的c1和c2则被替换成了TSM实现:

其中,shift模块的核心实现代码如下,对输入的channels分别向左和向右移动了一定单位(fold)。

def shift(x, n_segment, shift_type, fold_div=3, stride=1, inplace=False):
    nt, c, h, w = x.size()
    n_batch = nt // n_segment
    x = x.view(n_batch, n_segment, c, h, w)

    fold = c // fold_div # 32/8 = 4

    if inplace:
        # Due to some out of order error when performing parallel computing. 
        # May need to write a CUDA kernel.
        print("WARNING: use inplace shift. it has bugs")
        raise NotImplementedError  
        
    else:
        out = torch.zeros_like(x)
        if not 'toFutureOnly' in shift_type:
            out[:, :-stride, :fold] = x[:, stride:, :fold]  # backward (left shift)
            out[:, stride:, fold: 2 * fold] = x[:, :-stride, fold: 2 * fold]  # forward (right shift)
        else:
            out[:, stride:, : 2 * fold] = x[:, :-stride, : 2 * fold] # right shift only
        out[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shift

    return out.view(nt, c, h, w)

2. 推理阶段的网络实现

在推理阶段,网络实现就显得复杂一些了。大致的网络结构没变,但由于内部的TSM替换成了BBB, 因此没办法严格进行整体网络的加载,只能每一层单独加载训练出来的state_dict。并且,网络推理变成了流式推理,整个网络的定义显得比较凌乱,结构如下:

class BSVD(nn.Module):
    """
        Bidirection-buffer based framework with pipeline-style inference
    """
    def __init__(self, chns=[32, 64, 128], mid_ch=3, shift_input=False, in_ch=4, out_ch=3, norm='bn', act='relu', interm_ch=30, blind=False, 
                 pretrain_ckpt='./experiments/pretrained_ckpt/bsvd-64.pth'):
        super(BSVD, self).__init__()
        self.temp1 = DenBlock(chns=chns, out_ch=mid_ch, in_ch=in_ch,  shift_input=shift_input, norm=norm, act=act, blind=blind, interm_ch=interm_ch)
        self.temp2 = DenBlock(chns=chns, out_ch=out_ch, in_ch=mid_ch, shift_input=shift_input, norm=norm, act=act, blind=blind, interm_ch=interm_ch)

        self.shift_num = self.count_shift()
        # Init weights
        self.reset_params()
        if pretrain_ckpt is not None:
            self.load(pretrain_ckpt)
 
    def reset(self):
        self.temp1.reset()
        self.temp2.reset()
    def load(self, path):
        ckpt = torch.load(path)
        print("load from %s"%path)
        ckpt_state = ckpt['params']
        # split the dict here
        if 'module' in list(ckpt_state.keys())[0]:
            base_name = 'module.base_model.'
        else:
            base_name = 'base_model.'
        ckpt_state_1 = extract_dict(ckpt_state, string_name=base_name+'nets_list.0.')
        ckpt_state_2 = extract_dict(ckpt_state, string_name=base_name+'nets_list.1.')
        self.temp1.load_from(ckpt_state_1)
        self.temp2.load_from(ckpt_state_2)
            
    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def reset_params(self):
        for _, m in enumerate(self.modules()):
            self.weight_init(m)

    def feedin_one_element(self, x):
        x   = self.temp1(x)
        x   = self.temp2(x)
        return x
    
    def forward(self, input, noise_map=None):
        # N, F, C, H, W -> (N*F, C, H, W)
        if noise_map != None:
            input = torch.cat([input, noise_map], dim=2)
        N, F, C, H, W = input.shape
        input = input.reshape(N*F, C, H, W)
        base_out = self.streaming_forward(input)
        NF, C, H, W = base_out.shape
        base_out = base_out.reshape(N, F, C, H, W)
        return base_out
    
    def streaming_forward(self, input_seq):
        """
        pipeline-style inference

        Args:
            Noisy video stream

        Returns:
            Denoised video stream
        """
        out_seq = []
        if isinstance(input_seq, torch.Tensor):
            n,c,h,w = input_seq.shape
            input_seq = [input_seq[i:i+1, ...] for i in np.arange(n)]
        assert type(input_seq) == list, "convert the input into a sequence"
        _,c,h,w = input_seq[0].shape
        with torch.no_grad():
            for i, x in enumerate(input_seq):
 
                x_cuda = x.cuda()
                x_cuda = self.feedin_one_element(x_cuda)
                # if x_cuda is not None: x_cuda = x_cuda.cpu()
                if isinstance(x_cuda, torch.Tensor):
                    out_seq.append(x_cuda)
                else:
                    out_seq.append(x_cuda)

            end_out = self.feedin_one_element(None)

            out_seq.append(end_out)

            # end stage
            while 1:
                end_out = self.feedin_one_element(None)
                
                if len(out_seq) == (self.shift_num+len(input_seq)):
                    break

                out_seq.append(end_out)

            # number of temporal shift is 2, last element is 0
            # TODO fix init and end frames
            out_seq_clip = out_seq[self.shift_num:]
            self.reset()
            return torch.cat(out_seq_clip, dim=0)

    def count_shift(self):
        count = 0
        for name, module in self.named_modules():
            # print(type(module))
            if "BiBufferConv" in str(type(module)):
                count+=1
        return count

两个UNet的定义(DenBlock)大体上没发生变化,但下采样模块和上采样模块的定义发生了改变。

下采样层如下,原来带有TSM的CvBlock换成了MemCvBlock:

上采样模块也类似:

 

而MemCvBlock则调用了BBB模块,BBB模块的实现如下,这是整个算法的核心:

class BiBufferConv(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            bias=True
        ) -> None:
        super(BiBufferConv, self).__init__()
        self.op = ShiftConv(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias
        )
        self.out_channels = out_channels
        self.left_fold_2fold = None
        # self.zero_tensor = None
        self.center = None
        
    def reset(self):
        self.left_fold_2fold = None
        self.center = None
        
    def forward(self, input_right, verbose=False):
        fold_div = 8
        if input_right is not None:
            self.n, self.c, self.h, self.w = input_right.size()
            self.fold = self.c//fold_div
        # Case1: In the start or end stage, the memory is empty
        if self.center is None:
            self.center = input_right
            # if verbose:
            
            if input_right is not None:
                if self.left_fold_2fold is None:
                    # In the start stage, the memory and left tensor is empty

                    self.left_fold_2fold = torch.zeros((self.n, self.fold, self.h, self.w), device=torch.device('cuda'))
                if verbose: print("%f+none+%f = none"%(torch.mean(self.left_fold_2fold), torch.mean(input_right)))
            else:
                # in the end stage, both feed in and memory are empty
                if verbose: print("%f+none+none = none"%(torch.mean(self.left_fold_2fold)))
                # print("self.center is None")
            return None
        # Case2: Center is not None, but input_right is None
        elif input_right is None:
            # In the last procesing stage, center is 0
            output =  self.op(self.left_fold_2fold, self.center, torch.zeros((self.n, self.fold, self.h, self.w), device=torch.device('cuda')))
            if verbose: print("%f+%f+none = %f"%(torch.mean(self.left_fold_2fold), torch.mean(self.center), torch.mean(output)))
        else:
            
            output =  self.op(self.left_fold_2fold, self.center, input_right)
            if verbose: print("%f+%f+%f = %f"%(torch.mean(self.left_fold_2fold), torch.mean(self.center), torch.mean(input_right), torch.mean(output)))
            # if output == 57:
                # a = 1
        self.left_fold_2fold = self.center[:, self.fold:2*self.fold, :, :]
        self.center = input_right
        return output

这样,通过BBB模块,就实现了16个双向Buffer的填充、更新和清空。

限于篇幅,先梳理出个大体的思路,实际上还有很多细节需要特别关注,留待下一篇来写吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1125184.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

中国象棋棋盘识别

当象棋爱好者在挑战中国象棋残局或者在阅读象棋杀法书籍的时候遇到问题,往往需要通过象棋软件来辅助提示,此时要将该棋局在象棋软件中摆好,软件才能进行分析,为实现自动识别棋局图片,并导出为标准化FEN象棋文件格式&am…

Lec08 Page faults笔记总结

当一个用户应用程序触发了page fault,page fault会使用与Robert教授上节课介绍的相同的trap机制,将程序运行切换到内核,同时也会将出错的地址存放在STVAL寄存器中。 在SCAUSE(注,Supervisor cause寄存器,保…

MyBatis-Plus实现逻辑删除[MyBatis-Plus系列] - 492篇

历史文章(文章累计490) 《国内最全的Spring Boot系列之一》 《国内最全的Spring Boot系列之二》 《国内最全的Spring Boot系列之三》 《国内最全的Spring Boot系列之四》 《国内最全的Spring Boot系列之五》 《国内最全的Spring Boot系列之六》 M…

[Spring] SpringBoot2 简介(二)—— 高级配置

目录 一、Conditional 注解 1、SpringBoot 如何获取 Bean 对象 2、SpringBoot 创建 Condition 类 3、切换内置 web 服务器 二、EnableXXX 注解 1、SpringBoot 不能直接获取其他 jar 包/工程中的 Bean 2、原因分析 3、封装 Import 4、Import 注解 5、SpringBoot 自动配…

前端领域的插件式设计

插件,是一个常见的概念。 例如,当我们需要把我们前端代码中的 css 样式提取打包,我们可以用 webpack 的 mini-css-extract-plugin,或者你如果用 rollup 的话,可以选择 rollup-plugin-postcss。 再比如我们可以给 bab…

Python 数组和列表:创建、访问、添加和删除数组元素

Python 没有内置支持数组,但可以使用 Python 列表来代替。 数组 本页将向您展示如何使用列表作为数组,但要在 Python 中使用数组,您需要导入一个库,比如 NumPy 库。数组用于在一个变量中存储多个值: 示例&#xff0…

VSCode 开发 Vue 语法提示

一. 打开应用商店,搜索 vetur ,选择第一个,点击安装。 二. 安装完成后,还可以下载 Vue Language Features 解决代码警告的问题。 最后重启 VSCode 就可以使用啦。另外输入 按回车键还可以自动生成 vue 代码格式哦。 原创作者&…

GPT-3 内幕机制可视化解析

GPT-3 内幕机制可视化解析 GPT-3是一个基于Transformer的语言模型,通过不同的层次提取语言不同层面的特性,构建整个语言的语义信息,它学习的过程跟人类正常学习的过程是类似的,开始的时候是一个无监督预训练,如图5-5所示,GPT-3模型可以将网络上的所有文档下载下来,包含 …

AD9371 官方例程HDL详解之JESD204B TX侧时钟生成 (三)

AD9371 系列快速入口 AD9371ZCU102 移植到 ZCU106 : AD9371 官方例程构建及单音信号收发 ad9371_tx_jesd -->util_ad9371_xcvr接口映射: AD9371 官方例程之 tx_jesd 与 xcvr接口映射 AD9371 官方例程 时钟间的关系与生成 : AD9371 官方…

Sui提供dApp Kit 助力快速构建React Apps和dApps

近日,Mysten Labs推出了dApp Kit,这是一个全新的解决方案,可用于在Sui上开发React应用程序和去中心化应用程序(dApps)。mysten/dapp-kit是专门为React定制的全新SDK,旨在简化诸如连接钱包、签署交易和从RPC…

Python生成词云

成品: 代码: import os# 下面的两个包大家注意别导错了 from imageio.v2 import imread from wordcloud import wordcloud# mytext文本是字符串类型的 mytext str() # os.getcwd()是获得当前目录的路径,好像没啥用 读取 with open(os.getcw…

35岁运维工程师到底该何去何从?

你是否经常在网上看到类似的帖子: “运维35岁被裁”、“35岁运维找不到工作”,这样的字眼频频出现在新闻中。如何度过35岁职场危机呢,不妨看看这篇文章,或许对你有启发! 一、35岁被称为运维半衰期,究竟为何…

性能测试:系统架构性能优化思路

今天谈下业务系统性能问题分析诊断和性能优化方面的内容。这篇文章重点还是谈已经上线的业务系统后续出现性能问题后的问题诊断和优化重点。 系统性能问题分析流程 我们首先来分析下如果一个业务系统上线前没有性能问题,而在上线后出现了比较严重的性能问题&#x…

【算法练习Day26】分发饼干摆动序列 最大子数组和

​📝个人主页:Sherry的成长之路 🏠学习社区:Sherry的成长之路(个人社区) 📖专栏链接:练题 🎯长路漫漫浩浩,万事皆有期待 文章目录 分发饼干摆动序列最大子数组…

金蝶云星空企业版v8.0内网穿透配置详解:实现便捷的异地远程访问

文章目录 前言1. 金蝶云星空企业版v8.0安装下载1.1 登录金蝶官网下载安装包1.2 常见的安装下载问题 2. 金蝶云星空配置SQL Sever数据库2.1 创建数据管理中心2.2 创建完成后在服务器登录管理站点 3. 下载安装注册cpolar3.1 公网访问测试 4. 固定连接公网地址 前言 金蝶云星空专注…

关于AES加密输出密文不为128位的倍数的原因

今天尝试用AES-256-OFB加密一个flag结果输出的密文是43字节,不是128位(16字节)的倍数,代码如下: import os from Crypto.Cipher import AES databflag{a7ba7128-3917-4551-8260-b3499e9dd7b12} aes AES.new(os.urand…

如何用Pytest做性能测试?5个步骤轻松学会!

Pytest其实也是可以做性能测试或者基准测试的。是非常方便的。 可以考虑使用Pytest-benchmark类库进行。 安装pytest-benchmark 首先,确保已经安装了pytest和pytest-benchmark插件。可以使用以下命令安装插件: pip install pytest pytest-benchmark …

Apollo生态系统探索:更多工具与框架的介绍

前言 「作者主页」:雪碧有白泡泡 「个人网站」:雪碧的个人网站 「推荐专栏」: ★java一站式服务 ★ ★ React从入门到精通★ ★前端炫酷代码分享 ★ ★ 从0到英雄,vue成神之路★ ★ uniapp-从构建到提升★ ★ 从0到英雄&#xff…

选择合适的软件管理视频制作排期

如果你是一名专业的视频创作者,那么你一定知道一个清晰、高效的项目管理对于视频制作的重要性。那么如何使用Zoho Projects项目管理软件来管理的视频制作项目,以便更好地规划和执行每一个细节呢? 这款项目管理软件具有丰富的自定义字段功能&a…