在自定义的算子时,经常遇到一些函数和宏,这里介绍一下常见的函数和宏
REGISTER_OP
首先我们来思考REGISTER_OP的作用是什么?当我们定义一个tensorflow的算子,首先我们需要tensorflow知道这个算子,也就是说我们要把这个算子注册到tensorflow的算子库中。这个注册过程包括3部分的工作:
1. 我们提供注册的算子的基本信息,其中包括:
算子的名字肯定是要提供的,字符串格式
算子的输入输出的名称,输入输出的类型。字符串格式
提供输出的形状,这个取决于输入的形状,我们需要提供一个输出形状的推断方法,这个方法根据输入的形状计算输出的形状,所以这个推断方法必然是一个接口或者函数。
2. 我们提供的信息是抽象的,都是字符串或者函数,还不能用于生成op。好在tensorflow已经提供好了接口,我们只需要调用这个接口,输入我们提供的信息,即可创建一个对象,这个对象使得用户可以生成一个op。这个对象就是opdefbuilder。这也是为什么我们在自定义算子的时候要include很多tensorflow的代码,就是为了把这些接口include进来。
3. 把生成opdefbuilder 注册到tensorflow中,使其能够被调用。
其中第一部分工作,肯定是我们完成,REGISTER_OP 就是为了完成第2部分的工作。我们来看REGISTER_OP 是怎么完成第二部分工作的。
REGISTER_OP的最常用的方法是
REGISTER_OP("ZeroOutKsz")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
上面的字符串ZeroOutKsz, to_zero: int32 以及SetShapeFn 后面的lambda函数 其实就是我们提供的信息,即第一部分工作
我们来看第一行REGISTER_OP("ZeroOutKsz") ,REGISTER_OP是一个宏,定义在tensorflow/core/framework/op.h。这个宏经过几次其他宏的转换,最终是为了创建OpDefBuilderWrapper。OpDefBuilderWrapper的代码如下
namespace register_op {
class OpDefBuilderWrapper {
public:
explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
OpDefBuilderWrapper& Attr(std::string spec) {
builder_.Attr(std::move(spec));
return *this;
}
OpDefBuilderWrapper& Input(std::string spec) {
builder_.Input(std::move(spec));
return *this;
}
OpDefBuilderWrapper& Output(std::string spec) {
builder_.Output(std::move(spec));
return *this;
}
OpDefBuilderWrapper& SetIsCommutative() {
builder_.SetIsCommutative();
return *this;
}
OpDefBuilderWrapper& SetIsAggregate() {
builder_.SetIsAggregate();
return *this;
}
OpDefBuilderWrapper& SetIsStateful() {
builder_.SetIsStateful();
return *this;
}
OpDefBuilderWrapper& SetDoNotOptimize() {
// We don't have a separate flag to disable optimizations such as constant
// folding and CSE so we reuse the stateful flag.
builder_.SetIsStateful();
return *this;
}
OpDefBuilderWrapper& SetAllowsUninitializedInput() {
builder_.SetAllowsUninitializedInput();
return *this;
}
OpDefBuilderWrapper& Deprecated(int version, std::string explanation) {
builder_.Deprecated(version, std::move(explanation));
return *this;
}
OpDefBuilderWrapper& Doc(std::string text) {
builder_.Doc(std::move(text));
return *this;
}
OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) {
builder_.SetShapeFn(std::move(fn));
return *this;
}
OpDefBuilderWrapper& SetIsDistributedCommunication() {
builder_.SetIsDistributedCommunication();
return *this;
}
OpDefBuilderWrapper& SetTypeConstructor(OpTypeConstructor fn) {
builder_.SetTypeConstructor(std::move(fn));
return *this;
}
OpDefBuilderWrapper& SetForwardTypeFn(ForwardTypeInferenceFn fn) {
builder_.SetForwardTypeFn(std::move(fn));
return *this;
}
const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
InitOnStartupMarker operator()();
private:
mutable ::tensorflow::OpDefBuilder builder_;
};
}
在构造函数中,入参是字符串name,构造函数创建一个OpDefBuilder类builder_, 然后后面所有的类函数都是直接传给builder_, 然后类函数返回OpDefBuilderWrapper 本身,以方面用上面的链式定义。
主要用到的类函数包括
Attr:attr输入为一个字符串,attr的用法是
REGISTER_OP("opName")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
.Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.Attr("XlaCompile: bool=true")
"Attr"
是一个允许自定义的值, 比如XLA引擎就根据自身需求提供了"XlaCompile"
, 如果一个Op将该值设置为true, 就会强制XLA引擎将其编译. 当然, 也可以设置一些无用的值, 就像函数声明里有一个并没有实际使用的参数, 除了浪费存储空间没有其他用途.
Input: 输入为一个字符串,input用于设置算子的输入
REGISTER_OP("ZeroOutKsz")
.Input("input1: int32")
.Input("input2: int32)
);
Output: 输入为一个字符串,用于设置算子的输出
REGISTER_OP("ZeroOutKsz")
.Output("zeroed: int32")
);
SetShapeFn: 输入是一个接口或者函数,用于设置输出的形状,例如下面的例子就是输入了一个lambda 函数为了在创建图的时候就能够实现tensor形状的自洽。改接口经常以shape_inference::InferenceContext 为输入, InferenceContext后面会单独讲
用法是
REGISTER_OP("ZeroOutKsz")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
})
);
从上面的分析可以看到REGISTER_OP 的主要过程就是通过几层宏,生成了一个OpDefBuilderWrapper,而OpDefBuilderWrapper的构造函数和几个主要函数都是为了生成与修改OpDefBuilder。我们光看名字就知道OpDefBuilder 可以用来生成一个op。至此REGISTER_OP 就完成了OpDefBuilder的生成。
那么第3部分工作,即把这个算子注册到tensorflow是什么时候完成的呢?这部分工作是在我们load动态链接库的时候tensorflow 自动完成的,有一个opDefBuilderReceiver 会读取.so中的OpDefBuilder ,然后注册到算子库中,这个过程不需要我们研究的太深入。
InferenceContext
InferenceContext用于在注册op时提前做形状推断时的输入,这里有一个概念非常容易混淆,一定要弄清楚。不同于OpDefBuilder, OpDefBuilderWrapper 这些在算子注册时候的类,InferenceContext是一个在算子调用的时候生成的类。也就是说InferenceContext 不在算子注册的时候生成,而是用户在调用OpDefBuilder 的时候生成!!重要的事说三遍
InferenceContext 不在算子注册的时候生成,而是用户在调用OpDefBuilder 的时候生成!
InferenceContext 不在算子注册的时候生成,而是用户在调用OpDefBuilder 的时候生成!
InferenceContext 不在算子注册的时候生成,而是用户在调用OpDefBuilder 的时候生成!
这也好理解,SetShapeFn 本身就是一个接口,肯定是有具体数据传进来的时候才有意义。理解了这一点,后面才能理解。
InferenceContext和很多已经定义好的接口都属作用域shape_inference,shape_inference中包含了很多接口和函数,例如已经定义好的shape_inference::MatMulShape,这些也包含了形状推断用到的所有信息。要想更好的了解InferenceContext 必须要要了解shape_inference 中的一些对象。
shape_inference 的具体代码可以见:tensorflow/core/framework/shape_inference.h, 其中包含6个主要对象: shapehandle,Shape,DimenionHandle,Dimension ; InferenceContext, shapemanager。
其中shapehandle,Shape,DimenionHandle,Dimension 四个对象的包含关系如下:
ShapeHandle:只有一个主要的属性就ptr,就是 Shape的指针
Shape 有两个主要属性,rank和dims, rank 是int类型表示tesnor的有多少个维度,dim表示类型是vector,表示这三个维度的宽度对象DimensionHandle
DimensionHandle 也只有一个主要属性,ptr,是指向Dimension的指针
Dimension : 则只有一个主要属性是int,表示宽度。
这样层层嵌套的关系有点像tensorflow的中的featurecolumn的结构。
举例来说 一个tensor的维度是[3,4,5] 那么 这个tensor的shape类的rank 就是3, DimensionHandle 本质就是三个指针,这三个指针指向的值就是3,4,5。而整个shape的指针就是shapehandle。
上面铺垫完了,正式讲InferenceContext,InferenceContext是一个对象,构造函数中,最重要的输入是
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
其中input_shapes表示输入的形状,input_tensors 表示输入的tensor。
最重要的几个属性如下
std::vector<ShapeHandle> inputs_;
std::vector<const Tensor*> input_tensors_;
std::vector<bool> requested_input_tensor_;
std::vector<ShapeHandle> outputs_;
// Can have fewer elements than inputs_.
std::vector<ShapeHandle> input_tensors_as_shapes_;
std::vector<bool> requested_input_tensor_as_partial_shape_;
inputs_和outputs_ 是最核心的两个属性,inputs_在构造函数中被赋值,就是上面传入的input_shapes。 outputs_一开始为空,整个SetShapeFn 就是为了给outputs_ 赋值,赋值以后形状推断即结束,outputs_ 就是输出的形状。这个形状会流向下一个节点,由下一个节点判断形状是否有问题。
算子生成和算子注册
我们自定义一个算子的本质是为了,利用这个算子生成一个op,所以必须要知道我们自定义的算子是怎么生成op的。
对于前面的过程我们已经理清楚了,过程是这样的:
调用REGISTER_OP 生成一个 OpDefBuilderWrapper 类,给OpDefBuilderWrapper 传入我们算子的信息:函数名,输入输出名称、格式,形状推断方法。OpDefBuilderWrapper 回生成一个OpDefBuilder,在load动态链接库时自动注册到tensorflow算子库中。
然后用户在Python侧调用这个函数的时候,就等于在调用相应的OpDefBuilder,这里需要注意的是,OpDefBuilder不是简单地生成一个OpDef,而是会生成一个结构体OpRegistrationData, 这个结构体包括两个成员:OpDef, OpShapeInferenceFn,其中OpShapeInferenceFn 是一个以InferenceContext 为输入的函数。用户会输入具体的tensor,输入的内容传给OpDef, OpShapeInferenceFn 完成op的定义以及形状的推断
借用
HaoBBNuanMMhttps://blog.csdn.net/HaoBBNuanMM
的一张图片就是