【自制C++深度学习框架】表达式层的设计思路

news2024/12/28 18:31:23

表达式层的设计思路

在深度学习框架中,Expression Layer(表达式层)是指一个通用的算子,其允许深度学习网络的不同层之间结合和嵌套,从而支持一些更复杂的操作,如分支之间的加减乘除(elementAdd)等。

PNNXExpession Layer中给出的是一种抽象表达式,会对计算过程进行折叠,消除中间变量,并且将具体的输入张量替换为抽象输入@0, @1等。PNNX生成的抽象表达式是这样的:

add(@0,mul(@1,@2))
add(add(mul(@0,@1),mul(@2,add(add(add(@0,@2),@3),@4))),@5)

这就要求我们需要一个表达式解析语法树构建的功能。

  • 表达式解析:将表达式拆分成多个token。

  • 语法树构建:这是一个二叉树,以操作符作为根节点,可根据后序遍历生成逆波兰表达式。

语法树

因此,我们定义了如下类型:

  • TokenType 是一个枚举类型,用于表示不同的符号类型。该枚举包含了 TokenUnknown、TokenInputNumber、TokenComma、TokenAdd、TokenMul、TokenLeftBracket、TokenRightBracket 七种类型。

  • Token 是一个结构体,表示每个具体的 Token 的信息,包括 Token 类型、Token 开始和结束的位置 3 个属性。

  • TokenNode 是一个结构体,表示语法树的节点,包括节点存储的数值、左子节点和右子节点 3 个属性。

  • ExpressionParser 是一个生成符号表达式语法树的类,包括 Tokenizer()、Generate()、tokens()、token_strs() 等方法,用于对输入的字符串进行分析和处理,并生成符号表达式语法树。

    • 其中 Tokenizer() 方法进行词法分析,将输入字符串解析成多个 token,即各种符号的组合,以便后续进行语法分析;
    • Generate() 方法进行语法分析,利用词汇分析的结果生成符号表达式语法树;
    • tokens() 和 token_strs() 方法分别返回词法分析的结果和词语字符串。

表达式解析

表达式解析器 ExpressionParser 的 Tokenizer() 用于对输入的字符串进行词法分析。

即,词法分析将输入字符串切分为多个 token,使得后续的语法分析更具有可操作性。

具体如下:

  • Tokenizer() 方法的参数 re_tokenize 用于控制是否需要对已经存在的 tokens 进行重新处理,如果不需要重新处理且 tokens 已经存在,则直接返回。这样可以避免在不需要处理时多次调用词法分析函数。

  • 通过 std::isspace© 对每个字符进行检查,去除所有空格和制表符。

  • 对于不同的字符,Tokenizer() 方法使用 if/else 分支结构判断,生成不同类型的 token,包括 TokenAdd、TokenMul 和 TokenInputNumber。
    - 当字符为a的时候,我们的词法规定在token中以a作为开始的情况只有add,所以我们判断接下来的两个字符必须是d和d,如果不是的话就报错,如果是的话就初始化一个新的token并进行保存。
    - 当字符为m的时候,我们的语法规定token中以m作为开始的情况只有mul,所以我们判断接下来的两个字必须是u和l,如果不是的话就报错,是的话就初始化一个token进行保存mul。

  • 每当分析出一个 token 后,就将其记录到 tokens_ 和 token_strs_ 中,以备后续语法分析过程的使用。tokens_ 和 token_strs_ 分别存储 token 和对应的文本形式字符串。

Tokenizer() 方法的主要作用是对输入的字符串进行切割,将其划分为多个基础的符号,方便后续的语法分析进一步处理。

通过这种方式,表达式解析器可以更好地理解输入的表达式,从而为相应的计算程序提供原始数据和计算逻辑。

语法树的构建

ExpressionParser 的 Generate_() 方法用于生成符号表达式语法树。

Generate_() 方法是一个递归函数,用于从输入的 Token 序列中按照先序遍历顺序生成符号表达式语法树的节点,并返回根节点。

其中:

  • Generate_() 方法接受一个整型引用参数 index,用于指示当前正在处理的 token 的下标,在递归结构中传递以保证后续的深度遍历。

  • Generate_() 方法首先对 index 所指向的 token 进行检查,如果这个 token 不是数字、加号或乘号,则抛出异常。

  • 如果当前 token 是数字,则从 statement_ 字符串中提取数字并创建一个没有子节点的叶子节点,并返回该叶子节点的指针。

  • 如果当前 token 是加号或乘号,则需要生成一个新节点并将其 left 和 right 子节点连接起来。然后递归调用本身以生成其 left 和 right 子节点,并将指向子节点的指针赋值给新节点的相应成员变量。

  • 在处理加号或乘号情况时,需要保证当前 token 后面的第二个 token 是左括号,如果不是则抛出异常;而且还需要处理其中的 left 和 right 子节点,左节点与右节点之间用逗号隔开。

  • 最后,Generate_() 返回该树的根节点。

简单来说:
遍历tokens_容器中的每一个token
- 如果当前Token类型是数字,那么需要创建并返回TokenNode
- 如果当前Token类型是mul或者add,那么需要向下递归构建对应的左子树和右子树
- 1,判断是不是“(”
- 2,获取@num
- 3,递归调用自身,创建左子树
- 4,判断是不是“,”
- 5,获取@num
- 6,递归调用自身,创建右子树
- 7,判断是不是“)”

Generate_() 方法实现了从输入 Token 序列构建符号表达式语法树的逻辑,具有很好的可靠性和健壮性。

逆波兰表达式

ReversePolish() 函数接受一个指向符号表达式语法树根节点的指针和一个 vector,用于存储生成的后缀表达式。

其中:

  • ReversePolish() 函数采用后序遍历算法,先处理左子节点,然后是右子节点,最后才处理根节点。

  • 当 root_node 不为空时,递归调用 ReversePolish() 函数以处理其 left 和 right 子节点,使得 vector 中保存的各个节点的次序即根据后缀表达式的次序排列。

  • 对于根节点,将其推入 vector 的末尾,即实现了后缀表达式的构建。

该函数的主要作用是将一个给定符号表达式语法树转换成后缀表达式,并保存在 vector 中,方便后续基于后缀表达式的计算。

通常在表达式计算场景中,后缀表达式具有更好的可处理性和操作性,因此其转换成后缀表达式后可更快地求值。

ExpressionLayer的定义

我们定义了一个名为 ExpressionLayer 的类,继承自 Layer 类,用于将输入的表达式字符串解析为后缀表达式,并对给定的张量进行基于后缀表达式的计算。

其中:

  • ExpressionLayer 的构造函数接受一个字符串表达式 statement,创建一个 ExpressionParser 对象。

  • ExpressionLayer 的 Forward() 函数是重写 Layer 接口的纯虚函数,输入参数为一个浮点数类型的张量列表 inputs,输出参数为一个浮点数类型的张量列表 outputs。Forward()函数通过 ExpressionParser 对象解析出的后缀表达式计算结果,并更新到 outputs 参数。

  • GetInstance() 是 ExpressionLayer 的一个静态成员函数,用于创建 ExpressionLayer 对象并返回一个指向该对象的智能指针。该函数接受一个 RuntimeOperator 指针对象作为参数,并利用 op->param_attr 中包含的表达式串创建ExpressionLayer 对象,然后返回该对象的智能指针。

ExpressionLayer 类主要作用是将用户输入的表达式字符串解析为后缀表达式,并提供基于后缀表达式的计算功能。

ExpressionLayer的实例化

ExpressionLayer 类中的静态成员函数 GetInstance() 用于从 op 中解析出表达式字符串,创建 ExpressionLayer 对象,并返回 ParseParameterAttrStatus 枚举类型。

  • 首先,检查 RuntimeOperator 对象 op 是否为空,如果为空则抛出异常。

  • 然后,检查 params map 中是否包含 “expr” 字符串作为 key 。如果不包含,则返回 kParameterMissingExpr 枚举。

  • 接下来,检查 params[“expr”] 对象是否为语句字符串类型的运行参数。如果不是,则返回 kParameterMissingExpr 枚举。

  • 最后,通过 std::make_shared 创建 ExpressionLayer 指针对象,传入 statement_param->value 来初始化 ExpressionLayer 类的一个新实例;同时返回 kParameterAttrParseSuccess 枚举,说明解析成功。

此外, LayerRegistererWrapper kExpressionGetInstance("pnnx.Expression", ExpressionLayer::GetInstance)使得 ExpressionLayer::GetInstance() 函数在工厂类中注册绑定到名称为 “pnnx.Expression”,以便根据之前已经被注册的名称获取对应的层实例。

ExpressionLayer的前向传播

ExpressionLayer 类的成员函数 Forward() 用于根据输入张量列表计算输出张量,并返回 InferStatus 枚举类型。

  • 首先,对输入和输出张量进行检查。如果输入张量为空或者输出张量为空,则分别返回 kInferFailedInputEmpty 和 kInferFailedInputOutSizeAdaptingError 两种错误码。

  • 随后,检查解析器对象指针是否为空,表达式解析并获取 tokens,如果解析失败,输出错误。

  • 接下来,检查所有的输入张量是否为非空且一个维度中至少有一个元素,如果条件不满足,则返回 kInferFailedInputEmpty 错误码。

  • 使用 batch_size 获取输出张量数量,并在遍历输出张量时对每个张量进行初始化(赋值为0)。

  • 然后创建一个栈用于保存从张量张量取出来的操作数。接着,构建语法树,之后进入循环,依次遍历 TokenNode 树中的每个节点。

    • 遇到操作数时,获取 num_index 属性(即操作数的索引号),并根据当前的 batch_size 进行计算,从输入张量列表中截取 num_index 所表示的张量,并按照批次号将取出来的同一批次的张量打包成 vector 加入栈中。
    • 当遇到运算符时,需要取两个元素出栈进行计算,并将计算结果再次压入栈中。
    • 在此过程中,为了支持 batch 操作,对于取出的每对操作数也应按批次进行访问,计算结果张量也应作为一个显式的 vector 同时加入输出张量列表 outputs 中。多线程实现可以通过OpenMP加速。
  • 最后,当遍历结束时,此时栈中仅剩下一个元素,即为计算结果。将其出栈,并将输出张量列表中的操作数更新为该元素即可。

ExpressionLayer实现了逆波兰表达式的计算过程,支持二元运算符加、乘,并能够处理 batch 式数据。

阅读的代码

  • include
    • parser
      • parse_expression.hpp
  • source
    • parser
      • parse_expression.cpp
    • layer
      • details
        • expression.hpp
        • expression.cpp

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/609117.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PyTorch 深度学习 || 专题二:PyTorch 实验框架的搭建

PyTorch 实验框架的搭建 1. PyTorch简介 PyTorch是由Meta AI(Facebook)人工智能研究小组开发的一种基于Lua编写的Torch库的Python实现的深度学习库,目前被广泛应用于学术界和工业界,PyTorch在API的设计上更加简洁、优雅和易懂。 1.1 PyTorch的发展 “…

Numpy---生成数组的方法、从现有数组中生成、生成固定范围的数组

1. 生成数组的方法 np.ones(shape, dtypeNone, orderC) 创建一个所有元素都为1的多维数组 参数说明: shape : 形状; dtypeNone: 元素类型; order : {‘C’,‘F’},可选,默认值:C 是否在内…

BPMN2.0自动启动模拟流程

思路:BPMN的流程模拟启动,主要是通过生成令牌,并启动令牌模拟 流程模拟的开启需要关键性工具:bpmn-js-token-simulation,需要先行下载 注:BPMN2.0的流程模拟工具版本不同,启动方式也不一样&am…

Kafka某Topic的部分partition无法消费问题

今天同事反馈有个topic出现积压。于是上kfk管理平台查看该topic对应的group。发现6个分区中有2个不消费,另外4个消费也较慢,总体lag在增长。查看服务器日志,日志中有rebalance 12 retry 。。。Exception,之后改消费线程停止。 查…

chatgpt赋能python:Python实现数据匹配的方法

Python实现数据匹配的方法 在数据分析和处理中,经常需要将两组数据进行匹配。Python作为一门强大的编程语言,在数据匹配方面也有着其独特的优势。下面我们将介绍Python实现数据匹配的方法。 数据匹配 数据匹配通常指的是将两组数据根据某些特定的规则…

理解calico容器网络通信方案原理

0. 前言 Calico是k8s中常用的容器解决方案的插件,本文主要介绍BGP模式和IPIP模式是如何解决的,并详细了解其原理,并通过实验加深理解。 1. 介绍Calico Calico是属于纯3层的网络模型,每个容器都通过IP直接通信,中间通…

试验SurfaceFlinger 中Source Crop

在 SurfaceFlinger 中,Source Crop 是用于指定源图像的裁剪区域的一个概念。Source Crop 可以理解为是一个矩形区域,它定义了源图像中要被渲染到目标区域的部分。在 Android 中,Source Crop 通常用于实现屏幕分辨率适应和缩放等功能。 在 Sur…

【Java基础篇】逻辑控制练习题与猜数字游戏

作者简介: 辭七七,目前大一,正在学习C/C,Java,Python等 作者主页: 七七的个人主页 文章收录专栏:Java.SE,本专栏主要讲解运算符,程序逻辑控制,方法的使用&…

2023_Python全栈工程师入门教程目录

2023_Python全栈工程师入门教程 该路线来自慕课课程,侵权则删,支持正版课程,课程地址为:https://class.imooc.com/sale/python2021 学习路线以三个项目推动,一步步夯实技术水平,打好Python开发基石 目录: 1.0 Python基础入门 2.0 Python语法进阶 3.0 Python数据…

windows系统典型漏洞分析

内存结构 缓冲区溢出漏洞 缓冲区溢出漏洞就是在向缓冲区写入数据时,由于没有做边界检查,导致写入缓冲区的数据超过预先分配的边界,从而使溢出数据覆盖在合法数据上而引起系统异常的一种现象。 ESP、EPB ESP:扩展栈指针&#xff08…

React.memo()、userMemo 、 userCallbank的区别及使用

本文是对以下课程的笔记输出,总结的比较简洁,若大家有不理解的地方,可以通过观看课程进行详细学习; React81_React.memo_哔哩哔哩_bilibili React76_useEffect简介_哔哩哔哩_bilibili React136_useMemo_哔哩哔哩_bilibili Rea…

直播录音时准备一副监听耳机,实现所听即所得,丁一号G800S上手

有些朋友在录视频还有开在线会议的时候,都会遇到一个奇怪的问题,就是自己用麦克风收音的时候,自己的耳机和别人的耳机听到的效果不一样,像是音色、清晰度不好,或者是缺少伴奏以及背景音嘈杂等,这时候我们就…

2023贵工程团体程序设计赛

A这是一道数学题&#xff1f; 道路有两边。 #include<bits/stdc.h> using namespace std; int main(){int n,m;cin>>n>>m;cout<<(n/m1)*2;return 0; } BCPA的团体赛 直接输出 。 #include <bits/stdc.h> using i64 long long; #define IOS…

Docker基本管理与网络以及数据管理

目录 一、Docker简介1、Docker简述2、什么是容器3、容器的优点4、Docker的logo及设计宗旨5、Docker与虚拟机的区别6、Docker的2个重要技术7、Docker三大核心概念 二、Docker的安装及管理1、安装Docker2、配置Docker加速器3、Docker镜像相关基础命令①搜索镜像②拉取镜像③查看镜…

Linux 配置Tomcat环境(二)

Linux 配置Tomcat环境 二、配置Tomcat1、创建一个Tomcat文件夹用于存放Tomcat压缩包2、把Tomcat压缩包传入服务器3、解压并启动Tomcat4、CentOS开放8080端口 二、配置Tomcat 1、创建一个Tomcat文件夹用于存放Tomcat压缩包 输入指令 cd /usr/local 进入到 usr/local 输入指令 …

[LsSDK][tool] ls_syscfg_gui2.0

文章目录 一、简介1.工具的目的2. 更新点下个更新 三、配置文件 一、简介 1.工具的目的 ① 可视化选择IO口功能。 ② 自由配置IO支持的功能。 ③ 适用各类MCU&#xff0c;方便移植和开发。 ④ 功能配置和裁剪&#xff08;选项-syscfg-待完成–需要适配keil语法有些麻烦&#…

Node.js: express + MySQL + Vue实现图片上传

前段时间用Node.js: express MySQL Vue element组件做了一个小项目&#xff0c;记录一下图片上传的实现。 将图片存入数据库有两种方法&#xff1a; 1&#xff0c;将图片以二进制流的方式存入数据库&#xff08;数据库搬家容易&#xff0c;比较安全&#xff0c;但数据库空间…

微服务实战项目-学成在线-媒资管理模块(有项目实战实现)

学成在线-媒资管理模块 1 模块需求分析 1.1 模块介绍 媒资管理系统是每个在线教育平台所必须具备的&#xff0c;查阅百度百科对它的定义如下&#xff1a; 媒体资源管理(Media Asset Management&#xff0c;MAM)系统是建立在多媒体、网络、数据库和数字存储等先进技术基础上…

SpringCloud服务接口调用

SpringCloud服务接口调用 OpenFeign 是什么? 能干啥? 两者区别 OpenFeign使用 接口注解 微服务调用接口FeignClient Feign在消费端使用 新建cloud-consumer-feign-order80 导入eureka和openfeign依赖: <dependency><groupId>org.springframework.cloud&l…

Nginx 中的Rewrite讲解

这里写目录标题 常用的Nginx正则表达式locationelocation 分类location 常用的匹配规则location 优先级 总结RewriteRewrite全局变量是什么?rewrite 执行顺序如下&#xff1a;语法格式&#xff1a;rewrite \<regex> \<replacement> [flag];flag标记说明基于域名的…