昇思MindSpore进阶教程--二阶优化器THOR

news2024/11/26 11:55:42

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

正文介绍

深度学习训练过程可以看成损失函数损失值下降过程,合适的优化器可以让深度学习训练时间大大减少。优化器可以分为一阶优化器和二阶优化器,目前业界主流使用的仍然是一阶优化器,二阶优化器因为单步训练时间过久而没有被广泛应用,而近年来,将二阶优化应用到深度学习训练中有了理论突破,并取得了不错的结果。

本文会介绍下优化器的背景,以及MindSpore团队自研二阶优化器THOR。

优化器背景介绍

假设训练样本数据集: D = ( x 1 , y 1 ) , . . . , ( x i , y i ) , . . . , ( x N , y N ) , x i ∈ X , y i ∈ Y D = {(x_1,y_1),...,(x_i,y_i),...,(x_N,y_N)},x_i \in X,y_i\in Y D=(x1,y1),...,(xi,yi),...,(xN,yN),xiX,yiY,参数θ表述的深度神经网络模型为: y ^ = f ( x ; θ ) , x ∈ X \hat{y} = f(x;\theta),x\in{X} y^=f(x;θ),xX,定义在模型输出和真实标签y之间的损失函数为: L ( y , y ^ ) , y ∈ Y L(y,\hat y),y \in Y L(y,y^),yY,网络参数学习的过程是最小化损失函数的过程: min ⁡ θ L ( y , y ^ ) \min\limits_{\theta}L(y,\hat{y}) θminL(y,y^) 。给定数据集、模型、损失函数后,深度学习训练问题归结为优化问题,深度神经网络训练优化问题参数规模巨大,需要大量的计算,难以计算出解析解。因此该过程也常常被比喻成下山,如图1 所示,一个人站在山顶的时候如何在有限视距内寻找最快路径下山呢?
在这里插入图片描述
而优化器就是在做这件事情,业界的优化算法可分为一阶优化算法和二阶优化算法。下面简单介绍下业界的优化器情况。

一阶优化器

梯度下降算法(Gradient Descent, GD)是机器学习中最经典的一阶优化算法,也是众多机器学习算法中最常用的优化算法。常用的一阶优化算法(比如SGD算法)中对参数的更新采用如下规则:
,其中
是需要更新的参数,
是学习率,
是损失函数对于参数的梯度。

但是主流随机梯度下降方法有以下问题:太小的学习率会导致网络收敛过于缓慢;学习率太高可能会影响收敛,并导致损失函数在最小值上波动,甚至出现发散,对参数比较敏感;容易收敛到局部最优,难以跳出鞍点。

因此业界提出了很多随机梯度下降方法的改良算法,例如Momentum、Nesterov、AdaGrad、RMSprop、Adadelta和Adam等。这些改进后的优化算法可以利用随机梯度的历史信息来自适应地更新步长,使得它们更容易调参,而且方便使用。

二阶优化器

二阶优化算法利用目标函数的二阶导数进行曲率校正来加速一阶梯度下降。与一阶优化器相比,其收敛速度更快,能高度逼近最优值,几何上下降路径也更符合真实的最优下降路径。

例如,二阶优化算法中的牛顿法就是用一个二次曲面去拟合你当前所处位置的局部曲面,而梯度下降法是用一个平面去拟合当前的局部曲面,通常情况下,二次曲面的拟合会比平面更好,所以牛顿法选择的下降路径会更符合真实的最优下降路径。如图2 所示,左边下降路径表示牛顿法的下降曲线,右边表示一阶梯度的下降曲线,二阶算法与一阶算法相比,可以更快的走到目的地,从而加速收敛。
在这里插入图片描述

THOR的介绍

当前业界已有的二阶优化算法计算量较大,与一阶相比优势不明显或者使用场景较为单一。MindSpore提出了自研算法THOR(Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation), 算法已被AAAI录用,THOR在多个场景中均有明显收益,如在BERT和ResNet50中,收敛速度均有明显优势。THOR主要做了以下两点创新:

降低二阶信息矩阵更新频率

通过实验观察费雪矩阵的F范数(Frobenius norm)在前期变化剧烈,后期逐渐变稳定,从而假设 { F k } k = 1 n \Big\{{F^k}\Big\}^{n}_{k=1} {Fk}k=1n是一个马尔可夫过程,可以收敛到一个稳态分布π,其中 F k F^k Fk代表第k个迭代时的费雪矩阵。因此,在训练过程中逐步增大费雪矩阵的更新间隔,可以在不影响收敛速度的情况下,减少训练时间。例如在ResNet50中,更新间隔步数随着训练的进行越来越大,到后期每个epoch只需更新一次二阶信息矩阵。

THOR受KFAC启发,将费雪矩阵按层解耦来降低矩阵复杂度,分别针对每一层的费雪矩阵做实验,发现有些层的费雪矩阵趋于稳态的速度更快,因此在统一的更新间隔上,更加细粒度的去调整每一层的更新频率。THOR使用矩阵的迹作为判断条件,当迹的变化情况大于某一阈值时更新该层的二阶信息矩阵,否则沿用上一个迭代的二阶信息矩阵,并且引入了停止更新机制,当迹的变化量小于某个阈值时停止更新该层二阶信息矩阵,具体更新公式如下: { u p d a t e   F i k ,    i f   Δ k ∈ ( ω 1 , + ∞ ) d o   n o t   u p d a t e   F i k   a n d   s e t   F i k = F i k − 1 ,   i f   Δ k ∈ [ ω 2 , ω 1 ] s t o p   u p d a t e   F i k   a n d   s e t   F t k + t ≡ F i k − 1   f o r   a l l   t = 1 , 2 , . . . i f   Δ k ∈ [ 0 , ω 2 ) \begin{split} \begin{cases} update\ F^{k}_{i} , \qquad\qquad\qquad\qquad\qquad\qquad\qquad\qquad\qquad\quad\ \ if \ \Delta^{k} \in (\omega_{1},+\infty)\\ do\ not\ update\ F^{k}_{i}\ and\ set \ F^{k}_{i}=F^{k-1}_{i}, \ \quad\qquad\qquad\qquad\quad if \ \Delta^{k} \in [\omega_{2},\omega_{1}]\\ stop\ update\ F^{k}_{i}\ and\ set \ F^{k+t}_{t}\equiv F^{k-1}_{i}\ for\ all\ t=1,2,...\quad if \ \Delta^{k} \in [0,\omega_{2}) \end{cases} \end{split} update Fik,  if Δk(ω1,+)do not update Fik and set Fik=Fik1, if Δk[ω2,ω1]stop update Fik and set Ftk+tFik1 for all t=1,2,...if Δk[0,ω2)

其中:

Δ k = ∣ ∣ t r ( F i k + λ I ) ∣ − ∣ t r ( F i k − 1 + λ I ) ∣ ∣ ∣ t r ( F i k − 1 + λ I ) ∣ \Delta^k=\frac{||tr(F^k_i+\lambda I)|-|tr(F^{k-1}_i+\lambda I)||}{|tr(F^{k-1}_i+\lambda I)|} Δk=tr(Fik1+λI)∣∣tr(Fik+λI)tr(Fik1+λI)∣∣

硬件感知矩阵切分

THOR在将费雪矩阵按层解耦的基础上,进一步假设每个网络层中的输入和输出块之间也是独立的,例如将每层网络的输入输出切分为n个块,这n个块之间即是独立的,根据该假设对二阶信息矩阵做进一步的切分,从而提高了计算效率。THOR结合矩阵信息损失数据和矩阵性能数据确定了矩阵分块维度,从而大大提升费雪矩阵求逆时间。

那么如何确定矩阵分块维度的呢。具体方法为:

(1)根据费雪矩阵中维度最大的那一层,确定矩阵切分维度,拿ResNet50举例,网络层中的最大维度为2048,确定矩阵切分维度为[1,16,32,64,128,256,512,1024,2048]。

(2)根据确定的矩阵维度,根据谱范数计算每个维度下的矩阵损失,具体公式为:
L = 1 − λ m a x    ( A ^ A ^ T ) λ m a x    ( A A T ) L=1-\sqrt{\frac{\lambda_{max}\ \ (\hat{A}\hat{A}^T)}{\lambda_{max}\ \ (AA^T)}} L=1λmax  (AAT)λmax  (A^A^T)
3)根据确定的矩阵维度,计算每个维度下的矩阵求逆时间,再通过公式

得到每个维度下标准化后性能数据,其中
表示维度最小的矩阵的性能数据,
表示第n个维度下的性能数据。

(4)根据标注化后的矩阵损失信息和标准化后的性能数据绘图,如以ResNet50为例,如图3所示,图中下降曲线为性能曲线,上升曲线表示矩阵损失曲线,图中交叉点为106,与128最接近,最后确定矩阵切分维度为128。
在这里插入图片描述

实验结果

图4展示了THOR在ResNet50+ImageNet,batchsize为256时一二阶上的训练曲线图,其中train loss表示训练误差,test accuracy表示测试精度,epoch表示迭代数,wall-clock time表示时间,其中下降较快的曲线和上升较快的曲线是本算法曲线,另外差距较明显的曲线是momentum的训练曲线。
在这里插入图片描述
THOR还测试了在不同batchsize下ResNet50+ImageNet的收敛结果,结果见下图5,其中Hardware表示硬件平台,Software是指使用的深度学习框架,Batch size是每次训练的图片数量,Optimizer表示使用的优化器,Time指总体训练时间,Accuracy是指最后收敛精度。当batchsize为8192,使用256块Atlas训练系列产品时,只需2.7分钟精度即可收敛到75.9%。
在这里插入图片描述
在BERT+WIkipedia中,THOR也有不错的表现效果,以MLPerf为标准,精度达到71.2%,与一阶相比端到端提升30%,实验结果见图6,图中横坐标表示训练时间,纵坐标表示测试精度,上升较快的曲线是THOR的训练曲线,另一条为Lamb的训练曲线。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2187498.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Neo4j CQL语句 使用教程

CREATE命令 : CREATE (<node-name>:<label-name>{ <Property1-name>:<Property1-Value>........<Propertyn-name>:<Propertyn-Value>} )字段说明 CREATE (dept:Dept { deptno:10,dname:“Accounting”,location:“Hyderabad” })&#…

ATLAS/ICESat-2 L3B 每 3 个月网格动态海洋地形图 V001

目录 简介 摘要 代码 引用 网址推荐 0代码在线构建地图应用 机器学习 ATLAS/ICESat-2 L3B Monthly 3-Month Gridded Dynamic Ocean Topography V001 ATLAS/ICESat-2 L3B 每月 3 个月网格动态海洋地形图 V001 简介 该数据集包含中纬度、北极和南极网格上动态海洋地形&…

[Offsec Lab] ICMP Monitorr-RCE+hping3权限提升

信息收集 IP AddressOpening Ports192.168.52.218TCP:22,80 $ nmap -p- 192.168.52.218 --min-rate 1000 -sC -sV -Pn PORT STATE SERVICE VERSION 22/tcp open ssh OpenSSH 7.9p1 Debian 10deb10u2 (protocol 2.0) | ssh-hostkey: | 2048 de:b5:23:89:bb:9f:d4:1…

Kubernetes-Kind篇-01-kind搭建测试集群

1、Kind 介绍 官方文档地址&#xff1a;https://kind.sigs.k8s.io/ github仓库地址&#xff1a;https://github.com/kubernetes-sigs/kind 国内镜像仓库地址&#xff1a;https://gitcode.com/gh_mirrors/ki/kind/overview kind 是一种使用 Docker 容器 nodes 运行本地 Kubern…

算法日记-链表翻转

文章目录 场景&#xff1a;解法一&#xff1a;迭代步骤完整代码 解法二&#xff1a;递归步骤完整代码 重温力扣常规算法&#xff0c;记录算法的演变&#xff0c;今天介绍的是链表翻转 场景&#xff1a; 现在有一条单项链表&#xff0c;链表节点存在一个数据和指向下一个节点的…

MySQL--三大范式(超详解)

目录 一、前言二、三大范式2.1概念2.2第一范式&#xff08;1NF&#xff09;2.3第二范式&#xff08;2NF&#xff09;2.3第三范式&#xff08;3NF&#xff09; 一、前言 欢迎大家来到权权的博客~欢迎大家对我的博客进行指导&#xff0c;有什么不对的地方&#xff0c;我会及时改进…

AI不可尽信

看到某项目有类似这样的一段代码 leaves : make([]int, 10) leaves leaves[:0]没理解这样的连续两行,有何作用? 初始化一个长度和容量都为10的切片,接着把切片长度设置为0 即如下demo: (在线地址) package mainimport "fmt"func main() {leaves : make([]int, 1…

【2023工业3D异常检测文献】CPMF: 基于手工制作PCD描述符和深度学习IAD结合的AD方法

Complementary Pseudo Multimodal Feature for Point Cloud Anomaly Detection 1、Background 早期的点云异常检测(PCD)表示是手工制作的&#xff0c;依赖于启发式设计。随着深度学习的发展&#xff0c;最近的方法采用了基于学习的PCD特征。尽管与基线相比有相当大的改进&…

欧几里得算法--(密码学基础)

根基&#xff1a;gcd(a,b)gcd(b,a mod b) 先举个例子吧&#xff0c;gcd(16,6)gcd(6,4)gcd(4,2)gcd(2,0)2 学习这个定理的时候我想了几个问题. 第一个问题&#xff1a;为什么求出的就一定是他们两个数的公约数&#xff1f; 这个问题很简单我们只需要通过几何来计较即可&#x…

MyBatis——ORM

MyBatis——ORM 验证映射配置ResultType本质是ResultMap具体的转换逻辑 概括的说&#xff0c;MyBatis中&#xff0c;对于映射关系的声明是由开发者在xml文件手动完成的。比如对查询方法而言&#xff0c;你需要显式声明ResultType或ResultMap&#xff0c;这里其实就是在定义数据…

Java JUC(三) AQS与同步工具详解

Java JUC&#xff08;三&#xff09; AQS与同步工具详解 一. ReentrantLock 概述 ReentrantLock 是 java.util.concurrent.locks 包下的一个同步工具类&#xff0c;它实现了 Lock 接口&#xff0c;提供了一种相比synchronized关键字更灵活的锁机制。ReentrantLock 是一种独占…

【Kubernetes】常见面试题汇总(五十三)

目录 118. pod 状态为 ErrlmagePull &#xff1f; 119.探测存活 pod 状态为 CrashLoopBackOff &#xff1f; 特别说明&#xff1a; 题目 1-68 属于【Kubernetes】的常规概念题&#xff0c;即 “ 汇总&#xff08;一&#xff09;~&#xff08;二十二&#xff09;” 。…

uniapp使用字体图标 ttf svg作为选项图标,还支持变色变图按

在staic目录下有一些ttf文件&#xff0c;如uni.ttf&#xff0c;iconfont.ttf 这些文件中保存这字体svg的源码们&#xff0c;我们也可以在网上找其他的。这些就是我们要显示的突图标的 显示来源。这样不用使用png图标&#xff0c;选中不选中还得用两个图片 我的具体使用如下 &q…

Python入门--循环语句

目录 1. while循环基础语法 2. while循环的嵌套 3. while实现九九乘法表 4. for循环基础语法 5. for循环的嵌套 6. for循环实现九九乘法表 7. 循环中断&#xff1a;break和continue 循环普遍存在于日常生活中&#xff0c;同样&#xff0c;在程序中&#xff0c;循环功能也…

thinkphp6入门(25)-- 分组查询 GROUP_CONCAT

假设表名为 user_courses&#xff0c;字段为 user_id 和 course_name&#xff0c;存储每个用户选修的课程&#xff0c;想查询每个学生选修的所有课程 SQL 原生查询 SELECT user_id, GROUP_CONCAT(course_name) as courses FROM user_courses GROUP BY user_id; ThinkPHP 代码…

python常用库总结(argparse、re、matlpotlab.plot)

文章目录 1.argparse库字符串&#xff08;str&#xff09;布尔值&#xff08;bool&#xff09;选择&#xff08;choices&#xff09;计数&#xff08;count&#xff09;常量&#xff08;store_const 和 store_true&#xff09;多个值&#xff08;nargs&#xff09;可选参数&…

用Python实现运筹学——Day 9: 线性规划的灵敏度分析

一、学习内容 1. 灵敏度分析的定义与作用 灵敏度分析&#xff08;Sensitivity Analysis&#xff09; 是在优化问题中&#xff0c;分析模型参数变化对最优解及目标函数值的影响。它帮助我们了解在线性规划模型中&#xff0c;当某些参数&#xff08;如资源供应量、成本系数等&a…

SQLServer CXPACKET等待事件

文章目录 SQL Server 中的 CXPACKET 等待类型是最容易被误解的等待统计之一。CXPACKET 这个术语来源于 “Class Exchange Packet”&#xff08;类交换包&#xff09;。其本质可以描述为在单个进程的两个并行线程之间交换数据行的过程。其中一个线程是“生产者线程”&#xff0c…

理解Matplotlib构图组成

介绍 Matplotlib 是 Python 中最流行的数据可视化库之一。它提供了一系列丰富的工具&#xff0c;可以绘制高度自定义且适用于各种应用场景的图表。无论你是数据科学家、工程师&#xff0c;还是需要处理数据图形表示的任何人&#xff0c;理解如何操作和定制 Matplotlib 中的图表…

ElasticSearch 备考 -- 备份和恢复

一、题目 备份集群下的索引 task&#xff0c;存储快照名称为 snapshot_1 二、思考 这个涉及的是集群的备份&#xff0c;主要是通过创建快照&#xff0c;涉及到以下2步骤 Setp1&#xff1a;注册一个备份 snapshot repository Setp2&#xff1a;创建 snapshot 可以通过两种方…