InternalRow体系
学习TreeNode之前,我们先了解下InternalRow。
对于我们一般接触到的数据库关系表来说,我们对于数据库中的数据操作都是按照“行”为单位的。在spark sql内部实现中,InternalRow是用来表示这一行行数据的类。看下源码中的解释,InternalRow作为一个抽象类,包numFields 和 update 方法,以及各列数据对应的 get 与 set 方法,但具体的实现逻辑体现在不同的子类中
/**
* An abstract class for row used internally in Spark SQL, which only contains the columns as
* internal types.
一个抽象类,用于表示spark SQL内部行,只包含内部类型的多个列(其实就是表示一行行数据的类)
*/
详细代码这里就不贴了,整理下一些重要接口的功能含义好了,注意InternalRow中都是根据下标来访问和操作列元素的 。
InternalRow实现类包括,BaseGenericinternalRow、UnsafeRow 和 JoinedRow 3 个直接子类
- BaseGenericinternalRow:也是一个抽象类,实现了SpecializedGetters类中定义的所有GET方法,但是最终还是调用genericGet方法实现最终逻辑,genericGet方法在BaseGenericinternalRow内中只是定义了一个接口,最终实现在BaseGenericinternalRow的子类中。
- JoinedRow:该类主要用于join操作,两个InternalRow放在一起形成新的InternalRow,在sparksql 聚合和join相关操作中,会用的比较多
- UnsafeRow:不采用 Java 对象存储的方式,避免了 JVM 中垃圾回收( GC )的代价 。 此外,UnsafeRow 对行数据进行了特定的编码,使得存储更加高效 。
TreeNode体系
接下来正式开始进行TreeNode的学习
TreeNode是Spark SQL中所有树结构的基类,定义了一系列通用的集合操作和树遍历的操作接口。我们先看下TreeNode的代码
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with TreePatternBits {
}
首先TreeNode是一个抽象类,一个泛型类;这里TreeNode[BaseType <: TreeNode[BaseType]]这种书写方式,不知道大家会不会很陌生,反正我一开始看的时候,觉得不知道咋回事,那么我们来一起理解写,这个具体是什么含义:
- 首先,我们很明确这个TreeNode是个泛型,我们把[]中的看作一个T,其实就是TreeNode[T],这个没问题
- 接下里,我们要理解下“<:”这个符号的含义,这属于scala泛型中的知识,上边界和下边界。上边界是“<:”,下边界是“>:”;上边界,拿代码中的定义的含义解释就是BaseType必须是TreeNode[BaseType]的子类。也就是说TreeNode的泛型类型用BaseType表示,泛型类型比如是TreeNode类的子类
另外,TreeNode还继承了Product接口,对于product接口相关使用介绍,请看这篇文章(scala之product特质理解_大家都叫我船长的博客-CSDN博客),看完应该就明白了。
接下来,开始详细看看TreeNode一些重要方法:
- 返回子节点,只定义了接口,具体实现在之类中
/**
* Returns a Seq of the children of this node.
* Children should not change. Immutability required for containsChild optimization
*/
def children: Seq[BaseType]
- 返回子节点的set集合
lazy val containsChild: Set[TreeNode[_]] = children.toSet
- 比较两个TreeNode是否相等
def fastEquals(other: TreeNode[_]): Boolean = {
this.eq(other) || this == other
}
- 查找第一个符合f条件的TreeNode
def find(f: BaseType => Boolean): Option[BaseType] = if (f(this)) {
Some(this)
} else {
children.foldLeft(Option.empty[BaseType]) { (l, r) => l.orElse(r.find(f)) }
}
- 将函数f 递归 应用于TreeNode节点以及所有子节点(先应用于parent,后应用child)
def foreach(f: BaseType => Unit): Unit = {
f(this)
children.foreach(_.foreach(f))
}
- 函数f 递归 应用于TreeNode节点以及所有子节点(先应用于child,后应用parent)
def foreachUp(f: BaseType => Unit): Unit = {
children.foreach(_.foreachUp(f))
f(this)
}
- 通过前序遍历的方式,将函数f递归应用于当前节点以及所有子节点,返回seq
def map[A](f: BaseType => A): Seq[A] = {
val ret = new collection.mutable.ArrayBuffer[A]()
foreach(ret += f(_))
ret.toSeq
}
- flatmap和上面的map整体一致,但是这里的函数f的返回值必须是集合类型,这里需要注意
def flatMap[A](f: BaseType => TraversableOnce[A]): Seq[A] = {
val ret = new collection.mutable.ArrayBuffer[A]()
foreach(ret ++= f(_)) //f返回的结果必须是一个集合
ret.toSeq
}
- collect ,这里使用到了scala的偏函数的使用(可以参考scala之偏函数学习_大家都叫我船长的博客-CSDN博客),对pf函数作用的所有节点返回为Some(B)都add到ret集合中,最终以seq的形式返回
def collect[B](pf: PartialFunction[BaseType, B]): Seq[B] = {
val ret = new collection.mutable.ArrayBuffer[B]()
val lifted = pf.lift
foreach(node => lifted(node).foreach(ret.+=))
ret.toSeq
}
- 返回当前节点的所有子节点
def collectLeaves(): Seq[BaseType] = {
this.collect { case p if p.children.isEmpty => p }
}
- 先序的方式访问所有节点,且返回第一个pf作用后结果不为None的节点
def collectFirst[B](pf: PartialFunction[BaseType, B]): Option[B] = {
val lifted = pf.lift
lifted(this).orElse {
children.foldLeft(Option.empty[B]) { (l, r) => l.orElse(r.collectFirst(pf)) }
}
}
-
mapProductIterator其实功能和productIterator.map(f).toArray一致
protected def mapProductIterator[B: ClassTag](f: Any => B): Array[B] = {
val arr = Array.ofDim[B](productArity)
var i = 0
while (i < arr.length) {
arr(i) = f(productElement(i))
i += 1
}
arr
}
- 将当前节点的子节点替换为新的子节点
inal def withNewChildren(newChildren: Seq[BaseType]): BaseType = {
val childrenIndexedSeq = asIndexedSeq(children)
val newChildrenIndexedSeq = asIndexedSeq(newChildren)
assert(newChildrenIndexedSeq.size == childrenIndexedSeq.size, "Incorrect number of children")
if (childrenIndexedSeq.isEmpty ||
childrenFastEquals(newChildrenIndexedSeq, childrenIndexedSeq)) {
this
} else {
CurrentOrigin.withOrigin(origin) {
val res = withNewChildrenInternal(newChildrenIndexedSeq)
res.copyTagsFrom(this)
res
}
}
}
- transfrom,调用transformDown,传入一个rule偏函数
def transform(rule: PartialFunction[BaseType, BaseType]): BaseType = {
transformDown(rule)
}
- transformDown,调用transformDownWithPruning,先序的方式使用rule作用于每个子节点,使用新的节点替换之前的,对节点不影响的,保留原来的节点
def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = {
transformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
}
def transformDownWithPruning(cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType])
: BaseType = {
if (!cond.apply(this) || isRuleIneffective(ruleId)) {
return this
}
val afterRule = CurrentOrigin.withOrigin(origin) {
// 如果 this 是 BaseType 或其子类,则对 this 应用 rule 再返回应用 rule 后的结果,否则返回 this
rule.applyOrElse(this, identity[BaseType])
}
// Check if unchanged and then possibly return old copy to avoid gc churn.
if (this fastEquals afterRule) {
// 如果应用了 rule 后节点无变化,则递归将 rule 应用于 children
val rewritten_plan = mapChildren(_.transformDownWithPruning(cond, ruleId)(rule))
if (this eq rewritten_plan) {
markRuleAsIneffective(ruleId)
this
} else {
rewritten_plan
}
} else {
// If the transform function replaces this node with a new one, carry over the tags.
// 如果应用了 rule 后节点有变化,则本节点换成变化后的节点(children 不变),再将 rule 递归应用于子节点。也就是从根节点往下来应用 rule 替换节点
afterRule.copyTagsFrom(this)
afterRule.mapChildren(_.transformDownWithPruning(cond, ruleId)(rule))
}
}
-
transformWithPruning,底层调用transformDownWithPruning(功能是返回此节点的副本,其中“规则”已递归应用于树。当“规则”不适用于给定节点时,它将保持不变)
def transformWithPruning(cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType])
: BaseType = {
transformDownWithPruning(cond, ruleId)(rule)
}
def transformDownWithPruning(cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType])
: BaseType = {
if (!cond.apply(this) || isRuleIneffective(ruleId)) {
return this
}
val afterRule = CurrentOrigin.withOrigin(origin) {
// 如果 this 是 BaseType 或其子类,则对 this 应用 rule 再返回应用 rule 后的结果,否则返回 this
rule.applyOrElse(this, identity[BaseType])
}
// Check if unchanged and then possibly return old copy to avoid gc churn.
if (this fastEquals afterRule) {
// 如果应用了 rule 后节点无变化,则递归将 rule 应用于 children
val rewritten_plan = mapChildren(_.transformDownWithPruning(cond, ruleId)(rule))
if (this eq rewritten_plan) {
markRuleAsIneffective(ruleId)
this
} else {
rewritten_plan
}
} else {
// If the transform function replaces this node with a new one, carry over the tags.
// 如果应用了 rule 后节点有变化,则本节点换成变化后的节点(children 不变),再将 rule 递归应用于子节点。也就是从根节点往下来应用 rule 替换节点
afterRule.copyTagsFrom(this)
afterRule.mapChildren(_.transformDownWithPruning(cond, ruleId)(rule))
}
}
-
transformUp 用后序遍历方式将规则作用于所有节点,调用transformUpWithPruning
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
transformUpWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
}
def transformUpWithPruning(cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType])
: BaseType = {
if (!cond.apply(this) || isRuleIneffective(ruleId)) {
return this
}
val afterRuleOnChildren = mapChildren(_.transformUpWithPruning(cond, ruleId)(rule))
val newNode = if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
}
} else {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
}
}
if (this eq newNode) {
markRuleAsIneffective(ruleId)
this
} else {
// If the transform function replaces this node with a new one, carry over the tags.
newNode.copyTagsFrom(this)
newNode
}
}
- mapChildren 返回
f
应用于所有子节点后该节点的 copy。
def mapChildren(f: BaseType => BaseType): BaseType = {
if (containsChild.nonEmpty) {
withNewChildren(children.map(f))
} else {
this
}
}
上面罗列的方法,基本就是TreeNode常用的,还有一些不常用的非核心的,这里就不一一介绍了,大家有兴趣的可以自己去看看源码。
另外TreeNode有两个子类,分别是Expression和QueryPlan,这篇文章我们就先讲到这里,后面会对这两个子类也会进行一一介绍的。
有兴趣的可以关注我,后面一起学习sparkSql源码,另外文章中有错误的地方,感谢指出哈。