tensorflow算子注册以及op详解

news2025/1/12 22:55:50

在自定义的算子时,经常遇到一些函数和宏,这里介绍一下常见的函数和宏

REGISTER_OP

首先我们来思考REGISTER_OP的作用是什么?当我们定义一个tensorflow的算子,首先我们需要tensorflow知道这个算子,也就是说我们要把这个算子注册到tensorflow的算子库中。这个注册过程包括3部分的工作:

1. 我们提供注册的算子的基本信息,其中包括:

算子的名字肯定是要提供的,字符串格式

算子的输入输出的名称,输入输出的类型。字符串格式

提供输出的形状,这个取决于输入的形状,我们需要提供一个输出形状的推断方法,这个方法根据输入的形状计算输出的形状,所以这个推断方法必然是一个接口或者函数。

2. 我们提供的信息是抽象的,都是字符串或者函数,还不能用于生成op。好在tensorflow已经提供好了接口,我们只需要调用这个接口,输入我们提供的信息,即可创建一个对象,这个对象使得用户可以生成一个op。这个对象就是opdefbuilder。这也是为什么我们在自定义算子的时候要include很多tensorflow的代码,就是为了把这些接口include进来。

3. 把生成opdefbuilder 注册到tensorflow中,使其能够被调用。

其中第一部分工作,肯定是我们完成,REGISTER_OP 就是为了完成第2部分的工作。我们来看REGISTER_OP 是怎么完成第二部分工作的。

REGISTER_OP的最常用的方法是

REGISTER_OP("ZeroOutKsz")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

上面的字符串ZeroOutKsz, to_zero: int32  以及SetShapeFn 后面的lambda函数 其实就是我们提供的信息,即第一部分工作

我们来看第一行REGISTER_OP("ZeroOutKsz") ,REGISTER_OP是一个宏,定义在tensorflow/core/framework/op.h。这个宏经过几次其他宏的转换,最终是为了创建OpDefBuilderWrapper。OpDefBuilderWrapper的代码如下

namespace register_op {

class OpDefBuilderWrapper {
 public:
  explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
  OpDefBuilderWrapper& Attr(std::string spec) {
    builder_.Attr(std::move(spec));
    return *this;
  }
  OpDefBuilderWrapper& Input(std::string spec) {
    builder_.Input(std::move(spec));
    return *this;
  }
  OpDefBuilderWrapper& Output(std::string spec) {
    builder_.Output(std::move(spec));
    return *this;
  }
  OpDefBuilderWrapper& SetIsCommutative() {
    builder_.SetIsCommutative();
    return *this;
  }
  OpDefBuilderWrapper& SetIsAggregate() {
    builder_.SetIsAggregate();
    return *this;
  }
  OpDefBuilderWrapper& SetIsStateful() {
    builder_.SetIsStateful();
    return *this;
  }
  OpDefBuilderWrapper& SetDoNotOptimize() {
    // We don't have a separate flag to disable optimizations such as constant
    // folding and CSE so we reuse the stateful flag.
    builder_.SetIsStateful();
    return *this;
  }
  OpDefBuilderWrapper& SetAllowsUninitializedInput() {
    builder_.SetAllowsUninitializedInput();
    return *this;
  }
  OpDefBuilderWrapper& Deprecated(int version, std::string explanation) {
    builder_.Deprecated(version, std::move(explanation));
    return *this;
  }
  OpDefBuilderWrapper& Doc(std::string text) {
    builder_.Doc(std::move(text));
    return *this;
  }
  OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) {
    builder_.SetShapeFn(std::move(fn));
    return *this;
  }
  OpDefBuilderWrapper& SetIsDistributedCommunication() {
    builder_.SetIsDistributedCommunication();
    return *this;
  }

  OpDefBuilderWrapper& SetTypeConstructor(OpTypeConstructor fn) {
    builder_.SetTypeConstructor(std::move(fn));
    return *this;
  }

  OpDefBuilderWrapper& SetForwardTypeFn(ForwardTypeInferenceFn fn) {
    builder_.SetForwardTypeFn(std::move(fn));
    return *this;
  }

  const ::tensorflow::OpDefBuilder& builder() const { return builder_; }

  InitOnStartupMarker operator()();

 private:
  mutable ::tensorflow::OpDefBuilder builder_;
};

}

在构造函数中,入参是字符串name,构造函数创建一个OpDefBuilder类builder_, 然后后面所有的类函数都是直接传给builder_, 然后类函数返回OpDefBuilderWrapper 本身,以方面用上面的链式定义。

主要用到的类函数包括

Attr:attr输入为一个字符串,attr的用法是

REGISTER_OP("opName")
    .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
    .Attr("T: {half, float, float64, int32, int64}")
    .Attr("num_devices: int")
    .Attr("shared_name: string")
    .Attr("XlaCompile: bool=true")

 "Attr"是一个允许自定义的值, 比如XLA引擎就根据自身需求提供了"XlaCompile", 如果一个Op将该值设置为true, 就会强制XLA引擎将其编译. 当然, 也可以设置一些无用的值, 就像函数声明里有一个并没有实际使用的参数, 除了浪费存储空间没有其他用途.

Input: 输入为一个字符串,input用于设置算子的输入
REGISTER_OP("ZeroOutKsz")
    .Input("input1: int32")
    .Input("input2: int32)
);
Output: 输入为一个字符串,用于设置算子的输出
REGISTER_OP("ZeroOutKsz")
    .Output("zeroed: int32")
);
SetShapeFn:  输入是一个接口或者函数,用于设置输出的形状,例如下面的例子就是输入了一个lambda 函数为了在创建图的时候就能够实现tensor形状的自洽。改接口经常以shape_inference::InferenceContext 为输入, InferenceContext后面会单独讲

用法是

REGISTER_OP("ZeroOutKsz")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    })
);

从上面的分析可以看到REGISTER_OP 的主要过程就是通过几层宏,生成了一个OpDefBuilderWrapper,而OpDefBuilderWrapper的构造函数和几个主要函数都是为了生成与修改OpDefBuilder。我们光看名字就知道OpDefBuilder 可以用来生成一个op。至此REGISTER_OP 就完成了OpDefBuilder的生成。

那么第3部分工作,即把这个算子注册到tensorflow是什么时候完成的呢?这部分工作是在我们load动态链接库的时候tensorflow 自动完成的,有一个opDefBuilderReceiver 会读取.so中的OpDefBuilder ,然后注册到算子库中,这个过程不需要我们研究的太深入。

InferenceContext

InferenceContext用于在注册op时提前做形状推断时的输入,这里有一个概念非常容易混淆,一定要弄清楚。不同于OpDefBuilder, OpDefBuilderWrapper 这些在算子注册时候的类,InferenceContext是一个在算子调用的时候生成的类。也就是说InferenceContext 不在算子注册的时候生成,而是用户在调用OpDefBuilder 的时候生成!!重要的事说三遍

InferenceContext 不在算子注册的时候生成,而是用户在调用OpDefBuilder 的时候生成!

InferenceContext 不在算子注册的时候生成,而是用户在调用OpDefBuilder 的时候生成!

InferenceContext 不在算子注册的时候生成,而是用户在调用OpDefBuilder 的时候生成!

这也好理解,SetShapeFn 本身就是一个接口,肯定是有具体数据传进来的时候才有意义。理解了这一点,后面才能理解。

InferenceContext和很多已经定义好的接口都属作用域shape_inference,shape_inference中包含了很多接口和函数,例如已经定义好的shape_inference::MatMulShape,这些也包含了形状推断用到的所有信息。要想更好的了解InferenceContext 必须要要了解shape_inference 中的一些对象。

shape_inference 的具体代码可以见:tensorflow/core/framework/shape_inference.h, 其中包含6个主要对象: shapehandle,Shape,DimenionHandle,Dimension ; InferenceContext, shapemanager。

其中shapehandle,Shape,DimenionHandle,Dimension 四个对象的包含关系如下:

ShapeHandle:只有一个主要的属性就ptr,就是 Shape的指针

Shape 有两个主要属性,rank和dims, rank 是int类型表示tesnor的有多少个维度,dim表示类型是vector,表示这三个维度的宽度对象DimensionHandle

DimensionHandle 也只有一个主要属性,ptr,是指向Dimension的指针

Dimension : 则只有一个主要属性是int,表示宽度。

这样层层嵌套的关系有点像tensorflow的中的featurecolumn的结构。

举例来说  一个tensor的维度是[3,4,5] 那么 这个tensor的shape类的rank 就是3, DimensionHandle 本质就是三个指针,这三个指针指向的值就是3,4,5。而整个shape的指针就是shapehandle。

上面铺垫完了,正式讲InferenceContext,InferenceContext是一个对象,构造函数中,最重要的输入是

const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,

其中input_shapes表示输入的形状,input_tensors 表示输入的tensor。

最重要的几个属性如下

  std::vector<ShapeHandle> inputs_;
  std::vector<const Tensor*> input_tensors_;
  std::vector<bool> requested_input_tensor_;
  std::vector<ShapeHandle> outputs_;
  // Can have fewer elements than inputs_.
  std::vector<ShapeHandle> input_tensors_as_shapes_;
  std::vector<bool> requested_input_tensor_as_partial_shape_;


inputs_和outputs_ 是最核心的两个属性,inputs_在构造函数中被赋值,就是上面传入的input_shapes。 outputs_一开始为空,整个SetShapeFn 就是为了给outputs_  赋值,赋值以后形状推断即结束,outputs_ 就是输出的形状。这个形状会流向下一个节点,由下一个节点判断形状是否有问题。

算子生成和算子注册

我们自定义一个算子的本质是为了,利用这个算子生成一个op,所以必须要知道我们自定义的算子是怎么生成op的。

对于前面的过程我们已经理清楚了,过程是这样的:

调用REGISTER_OP 生成一个 OpDefBuilderWrapper 类,给OpDefBuilderWrapper 传入我们算子的信息:函数名,输入输出名称、格式,形状推断方法。OpDefBuilderWrapper 回生成一个OpDefBuilder,在load动态链接库时自动注册到tensorflow算子库中。

然后用户在Python侧调用这个函数的时候,就等于在调用相应的OpDefBuilder,这里需要注意的是,OpDefBuilder不是简单地生成一个OpDef,而是会生成一个结构体OpRegistrationData, 这个结构体包括两个成员:OpDef, OpShapeInferenceFn,其中OpShapeInferenceFn 是一个以InferenceContext 为输入的函数。用户会输入具体的tensor,输入的内容传给OpDef, OpShapeInferenceFn 完成op的定义以及形状的推断

借用 

HaoBBNuanMMicon-default.png?t=MBR7https://blog.csdn.net/HaoBBNuanMM

的一张图片就是 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/156561.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WeLink的使用

我这里是注册的企业端 流程>手机号验证码 注册成功后登陆 进入首页面 按操作逐步完成信息需求 因个体使用情况不同 在角色分类和组织架构中可根据自己部门或单位的分工分类 【拉人】&#xff1a; 三种方式 主要就是网址超链接和企业码 前提需要用户先注册 【加入审核】是根…

Nginx——反向代理解决跨域问题(Windows)

这个破玩意是真麻烦&#xff0c;必须写一篇文章避避坑了。一、先看看大佬的解释&#xff0c;了解反向代理和跨域问题吧&#xff1a;Nginx反向代理什么是跨域问题二、OK&#xff0c;直接开工&#xff0c;装Nginx下载地址: http://nginx.org/en/download.html如图所示, 选择相应的…

Flink多流转换(Flink Stream Unoin、Flink Stream Connect、Flink Stream Window Join)

文章目录多流转换1、分流操作1.1、在flink 1.13版本中已弃用.split()进行分流1.2、使用&#xff08;process function&#xff09;的侧输出流&#xff08;side output&#xff09;进行分流2、基本合流操作2.1、联合&#xff08;Flink Stream Union&#xff09;2.2、连接&#x…

【Go】实操使用go连接clickhouse

前言 近段时间业务在一个局点测试clickhouse&#xff0c;用java写的代码在环境上一直连接不上clickhouse服务&#xff0c;报错信息也比较奇怪&#xff0c;No client available&#xff0c;研发查了一段时间没查出来&#xff0c;让运维这边继续查&#xff1a; 运维同学查了各种…

OAuth 2.0授权框架详解

简介 在现代的网站中&#xff0c;我们经常会遇到使用OAuth授权的情况&#xff0c;比如有一个比较小众的网站&#xff0c;需要用户登录&#xff0c;但是直接让用户注册就显得非常麻烦&#xff0c;用户可能因为这个原因而流失&#xff0c;那么该网站可以使用OAuth授权&#xff0…

FactoryBean和BeanFactory的区别

1. 前言 “BeanFactory和FactoryBean的区别是什么&#xff1f;&#xff1f;&#xff1f;” 这是Spring非常高频的一道面试题&#xff0c;BeanFactory是Spring bean容器的顶级接口&#xff0c;负责创建和维护容器内所有的bean对象。而FactoryBean是用来创建一类bean的接口&…

数字新基建之数据云

自2021年“新基建”概念火爆以来&#xff0c;相关的政策和技术都不断跟进和发展&#xff0c;由于“新基建”本质上是基础设施向数字化、智能化、网络化方向发展&#xff0c;因此更多的科技领域从业者和投资者都将其称为“数字新基建”。而数据库、数据仓库、大数据平台和数据云…

C语言:整数的存储方式

整数的存储方式 char类型在存储时是按照ASCII码值进行存储&#xff0c;存储方式与整型一致 有符号数与无符号数 char一个字节signed charunsigned char int四个字节signed intunsigned int 各种类型数据均分为有符号和无符号类型&#xff0c;当定义一个int类型或char类型的数…

备库为什么会延迟好几个小时?

在上一篇文章中,我和你介绍了几种可能导致备库延迟的原因。你会发现,这些场景里,不论是偶发性的查询压力,还是备份,对备库延迟的影响一般是分钟级的,而且在备库恢复正常以后都能够追上来。 但是,如果备库执行日志的速度持续低于主库生成日志的速度,那这个延迟就有可能…

百度搜索留痕推广资源整理如何收录排名的?

每日分享&#xff1a;百度对图文类内容的优质标准 &#xff08;1&#xff09;文字的字体、字号与间距需要适配网页&#xff0c;文档分段合理&#xff0c;结构有序&#xff0c;阅读体验舒适。 &#xff08;2&#xff09;在文章中使用小标题准确概括段意&#xff0c;通过加粗、…

vue3 setup语法糖父子组件传值,让女友看得明明白白

前言 最近在想做个cloud项目,gitee上找了个模板项目&#xff0c;前端使用到vue3 typeScript&#xff0c;最近使用到vue3 的父子组件之间的传值&#xff0c;顺便学习一下&#xff0c;在此总结一下&#xff0c;若有不足之处&#xff0c;望大佬们可以指出。 vue3官网&#xff1a…

栈--专题讲解

文章目录基本概念模拟栈数据结构-栈&#xff1a;stack头文件定义基本操作实例&#xff1a;火车进栈题目大意解题思路AC代码基本概念 栈的定义 栈(stack)是限定仅在表尾进行插入或者删除的线性表。对于栈来说&#xff0c;表尾端称为栈顶&#xff08;top&#xff09;&#xff0c…

web服务器----基于http协议搭建的静态网站详解

一&#xff0c;WWW的简介 1、什么是 www www 是 world wide web 的缩写&#xff0c;也就是全球信息广播的意思。通常说的上网就是使用 www 来查询用户所需要的信息。www 可以结合文字、图形、影像以及声音等多媒体&#xff0c;并通过可以让鼠标单击超链接的方式将信息以 Inter…

Docker容器搭建及基本使用

一、安装环境 操作系统&#xff1a;CentOS 7&#xff08;建议用7或以上&#xff0c;因为6版本有部分功能不兼容&#xff09; 二、Docker安装 1、卸载旧版本 yum remove docker \docker-client \docker-client-latest \docker-common \docker-latest \docker-latest-logrota…

linux修改密码报错‘Authentication token manipulation error‘

本次事故使用操作系统为centos7 1、报错起因&#xff1a; 利用chage设置root用户密码定期更换后&#xff0c;到期之后登录系统&#xff0c;输入密码之后&#xff0c;提示要改密码&#xff0c;输入新密码之后&#xff0c;报错 ‘Authentication token manipulation error’ &a…

【k8s系列】gvisor安装与containerd集成

文章目录安装与containerd集成下发runtimeclass资源修改containerd配置文件准备pod的yaml文件参考资料author: ningan123date: ‘2023-01-11 21:23’updated: ‘2023-01-11 21:31’安装 安装地址&#xff1a;Installation - gVisor ARCH$(uname -m)URLhttps://storage.googlea…

Gotify <2.2.3 存在反射型 XSS 漏洞(MPS-2023-0815)

漏洞描述 Gotify 是 Go 语言开发的开源组件&#xff0c;用作于发送和接收消息的服务器。 由于 2.2.3 之前版本的 Gotify 使用具有反射型 XSS 漏洞版本的 swagger-ui 生成文档&#xff0c;当用户访问 Gotify /docs 页面时存在反射型 XSS 漏洞。 攻击者可诱导 Gotify 用户点击…

【学习笔记】【Pytorch】四、torchvision中的数据集使用

【学习笔记】【Pytorch】四、torchvision中的数据集使用学习地址主要内容一、datasets模块介绍二、datasets.CIFAR10类的使用1.使用说明2.代码实现学习地址 PyTorch深度学习快速入门教程【小土堆】. 主要内容 一、datasets模块介绍 介绍&#xff1a;一些加载数据的函数及常用…

P6:DataLoader的使用

1、准备数据集&#xff08;测试集&#xff09; import torchvisiontest_data torchvision.datasets.CIFAR10(./dataset, trainFalse, transformtorchvision.transforms.ToTensor()) 注意数据集中的图片是PIL的格式&#xff0c;需要格式转换。 2、使用DataLoader from torch…

HBase数据库总结(一)

1、 HBase的特点是什么&#xff1f;HBase是一个高可靠性、高性能、面向列、可伸缩的分布式存储系统&#xff0c;HBase不同于一般的关系数据库&#xff0c;它是一个适合于非结构化数据存储的数据库。1&#xff09;大&#xff1a;一个表可以有数十亿行&#xff0c;上百万列2&…