论文笔记Neural Ordinary Differential Equations

news2025/1/15 12:57:26

论文笔记Neural Ordinary Differential Equations

  • 概述
  • 参数的优化
  • 连续标准化流(Continuous Normalizing Flows)
  • 生成式的隐轨迹时序模型(A generative latent function time-series model)

这篇文章有多个版本,在最初的版本中存在一些错误,建议下载2019年的最新版。

概述

在残差网络中有下面的形式:
h t + 1 = h t + f ( h t , θ t ) (1) \mathbf h_{t+1} = \mathbf h_{t} + f(\mathbf h_{t}, \theta_t) \tag{1} ht+1=ht+f(ht,θt)(1)
连续的动态系统通常可以用常微分方程(ordinary differential equation, ODE)表示为:
d h ( t ) d t = f ( h ( t ) , t , θ ) (2) \frac{d\mathbf h(t)}{dt} = f(\mathbf h(t), t, \theta) \tag{2} dtdh(t)=f(h(t),t,θ)(2)如果动态系统中的 f f f用神经网络的模块表示,就得到了神经常微分方程Neural ODE,公式(1)可以看做是公式(2)的欧拉离散化(Euler discretization)。
输入是 h ( 0 ) \mathbf h(0) h(0),输出是 h ( T ) \mathbf h(T) h(T),也就是常微分方程初值问题在T时刻的解。

值得注意的是这里的 t t t不代表时间,而是代表网络的层数。但在某些问题下,如时间预测问题下, t t t也可以代表时间。

下图所示是残差网络和神经常微分方程的区别。纵轴代表 t t t(depth),残差网络的状态变化是离散的,在整数位置计算状态的值,而神经常微分方程的状态是连续变化的,计算状态值的位置由求解常微分方程的算法决定。
实际上Neural ODE中的depth的定义并不简单,这在论文第3部分有说,并不是t为多少就是多深,Neural ODE中的depth应该是和隐含状态计算的次数相关的。比如下图中depth到5,resnet确实只计算了5次隐含状态,但Neural ODE其实计算了很多次的隐含状态。隐含状态计算的次数和终点t有关,和ODE的求解算法也有关。

在这里插入图片描述
Neural ODE就是用神经网络模块来表示常微分方程里的 f f f,同时Neural ODE又可以把常微分方程作为一个模块嵌入大的神经网络中。

参数的优化

普通的常微分方程中的参数 θ \theta θ是固定的,但是在Neural ODE中是神经网络的参数,所以需要优化。神经网络的参数用反向传播进行优化,神经常微分方程作为神经网络的一个模块,也需要支持反向传播。因为不只需要优化神经常微分方程中的参数,要需要优化神经常微分方程之前的模块的参数,所以需要求损失函数关于 z ( t 0 ) , t 0 , t 1 , θ \mathbf z(t_0), t_0, t_1, \theta z(t0),t0,t1,θ的梯度。

直接对积分的前向过程做反向传播理论上是可行的,但是需要大量的内存并会导致额外的数值误差。
为了解决这些问题,论文提出使用adjoint sensitivity method来求梯度。adjoint法可以通过求解另一个ODE来计算反传时需要的梯度。
考虑优化一个标量损失函数,这个损失函数的输入是ODE的结果。

在这里插入图片描述
定义伴随状态(adjoint state)为 a ( t ) = − ∂ L ∂ z ( t ) \mathbf a(t)=-\frac{\partial L}{\partial \mathbf z(t)} a(t)=z(t)L
adjoint state满足另一个ODE:
d a ( t ) d t = − a ( t ) ⊤ ∂ f ( z ( t ) , t , θ ) ∂ z \frac{d \mathbf a(t)}{dt} = -\mathbf a(t)^\top \frac{\partial f(\mathbf z(t), t, \theta)}{\partial \mathbf z} dtda(t)=a(t)zf(z(t),t,θ)论文在附录中给出了证明。
通过伴随状态,损失函数关于 z ( t 0 ) , t 0 , t 1 , θ \mathbf z(t_0), t_0, t_1, \theta z(t0),t0,t1,θ的梯度都可以通过求解ODE得到。
∂ L ∂ z ( t 0 ) = a ( t 1 ) − ∫ t 1 t 0 a ( t ) ⊤ ∂ f ( t , z ( t ) , θ ) ∂ z ( t ) d t \frac{\partial L}{\partial \mathbf z(t_0)} = \mathbf a(t_1) - \int_{t_1}^{t_0} \mathbf a(t)^{\top}\frac{\partial f(t,\mathbf z(t), \theta)}{\partial \mathbf z(t)} dt z(t0)L=a(t1)t1t0a(t)z(t)f(t,z(t),θ)dt其中 a ( t 1 ) \mathbf a(t_1) a(t1)是损失函数对最后时刻的隐藏状态的梯度,可以由下一层神经网络的BP获得。

a θ ( t ) = ∂ L ∂ θ ( t ) ,   a t ( t ) = ∂ L ∂ t ( t ) \mathbf a_\theta(t) = \frac{\partial L}{\partial\theta(t)}, \ a_t(t) = \frac{\partial L}{\partial t(t)} aθ(t)=θ(t)L, at(t)=t(t)L
∂ L ∂ θ ( t 0 ) = a θ ( t 1 ) − ∫ t 1 t 0 a ( t ) ⊤ ∂ f ( t , z ( t ) , θ ) ∂ θ d t \frac{\partial L}{\partial\theta(t_0)} = \mathbf a_\theta(t_1) - \int_{t_1}^{t_0} \mathbf a(t)^{\top}\frac{\partial f(t, \mathbf z(t), \theta)}{\partial\theta} dt θ(t0)L=aθ(t1)t1t0a(t)θf(t,z(t),θ)dt其中令 a θ ( t 1 ) = 0 \mathbf a_\theta(t_1)=0 aθ(t1)=0,这一点我目前没有看懂为啥这么设置, θ \theta θ是不随着 t t t而变的。
∂ L ∂ t 1 = ∂ L ∂ z ( t 1 ) ∂ z ( t 1 ) ∂ t 1 = a ( t 1 ) ⊤ f ( t 1 , z ( t 1 ) , θ ) = a t ( t 1 ) \frac{\partial L}{\partial t_1} = \frac{\partial L}{\partial \mathbf z(t_1)} \frac{\partial \mathbf z(t_1)}{\partial t_1} = \mathbf a(t_1)^{\top} f(t_1, \mathbf z(t_1), \theta) = a_t(t_1) t1L=z(t1)Lt1z(t1)=a(t1)f(t1,z(t1),θ)=at(t1) ∂ L ∂ t 0 = a t ( t 1 ) − ∫ t 1 t 0 a ( t ) ⊤ ∂ f ( t , z ( t ) , θ ) ∂ t d t \frac{\partial L}{\partial t_0} = a_t(t_1) - \int_{t_1}^{t_0} \mathbf a(t)^{\top}\frac{\partial f(t, \mathbf z(t), \theta)}{\partial t} dt t0L=at(t1)t1t0a(t)tf(t,z(t),θ)dt
这些导数可以整合放到一个ODE方程中去求解,如下面的算法所示:
在这里插入图片描述
实际使用中不需要考虑梯度计算的问题,因为这些在库(https://github.com/rtqichen/torchdiffeq)中都已经写好了,只需要定义好 f f f直接调用积分算法就可以了。

连续标准化流(Continuous Normalizing Flows)

公式(1)中这种形式也出现在标准化流中(normalizing flows)。
normalizing flows是一种生成算法,可以学习模型生成指定分布的数据,目前广泛用于图像的生成。
normalizing flows要求变换是双射(bijective fucntion),这样就可以利用change of variables theorem直接计算概率。
在这里插入图片描述

为了满足双射的要求,变换需要是精心设计的。normalizing flows有不同的变种方法,其中一种planar normalizing flow有下面的变换:
在这里插入图片描述
主要的运算量来着于计算 ∂ f ∂ z \frac{\partial f}{\partial \mathbf z} zf有趣的是当离散的变换变为连续的变换时,概率的计算变得简单了,不再需要det的计算。
论文给出了下面的定理:

在这里插入图片描述
值得注意的是,后面火起来的生成模型diffusion model,可以扩展为probability flow ODE,也可以使用这个定理。

生成式的隐轨迹时序模型(A generative latent function time-series model)

在时序模型中 t t t可以表示时间。用Neural ODE建模时间序列的好处是可以建模连续的状态,天然适合非规则采样的时间序列(irregularly-sampled data)。
假设每一个时间序列由一个隐轨迹决定。隐轨迹是由初始状态和一组隐含的动态决定的。有观测时间点 t 0 , t 1 , ⋯   , t N t_0,t_1,\cdots,t_N t0,t1,,tN和初始状态 z t 0 z_{t_0} zt0,生成模型如下:
在这里插入图片描述
这里 f f f被定义为一个不随着时间变换的神经网络。外推(Extrapolating)可以得到时间点往前或者往后的预测结果。

这本质是一个隐变量生成模型,所以可以用variational autoencoder(VAE)的算法优化。只不过这里的观测变量时间序列,而传统VAE的观测变量是图像。
为了能表示时间序列,这里encoder使用的是RNN模型。生成初始隐含状态后,由Neural ODE生成其他时间点的隐含状态,再由一个decoder网络计算 p ( x ∣ z ) p(x|z) p(xz)
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/145451.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

嵌入式系统IO体系简述

前言: CPU的主要职责是负责运算,而计算机是需要各种外设的,否则无法和人进行交互。早期x86体系的CPU,需要使用前端总线(fsb)和北桥芯片相连,北桥再和南桥相连。南北桥是一种架构的划分&#xff…

数据结构与算法——算法分析(3)

算法的时间复杂度计算 算法基本操作执行的次数还会随着问题输入的数据集不同而不同 最坏时间复杂度:在最坏的情况下,算法的时间复杂度 平均时间复杂度:所有可能输入在等概率的情况下算法的期望运行时间 最好时间复杂度:在最好的…

企业内训方案|数据治理/项目管理/敏捷项目管理/产品管理

企业内训方案|数据治理/项目管理/敏捷项目管理/产品管理 》》数据治理 数据管理基础 数据处理伦理 数据治理 数据架构 数据建模和设计 数据安全 数据集成和互操作 文件和内容管理 参考数据和主数据 数据仓库和商务智能 元数据管理 数据质量 大数据和数据科学 数据管理成熟度评…

网络协议HTTP:了解Web及网络基础

文章整理自图书图解Http第一章:使用Http协议访问Web第二章:Http的诞生第三章:网络基础TCP/IP协议一:应用层二:传输层三:网络层四:链路层五:TCP/IP通信传输流第四章:IP、T…

http与https的区别我真的知道吗

之前每次看到类似“http与https的区别?”的问题时,都会自己思考一下答案,好像只是浅显地知道https比http安全,但究竟为什么更安全,却又似乎说不出个所以然,或者说很多细节地方自己都是不清楚的。为了搞清楚…

Linux权限shell命令以及运行原理

文章目录一、Linux权限的概念二、Linux权限管理2.1.文件访问者的分类(角色)2.2文件属性2.3文件访问权限的相关设置方法2.4访问者角色的修改2.5目录权限含义2.6默认权限三、粘滞位四、 shell命令以及运行原理一、Linux权限的概念 权限的概念通常是指行事…

ffmpeg录制H265格式的桌面视频

ffmpeg本身不支持H265,如果需要支持,需要事先编译出libx265,读者可以到libx265的官方网站https://www.videolan.org/developers/x265.html上找到下载地址,本人下载的是x265_3.5.tar.gz。 编译libx265时,定位到其目录下…

java基础 网络编程

网络编程概念: 让程序可以和网络上的其他设备中的程序进行数据交互。 网络通信基本模式: CS:Client-Server 自己写客户端和服务器交流 BS:Browser/Server 通过浏览器和服务器交流 实现网络编程关键的三要素…

python中的函数与变量

一、函数python中函数的基本格式则为:def函数名参数名函数体返回,python作为一门面向对象的语言,同样可分为类函数、实例函数。 # 定义一个函数 def add(x, y):"""函数的说明:param x: 参数x的作用:param y: 参数y的作用:return: 函数返…

碱性环境吸钯树脂技术

汞和贵金属的选择性去除回收离子交换树脂 Tulsimer CH-95S 是一款为了从工业废水中去除回收汞和贵金属而开发的螯合树脂。 Tulsimer CH-95S是一款拥有聚乙烯异硫脲官能基的大孔树脂,这种树脂对汞有的选择性。它也选择其他的贵金属,如黄金,铂…

消息收发弹性——生产集群如何解决大促场景消息收发的弹性降本诉求

作者:宸罡 产品介绍—什么是消息收发弹性 大家好,我是来自阿里云云原生消息团队的赖福智,花名宸罡,今天来给大家分享下阿里云 RocketMQ5.0 实例的消息弹性收发功能,并且通过该功能生产集群是如果解决大促场景消息收发…

JS in CSS:一键支持响应式布局

前言 如今网速不再成为适配移动端时选择响应式设计的限制因素,在资源充足的条件下,针对各端各自设计应用界面能达到应用最佳用户体验,毕竟不同类型的设备交互体验是不同的,但在团队前端资源拮据时,相比无脑自适应&…

基于Java+jquery+SpringMVC校园网站平台设计和实现

基于JavajquerySpringMVC校园网站平台设计和实现 博主介绍:5年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 超级帅帅吴 Java毕设项目精品实战案例《500套》 欢迎点赞 收藏 ⭐留言 文末获取源码联…

临时白名单

临时白名单介绍 相关常量 临时白名单列表介绍 前两个临时白名单可以豁免后台启动Service、豁免uid后台1min后进入idle状态等,最后一个临时白名单可以后台启动FGS。 // 由于高优先级消息而暂时允许逃避后台检查的一组应用程序 ID,短信/彩信 Composite…

【Vue路由】路由守卫、生命周期钩子、路由器工作模式

文章目录生命周期钩子案例实现总结路由守卫全局路由守卫独享守卫组件内守卫总结路由器的两种工作模式总结生命周期钩子 我们在News组件列表中的第一行加一个渐变文字。同时原来的路由缓存功能也要保存。 案例分析: 我们实现这个渐变的效果,是使用周期定…

Go select底层原理

在对Channel的读写方式上&#xff0c;除了我们通用的读 i <- ch, i, ok <- ch&#xff0c;写 ch <- 1 这种阻塞访问方式&#xff0c;还有select关键字提供的非阻塞访问方式。 在日常开发中&#xff0c;select语句还是会经常用到的。可能是channel普通读写的使用频率比…

基于Node.js和vue的博客系统的设计与实现

摘要随着互联网技术的高速发展&#xff0c;人们生活的各方面都受到互联网技术的影响。现在人们可以通过互联网技术就能实现不出家门就可以在线发布博客文章&#xff0c;简单、快捷的方便了人们的日常生活。同样的&#xff0c;在人们的工作生活中&#xff0c;也需要互联网技术来…

【Java寒假打卡】Java基础-日期类对象

【Java寒假打卡】Java基础-日期类对象Date概述Date类常用成员方法SimpleDateFormat案例:秒杀活动案例&#xff1a;在当前时间加上一天时间JDK8新增日期类获取时间中的一个值LocalDateTime转换方法LocalDateTime格式化和解析LocalDateTime 增加或者减少时间的方法修改时间的方法…

【JAVA程序设计】(C00099)基于SpringBoot的外卖订餐小程序(原生开发)

基于SpringBoot的外卖订餐小程序&#xff08;原生开发&#xff09;项目简介项目获取开发环境项目技术运行截图项目简介 基于SpringBootvue开发的原生外卖点餐微信小程序&#xff0c;包括用户小程序登录以及网页端的商家登录。本系统分为三个权限&#xff1a;商家、用户和游客&…

第七章.机器学习 Scikit-Learn—最小二乘法回归,岭回归,支持向量机,K_means聚类算法

第七章.机器学习 Scikit-Learn 7.1 Scikit-Learn简介 Scikit-Learn简称(SKlearn)是Python的第三方模块&#xff0c;是机器学习领域当中知名的Python模块之一&#xff0c;对常用的机器学习算法进行了封装&#xff0c;包括回归(Regression)&#xff0c;降维(Dimensionality Redu…