CVPR2017|Deep Feature Flow for Video Recognition论文复现(pytorch版)

news2025/1/11 9:10:23

🏆引言:深度卷积神经网络在图像识别任务中取得了巨大的成功。然而,将最先进的图像识别网络转移到视频上并非易事,因为每帧评估速度太慢且负担不起。我们提出了一种快速准确的视频识别框架——深度特征流DFF。它只在稀疏关键帧上运行昂贵的卷积网络,并通过流场将其深度特征映射传播到其他帧。它实现了显著的加速,因为流计算相对较快。整个体系结构的端到端训练显著提高了识别精度。深度特征流是灵活和通用的。在目标检测和语义分割两个视频数据集上进行了验证。它极大地推进了视频识别任务的实践。

文章目录

模型架构

DFF的流程图

image-20221226165948305

feat网络相当于是pspnet的backbone,然后这个task网络就是pspnet用于预测的head

伪代码说明

image-20221226170248748

N f e a t N_{feat} Nfeat

<原文>:我们使用ResNet模型,具体来说,ResNet-50和ResNet-101模型在ImageNet预训练,最后的1000路分类层被丢弃。按照DeepLab进行语义分割,R-FCN进行目标检测的做法,将特征步幅从32缩小到16,以产生更密集的特征图。第一个block的conv5层,步距由2改为1,并且conv5中的所有3×3卷积核上应用空洞卷积,以保持视场(dilation=2)。对conv5追加一个随机初始化的3×3卷积,将特征信道维数降至1024,其中还应用了空洞卷积。生成的1024维特征映射是后续任务的中间特征矩阵 N f e a t N_{feat} Nfeat

feat网络实际上就是一个语义分割模型,作者采用了DeeplabV2,本人采用了pspnet-r101。并且语义分割模型pspnet101已在cityscape数据集上与训练好了。单独的pspnet-r101在cityscape验证集上mIOU=69

dff还有一个scale层,论文里面有说,不过我这里没有加进去,因为scale层容易干扰我后面的实验分析

输入:

gt = [batch_size,1,512,1024]
im_flow_list = [batch_size,3,2,512,1024]
im_seg_list = [batch_size,3,2,512,1024]

输出

pred.shape = [1,19,64,128]

N t a s k N_{task} Ntask

<原文>:在中间特征矩阵上应用随机初始化的1 × 1卷积层,得到(C+1)分图,其中C为类别数,1为背景类别。然后通过softmax层输出逐像素概率。因此,任务网络只有一个可学习的权重层。整体网络架构类似于DeepLab

task是一个分类器,作者采用了task=1*1conv+softmax,作者通过实验发现,有没有这个1*1conv效果差不多,使用0层基本上等同于使用1层,无论是精度还是速度。我们选择1层作为默认值,因为在特征传播之后会留下一些可调参数,这可能更通用。

image-20221226171055248

# net_task = 1*1 Conv + softmax
# 论文里面说有没有这个1x1conv没什么区别,多加一个conv可以为以后需要时调参数,或者说更常规
self.net_task = nn.Conv2d(num_classes, num_classes, kernel_size=1, stride=1, padding=0)

f l o w n e t flownet flownet

使用的是flownet,这里我直接搬用flownet的api。

实验

根据数据集中视频帧率的不同,评估时cityscape分割的关键帧时长l默认为5,ImageNet VID检测的关键帧时长l默认为10

视频语义分割的评价指标使用 mIoU,在计算 mIoU 时,设置了三个传播距 {1,5,9} 来 刻画传播精度, 其中 {1,5,9} 分别表示当前帧与关键帧距离为相邻帧、隔了 4 帧和隔 8 帧。

不过验证时我采用的传播距离固定是5

  • 训练时传播距离固定为5

  • 训练时传播距离在1-5之间随机采样

  • 在每个小批量中,随机抽取一对附近的视频帧 [ I k , I i ] , 0 ≤ i − k ≤ 9 [I_k, I_i] ,0≤i−k≤9 [Ik,Ii],0ik9

dis(传播距离)5random(1~5)random(0~9)
mIOU60.9860.9160.90

可以看到,三者实际上是差不多的。

代码

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmseg.models import build_segmentor
from mmcv.utils import Config

from pspnet import pspnet_res101
from flownet import FlowNets
from warpnet import warp


class DFF(nn.Module):
    def __init__(self, num_classes=19, weight_res101=None, weight_flownet=None):
        super(DFF, self).__init__()

        # reference branch选用pspnet_res50
        # TODO(12.26):预训练pspnet-r50模型
        self.net_feat = pspnet_res101()

        # net_task = 1*1 Conv + softmax
        # 论文里面说有没有这个1x1conv没什么区别,多加一个conv可以为以后需要时调参数,或者说更常规
        self.net_task = nn.Conv2d(num_classes, num_classes, kernel_size=1, stride=1, padding=0)

        # 光流场‘O()’选择FlowNets,预测的光流图
        self.flownet = FlowNets()

        # 用于传播关键帧到当前帧的可学习函数‘W()’,即将预测的光流图和关键帧的语义分割图进行融合
        self.warp = warp()

        # 权重初始化
        # TODO:将res101 -> res50
        self.weight_init(weight_res101, weight_flownet)

        # 交叉熵损失函数
        # FIXME:ignore_index=255?
        self.criterion_semantic = nn.CrossEntropyLoss(ignore_index=255)

    def weight_init(self, weight_res101, weight_flownet):
        if weight_res101 is not None:

            # 加载预训练权重
            weight = torch.load(weight_res101, map_location='cpu')
            weight = weight['state_dict']

            # 加载预训练权重
            self.net_feat.model.load_state_dict(weight, False)

            # 冻结backdone的参数,仅调整decode_head的参数
            self.net_feat.fix_backbone()

        if weight_flownet is not None:
            weight = torch.load(weight_flownet, map_location='cpu')
            self.flownet.load_state_dict(weight, True)

        # 为了使得网络中信息更好的流动,每一层输出的方差应该尽量相等
        nn.init.xavier_normal_(self.net_task.weight)
        self.net_task.bias.data.fill_(0)
        print('pretrained weight loaded')

    # ---------------------------------Input-----------------------------------------------
    # gt = [batch_size,1,512,1024]
    # im_flow_list = [batch_size,3,2,512,1024]
    # im_seg_list = [batch_size,3,2,512,1024]
    # -------------------------------------------------------------------------------------
    def forward(self, im_seg_list, im_flow_list, gt=None):

        # 输入的视频数据参数值,依次为 bastchsize, 通道, 关键帧间隔时间, 帧高度, 帧宽度
        n, c, t, h, w = im_seg_list.shape

        # 推理关键帧的语义结果
        pred = self.net_feat(im_seg_list[:, :, 0, :, :])
        # pred.shape = [1,19,64,128]

        # 双线性插值等比放大2倍
        pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
        # pred.shape = [1,19.128.256]

        # 计算关键帧的光流传播:首先将关键帧和当前帧的tensor在通道处堆叠,然后传入flownet,从而根据关键帧和当前帧计算光流,
        flow = self.flownet(torch.cat([im_flow_list[:, :, -1, :, :], im_flow_list[:, :, 0, :, :]], dim=1))

        # 将关键帧的pred传入warp(),然后和当前帧的flow继续一个W()函数,输出pred
        pred_result = self.warp(pred, flow)

        # 将经过warp输出的pred放到task网络里面
        pred_result = self.net_task(pred_result)
        # 双线性插值放大4倍
        pred_result = F.interpolate(pred_result, scale_factor=4, mode='bilinear', align_corners=False)
        # pred_result.shape = [1,19,512,1024]

        if gt is not None:
            loss = self.criterion_semantic(pred_result, gt)

            # .unsqueeze(0) 表示,在第一个位置增加维度
            loss = loss.unsqueeze(0)
            return loss
        else:
            return pred_result

    def evaluate(self, im_seg_list, im_flow_list):
        out_list = []
        t = im_seg_list.shape[2]
        pred = self.net_feat(im_seg_list[:, :, 0, :, :])
        # pred.shape = [1,19,64,128]

        pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
        # pred.shape = [1,19,128,256]

        # 将经过net_feat的关键帧,再经过net_task处理
        out = self.net_task(pred)

        # 长宽均放大4倍
        out = F.interpolate(out, scale_factor=4, mode='bilinear', align_corners=False)
        # out.shape = [1,19,512,1024]

        # 输入:out.shape = torch.Size([1, 19, 512, 1024])
        out = torch.argmax(out, dim=1)
        # 输出:out.shape = torch.Size([1, 512, 1024])

        out_list.append(out)

        # FIXME:eval时也不需要for循环吗?

        # 当前帧和关键帧做一个光流估计
        flow = self.flownet(torch.cat([im_flow_list[:, :, -1, :, :], im_flow_list[:, :, 0, :, :]], dim=1))

        # 扔进‘W()’函数里
        pred_result = self.warp(pred, flow)

        # 对堆叠结果进行卷积
        pred_result = self.net_task(pred_result)
        pred_result = F.interpolate(pred_result, scale_factor=4, mode='bilinear', align_corners=False)

        # 取最大的可能性结果
        out = torch.argmax(pred_result, dim=1)
        out_list.append(out)

        return out_list

    def set_train(self):
        self.net_feat.eval()
        self.net_feat.model.decode_head.conv_seg.train()
        self.net_task.train()
        self.flownet.train()


if __name__ == '__main__':
    model = DFF(weight_res101=None, weight_flownet=None)
    model.cuda().eval()

    im_seg_list = torch.rand([1, 3, 5, 512, 1024]).cuda()
    im_flow_list = torch.rand([1, 3, 5, 512, 1024]).cuda()
    with torch.no_grad():
        out_list = model.evaluate(im_seg_list, im_flow_list)
        print(len(out_list), out_list[0].shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/118317.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据通信基础 - 调制技术

文章目录1 概述2 调制技术2.1 分类2.2 N 相调制3 网工软考真题1 概述 #mermaid-svg-ZTF6pPysJlmUes01 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-ZTF6pPysJlmUes01 .error-icon{fill:#552222;}#mermaid-svg-ZTF…

谷歌用量子处理器发现:光子能在混沌中保持稳健的束缚态

一圈超导量子比特可以容纳微波光子的“束缚态”&#xff0c;其中光子往往聚集在相邻的量子比特位点上。图片来源&#xff1a;Google Quantum AI 使用量子处理器&#xff0c;研究人员可以使微波光子具有异常的“粘性”。在诱使它们聚集成束缚态后&#xff0c;他们发现这些光子簇…

谷歌 Chrome 浏览器弹窗境外广告的解决方法

谷歌的 Chrome 浏览器是我非常喜欢的一款的浏览器&#xff0c;用了它之后就不想再用其它浏览器。可是不知道从什么时候开始&#xff0c;Chrome 浏览器居然时不时地在右下角弹出广告&#xff0c;仔细一看&#xff0c;还是境外的 VPN 广告&#xff0c;如下图。有弹出过几次了&…

如何通过创建 SSH key 来进行Git 代码管理

1.检查你的电脑是否已经有SSH Key&#xff1f; 运行如下命令查看&#xff1a; $ cd ~/.ssh $ ls如果存在id_rsa.pub或 id_dsa.pub 文件&#xff0c;说明你的电脑已经有 SSH Key &#xff0c;可以直接拿来用&#xff0c;如果没有的话需要创建。 2.创建SSH Key 配置全局的nam…

zookeeper入门篇

文章目录前言介绍安装与启动配置说明节点节点类型PERSISTENT&#xff08;持久化节点&#xff09;PERSISTENT_SEQUENTIAL&#xff08;持久化顺序节点&#xff09;EPHEMERAL&#xff08;临时节点&#xff09;EPHEMERAL_SEQUENTIAL&#xff08;临时顺序节点&#xff09;Container&…

用Java实现简单的图书管理系统(Java系列7)

目录 前言&#xff1a; 1.基础框架的搭建 1.1图书 1.1.1书 1.1.2书架 1.2用户 1.2.1抽象类 1.2.2普通用户 1.2.3管理员 1.3操作 1.3.1新增图书 1.3.2借阅图书 1.3.3删除图书 1.3.4退出图书 1.3.5查找图书 1.3.6归还图书 1.3.7显示图书 2.具体内容的实现 2.1Ma…

<flutter>跨平台开发新手入坑指南 dart dio pubspec.yaml json_annotation 打包 小坑指南

1.资源文件和依赖三方包&#xff08;pubspec.yaml&#xff09;&#xff1a; pubspec.yaml文件可以说是和安卓的gradle文件差不多&#xff0c;它用来描述版本号、sdk、依赖等的。 在资源导入方面同安卓不一样的是&#xff0c;flutter需要在pubspec.yaml中声名&#xff0c;不然…

【PCB专题】Allegro元件库路径设置方法

正常Layout拉线前,需要将原理图导出的网表导入到Allegro里,Allegro就会自动将元件导入。如果库路径没有设置或都软件找不到器件,将会非常的卡顿,并且报Completed with warnings/errors。如下图所示: 在弹出的错误报告View of file:netrev.lst中会提示很多器件找不到封装。…

js知识点

js有预解析阶段&#xff0c;变量声明提升只提升定义&#xff0c;不提升值 console.log(a);//undefined var a10; 基本数据类型 Number、String、Boolean、Undefined和Null 复杂数据类型 Object、Array、Function、RegExp、Date、Map、Set等 使用typeof运算符可以检测值或…

刷爆力扣之验证回文串 II

刷爆力扣之验证回文串 II HELLO&#xff0c;各位看官大大好&#xff0c;我是阿呆 &#x1f648;&#x1f648;&#x1f648; 今天阿呆继续记录下力扣刷题过程&#xff0c;收录在专栏算法中 &#x1f61c;&#x1f61c;&#x1f61c; 该专栏按照不同类别标签进行刷题&#xff…

第六章 作业【数据库原理】

第六章 作业【数据库原理】前言推荐第六章 作业第6章第1题&#xff08;简答题&#xff09;第6章第2题&#xff08;简答题&#xff09;第6章第3题&#xff08;设计题&#xff09;第6章第4题&#xff08;设计题&#xff09;最后前言 2022-12-27 16:05:55 以下内容源自数据库原理…

最大连续子序列的和问题(算法)

问题描述 给定一个有n&#xff08;n≥1&#xff09;个整数的序列&#xff0c;要求求出其中最大连续子序列的和。 蛮力法 暴力枚举 /*** 时间复杂度&#xff1a;O(n^3)* param arr 序列[数组]* param n 数组大小* return int */ int maxSubSum1(int arr[], int n) {int thi…

美团餐饮SaaS基于StarRocks构建商家数据中台的探索

作者&#xff1a;何启航&#xff0c;美团餐饮SaaS数据专家&#xff08;文章整理自作者在 StarRocks Summit Asia 2022 的分享&#xff09; 随着社会经济的发展&#xff0c;餐饮连锁商家越来越大&#xff0c;“万店时代”来临。对于美团餐饮 SaaS 来说&#xff0c;传统的 OLTP …

LeetCode 324 周赛

2506. 统计相似字符串对的数目 给你一个下标从 0 开始的字符串数组 words 。 如果两个字符串由相同的字符组成&#xff0c;则认为这两个字符串 相似 。 例如&#xff0c;"abca" 和 "cba" 相似&#xff0c;因为它们都由字符 a、b、c 组成。然而&#xff…

HQChart实战教程54-renko砖形K线图

HQChart实战教程54-renko砖形K线图 Renko砖形图效果图使用HQChart创建Renko初始化创建Renko配置参数说明ClassNameOption动态修改Renko配置参数完成demo代码Renko砖形图 Renko砖形图是仅测量价格变动的图表类型。 “ renko”一词源自日语单词“ renga”,意为“砖”。并非巧合…

day30【代码随想录】分割回文串、复原IP地址、子集

文章目录前言一、分割回文串&#xff08;力扣131&#xff09;二、复原IP地址&#xff08;力扣93&#xff09;三、子集&#xff08;力扣78&#xff09;总结前言 1、分割回文串 2、复原IP地址 3、子集 一、分割回文串&#xff08;力扣131&#xff09; 给你一个字符串 s&#xf…

前端开发:关于鉴权的使用总结

前言 前端开发过程中&#xff0c;关于鉴权&#xff08;权限的控制&#xff09;是非常重要的内容&#xff0c;尤其是前端和后端之间数据传递时候的请求鉴权校验。前端鉴权的本质就是控制前端视图层的显示和前端向后台所发送的请求&#xff0c;但是只有前端鉴权&#xff0c;没有后…

MyGDI+

文章目录[toc]界面设计Form窗口MenuStrip画笔其他选项界面美化整体框架设计DataStructureCPointPolylinePolygonSingletonGraphicFunctionForm事件处理成员变量事件处理总结界面设计 Form窗口 首先添加MenuStrip控件&#xff0c;随后在Form窗口属性界面根据个人爱好修改其图标…

请收下这份数字IC面试超强攻略!(内附大厂面试题目)

2022年马上就要结束了&#xff0c;想必今年有很多同学也已经感受到IC行业的门槛在不断提升&#xff0c;这一点尤其在面试的过程中感受明显。 前两年的时候&#xff0c;面试官有可能问一些比较简单的问题就能通过&#xff0c;今年可就没那么简单了&#xff0c;必须提前做好相关…