【光流估计】【深度学习】Windows11下FastFlowNet代码Pytorch官方实现与源码讲解

news2024/11/25 14:44:38

【光流估计】【深度学习】Windows11下FastFlowNet代码Pytorch官方实现与源码讲解

提示:最近开始在【光流估计】方面进行研究,记录相关知识点,分享学习中遇到的问题已经解决的方法。


文章目录

  • 【光流估计】【深度学习】Windows11下FastFlowNet代码Pytorch官方实现与源码讲解
  • 前言
  • FastFlowNet模型运行环境搭建
  • FastFlowNet模型运行
    • 模型权重
    • FastFlowNet测试
      • 对运行速度进行基准测试并计算模型参数
      • 预测光流的演示
  • 总结


前言

FastFlowNet是由上海交通大学的Kong, Lingtong等人在《FastFlowNet: A Lightweight Network for Fast Optical Flow Estimation【ICRA-2021】》【论文地址】一文中提出的针对移动设备优化的光流估计网络,通过头部增强池金字塔(HEPP)特征提取、中心密集膨胀相关(CDDC)层和高效的洗牌块解码器(SBD)实现高效计算。该网络在保持大搜索半径的同时降低了参数数量和计算成本,适合在资源有限的平台上运行。
在详细解析FastFlowNet网络之前,首要任务是搭建FastFlowNet【Pytorch-demo地址】所需的运行环境,并完成模型训练和测试工作,展开后续工作才有意义。


FastFlowNet模型运行环境搭建

注意需要安装Visual Studio,博主安装的是 Visual Studio 2019版本。
在win11环境下装anaconda环境,方便搭建专用于FastFlowNet模型的虚拟环境。

  • 查看主机支持的cuda版本(最高)

    # 打开cmd,执行下面的指令查看CUDA版本号
    nvidia-smi
    

  • 安装GPU版本的torch【官网】
    博主的cuda版本是12.2,博主选的11.8也没问题。

    其他cuda版本的torch在【以前版本】找对应的安装命令。

  • 博主安装环境参考

    # 创建虚拟环境
    conda create -n FastFlowNet python=3.10
    # 查看新环境是否安装成功
    conda env list
    # 激活环境
    activate FastFlowNet 
    # 分别安装pytorch和torchvision
    pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    # 安装 opencv
    pip install -i https://pypi.tuna.tsinghua.edu.cn/simple opencv-python
    # 安装Correlation模块,进入工程的correlation_package路径下
    cd ./models/correlation_package/
    python setup.py build
    python setup.py install
    # 查看所有安装的包
    pip list
    conda list
    

    安装Correlation模块出现的问题解决方案:“ “getCurrentCUDAStream”: 不是 “at::Context” 的成员”

    解决方案: 找到models/correlation_package路径下的correlation_cuda.cc文件,修改里面内容。

    // 注释 #include <ATen/ATen.h> 
    // 更改为
    #include <ATen/cuda/CUDAContext.h> 
    // 注释 at::globalContext().getCurrentCUDAStream()
    // 更改为
    at::cuda::getCurrentCUDAStream()
    



FastFlowNet模型运行

模型权重

原作者预训练权重放置在当前工程目录下的checkpoints目录下

FastFlowNet测试

源代码未开源训练部分代码

对运行速度进行基准测试并计算模型参数

python benchmark.py

可能出现的问题:“RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method”

解决方式: 正在使用旧版本的 PyTorch autograd 系统,需要修改models/correlation_package/correlation.py内容的部分代码内容,使用新的方式定义自定义的 autograd 函数。
修改部分1:修改前的原始代码

class CorrelationFunction(Function):

    def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
        super(CorrelationFunction, self).__init__()
        self.pad_size = pad_size
        self.kernel_size = kernel_size
        self.max_displacement = max_displacement
        self.stride1 = stride1
        self.stride2 = stride2
        self.corr_multiply = corr_multiply
        # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)

    def forward(self, input1, input2):
        self.save_for_backward(input1, input2)

        with torch.cuda.device_of(input1):
            rbot1 = input1.new()
            rbot2 = input2.new()
            output = input1.new()

            correlation_cuda.forward(input1, input2, rbot1, rbot2, output, 
                self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)

        return output

    def backward(self, grad_output):
        input1, input2 = self.saved_tensors

        with torch.cuda.device_of(input1):
            rbot1 = input1.new()
            rbot2 = input2.new()

            grad_input1 = input1.new()
            grad_input2 = input2.new()

            correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
                self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)

        return grad_input1, grad_input2

修改部分1:修改后的新代码

class CorrelationFunction(Function):
    @staticmethod
    def forward(ctx, input1, input2, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
        ctx.pad_size = pad_size
        ctx.kernel_size = kernel_size
        ctx.max_displacement = max_displacement
        ctx.stride1 = stride1
        ctx.stride2 = stride2
        ctx.corr_multiply = corr_multiply
        ctx.save_for_backward(input1, input2)

        with torch.cuda.device_of(input1):
            rbot1 = input1.new()
            rbot2 = input2.new()
            output = input1.new()

            correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
                ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        input1, input2 = ctx.saved_tensors

        with torch.cuda.device_of(input1):
            rbot1 = input1.new()
            rbot2 = input2.new()

            grad_input1 = input1.new()
            grad_input2 = input2.new()

            correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
                ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply)

        return grad_input1, grad_input2, None, None, None, None, None

修改部分2:修改前的原始代码

    def forward(self, input1, input2):

        result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2)

        return result

修改部分2:修改后的新代码

    def forward(self, input1, input2):

        result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)

        return result

预测光流的演示

python demo.py

在代码里可以修改自己想要测试的图片,这里博主使用源代码测试数据,效果如下:


总结

尽可能简单、详细的介绍了FastFlowNet的安装流程以及FastFlowNet的使用方法。后续会根据自己学到的知识结合个人理解讲解FastFlowNet的原理和代码。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2035847.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VUE系列之极速入门与实践教程

系列博客专栏&#xff1a; JVM系列博客专栏SpringBoot系列博客 实验环境 npm v10.8.1 node v20.16.0 vue.js v3.4.37 VSCODE 1.88.1 什么是Vue.js Vue (读音 /vjuː/&#xff0c;类似于 view) 是一套用于构建用户界面的渐进式框架。Vue.js 是一套构建用户界面的框架&…

聚合平台项目优化(门面模式,适配器模式,注册器模式)

前言&#xff1a; 这篇文章的思路就是抛出问题&#xff0c;再思考解决方案&#xff0c;最后利用设计模式解决问题 项目背景&#xff1a; 聚合搜索平台的主要功能就是一个有强大搜索能力的一个项目 用户输入一个词&#xff0c;同时可以搜索出用户&#xff0c;文章和图片这种…

【AI学习】具身智能的技术发展、商业路径等有趣观点

阅读了两篇有关具身智能的文章&#xff0c;有好多话语&#xff0c;挺有趣&#xff0c;做一些摘录。 一篇是&#xff1a;腾讯研究院的《具身智能的10个真问题&#xff5c;3万字圆桌实录》&#xff08;链接&#xff1a;https://mp.weixin.qq.com/s/peIi0YOJGKFV3fpLURDyyQ&#x…

一天搞定Vue3——包含Axios、ElementUI Plus、Vuex的使用!!!

前言,本篇文章是依据bilibili博主(波波酱老师)的学习笔记,波波酱老师讲的很好,很适合速成!!! 本篇文章会与vue2进行对比学习,并且也有很多的JavaScript知识点&#xff0c;要提前掌握他们才能学的效果更佳,见效更快。&#x1f973; 文章目录 Vue基础Vue的底层原理el挂载点data数…

Linux中网卡收发包的流程

进来在一个RTOS上移植开发网卡驱动&#xff0c;最终DMA收发包流程打通之后&#xff0c;在使用过程中觉得RTOS的处理逻辑太差了&#xff0c;因此有想法来梳理下Linux中对收发包流程处理&#xff0c;来给一些参考。 一、Linux接收网络包的流程 网卡是一个计算机的硬件&#xff0…

浅谈线性表——顺序表

文章目录 一、List接口二、线性表2.1、什么是线性表&#xff1f; 三、顺序表ArrayList什么是顺序表&#xff1f; 一、List接口 从上图看到List接口继承自Collection接口&#xff0c;而 ArrayList、LinkedList、Stack 类都实现了List接口&#xff0c;List是个接口&#xff0c;不…

论文新体验!分享8款人工智能AI软件论文网站

最近看到这个AI工具推广做的比较多&#xff0c;号称长文写的比kimi还要好&#xff01;难道大学生的救星下凡了&#xff1f;&#xff1f; 本文将分享8款优秀的AI软件论文网站&#xff0c;并重点推荐千笔-AIPassPaPer&#xff0c;这是一款备受用户好评的AI原创论文写作平台。 1…

C++ | 掌握C++异常处理:从基础到自定义异常体系的全面指南

09--异常 1、C语言传统的错误处理方式&#xff1a; 包括终止程序和返回错误码两种方式。 直接使用assert终止程序过于粗暴&#xff1a;用户无意的小错误也会造成程序结束运行。 return返回错误码&#xff0c;再通过错误码查找错误类型&#xff1a;过程繁琐&#xff0c;对用…

可视化基础的设计四大原则

一个好的数据可视化设计可以帮助观众迅速理解数据背后的意义。然而&#xff0c;如何确保我们的可视化设计既美观又简单易懂呢&#xff1f;本文将介绍四大设计原则——亲密原则、对比原则、对齐原则和重复原则。 1、 亲密原则&#xff08;Proximity&#xff09; 定义与应用&am…

学习大数据DAY34 面向对象思想深化练习 将从豆瓣爬取的数据置入自己搭建的网站上

目录 查看电影类型的电影列表 添加电影 修改电影 上机练习 13 使用三层架构完善 web 系统 查看电影类型的电影列表 DAL.py 文件 class MovieDAL(DBHelper): def getMovieByTid(self,typeid): sqlf"""select id,title,release_date,score,tname from Mo…

YOLOv8 | 融合改进 | C2f融合可变核卷积AKConv【附代码+小白可上手】

秋招面试专栏推荐 &#xff1a;深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转 &#x1f4a1;&#x1f4a1;&#x1f4a1;本专栏所有程序均经过测试&#xff0c;可成功执行&#x1f4a1;&#x1f4a1;&#x1f4a1; 专栏目录 &#xff1a;《YOLOv8改进有效…

【JavaEE】深入理解Spring IoC与DI:从传统开发到依赖注入的转变

目录 IoC & DI ⼊⻔什么是Spring什么是容器什么是IoCIoC介绍传统程序开发问题分析解决方案IoC程序开发IoC优势 IoC & DI ⼊⻔ IoC&#xff1a;Inversion of Control (控制反转) DI&#xff1a;Dependency Injection 在前⾯我们学习了Spring Boot和Spring MVC的开发, 可…

DNS相关内容

一、dns的两种解析方式 1. 正向解析 将域名解析为 ip 地址 2. 逆向解析 将 ip 地址解析为域名 设置解析方式&#xff0c;都是在 zone 文件中 named.conf 解决权限 named.rfc1912.zone 解决解析方式 3.DNS 方向解析 把 192.168.71.145 这个 ip 地址逆向解析为 www.yuany…

Android逆向题解攻防世界-easyjava-难度6

纯Java实现&#xff0c;不涉及so, flag加密之后与指定字符串 “wigwrkaugala"比较判断&#xff0c;循环一个个字符加的&#xff0c;那可以一个个字符对应还原。 加密算法就在a,b类里面&#xff0c;代码直接复制到idea &#xff0c;枚举暴力破解。 每一位输入范围a-z , 找…

Lua脚本 快速掌握

1.Lua脚本概述 Lua是一种轻量级的编程语言&#xff0c;由巴西里约热内卢天主教大学开发。设计初衷是为了嵌入应用程序中&#xff0c;提供灵活的配置和脚本能力。Lua具有简洁的语法和强大的扩展性&#xff0c;使得它在多个领域得到了广泛应用。 Lua的特点包括动态类型、自动内…

The Sandbox 游戏制作教程第 4 章|使用装备制作游戏,触发独特互动

欢迎回到我们的系列&#xff0c;我们将记录 The Sandbox Game Maker 的 “On-Equip”&#xff08;装备&#xff09;功能的多种用途。 如果你刚加入 The Sandbox&#xff0c;On-Equip 功能是 “可收集组件”&#xff08;Collectable Component&#xff09;中的一个多功能工具&a…

C++ list【常用接口、模拟实现等】

1. list的介绍及使用 1.1 list的介绍 1.list是可以在常数范围内在任意位置进行插入和删除的序列式容器&#xff0c;并且该容器可以前后双向迭代。 2.list的底层是双向链表结构&#xff0c;双向链表中每个元素存储在互不相关的独立节点中&#xff0c;在节点中通过指针指向其前…

MyBatisPlus 第二天

常用注解 1 TableName:数据库表名和实体类名不同时,会出现以下报错 在实体类上添加 TableName("t_user") 在开发的过程中&#xff0c;我们经常遇到以上的问题&#xff0c;即实体类所对应的表都有固定的前缀&#xff0c;例如t_或tbl_此时&#xff0c;可以使用MyBa…

el-tree自定义节点内容

<el-tree :data"data" :props"defaultProps" ref"treeRef" show-checkbox check-change"handleCheckChange"><!-- 自定义节点内容 --><template #default"{ node, data, store }"><span class"tr…

无人值守人工智能智慧系统数据分析:深度洞察与未来展望

无人值守人工智能智慧系统数据分析&#xff1a;深度洞察与未来展望 随着科技的飞速发展&#xff0c;人工智能&#xff08;AI&#xff09;技术已逐渐渗透到社会经济的各个领域&#xff0c;其中无人值守人工智能智慧系统作为AI技术应用的前沿阵地&#xff0c;正引领着一场深刻的…