Spark Sql 4/5

news2025/1/16 9:03:12

4. 用户自定义函数

通过spark.udf功能用户可以自定义函数。

4.1用户自定义UDF函数

Shell
scala> val df = spark.read.json("examples/src/main/resources/people.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, name: string]

scala> df.show()
+----+-------+
| age|   name|
+----+-------+
|null|Michael|
|  30|   Andy|
|  19| Justin|
+----+-------+


scala> spark.udf.register("addName", (x:String)=> "Name:"+x)
res5: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))

scala> df.createOrReplaceTempView("people")

scala> spark.sql("Select addName(name), age from people").show()
+-----------------+----+
|UDF:addName(name)| age|
+-----------------+----+
|     Name:Michael|null|
|        Name:Andy|  30|
|      Name:Justin|  19|
+-----------------+----+

UDF案例2

需求,有如下数据

Plain Text
id,name,age,height,weight,yanzhi,score
1,a,18,172,120,98,68.8
2,b,28,175,120,97,68.8
3,c,30,180,130,94,88.8
4,d,18,168,110,98,68.8
5,e,26,165,120,98,68.8
6,f,27,182,135,95,89.8
7,g,19,171,122,99,68.8

需要计算每一个人和其他人之间的余弦相似度(特征向量之间的余弦相似度)

 

代码实现:

Scala
package cn.doitedu.sparksql.udf

import cn.doitedu.sparksql.dataframe.SparkUtil
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.UserDefinedFunction

import scala.collection.mutable


/**
  * UDF 案例2 : 用一个自定义函数实现两个向量之间的余弦相似度计算
  */

case class Human(id: Int, name: String, features: Array[Double])

object CosinSimilarity {

  def main(args: Array[String]): Unit = {


    val spark = SparkUtil.getSpark()
    import spark.implicits._
    import spark.sql
    // 加载用户特征数据
    val df = spark.read.option("inferSchema", true).option("header", true).csv("data/features.csv")
    df.show()



    // id,name,age,height,weight,yanzhi,score
    // 将用户特征数据组成一个向量(数组)
    // 方式1:
    df.rdd.map(row => {
      val id = row.getAs[Int]("id")
      val name = row.getAs[String]("name")
      val age = row.getAs[Double]("age")
      val height = row.getAs[Double]("height")
      val weight = row.getAs[Double]("weight")
      val yanzhi = row.getAs[Double]("yanzhi")
      val score = row.getAs[Double]("score")

      (id, name, Array(age, height, weight, yanzhi, score))
    }).toDF("id", "name", "features")

    // 方式2:
    df.rdd.map({
      case Row(id: Int, name: String, age: Double, height: Double, weight: Double, yanzhi: Double, score: Double)
      => (id, name, Array(age, height, weight, yanzhi, score))
    })
      .toDF("id", "name", "features")


    // 方式3: 直接利用sql中的函数array来生成一个数组
    df.selectExpr("id", "name", "array(age,height,weight,yanzhi,score) as features")
    import org.apache.spark.sql.functions._
    df.select('id, 'name, array('age, 'height, 'weight, 'yanzhi, 'score) as "features")

    // 方式4:返回case class
    val features = df.rdd.map({
      case Row(id: Int, name: String, age: Double, height: Double, weight: Double, yanzhi: Double, score: Double)
      => Human(id, name, Array(age, height, weight, yanzhi, score))
    })
      .toDF()

    // 将表自己和自己join,得到每个人和其他所有人的连接行
    val joined = features.join(features.toDF("bid","bname","bfeatures"),'id < 'bid)
    joined.show(100,false)

    // 定义一个计算余弦相似度的函数
    // val cosinSim = (f1:Array[Double],f2:Array[Double])=>{ /* 余弦相似度 */ }
    // 开根号的api:  Math.pow(4.0,0.5)
    val cosinSim = (f1:mutable.WrappedArray[Double], f2:mutable.WrappedArray[Double])=>{

      val fenmu1 = Math.pow(f1.map(Math.pow(_,2)).sum,0.5)
      val fenmu2 = Math.pow(f2.map(Math.pow(_,2)).sum,0.5)

      val fenzi = f1.zip(f2).map(tp=>tp._1*tp._2).sum

      fenzi/(fenmu1*fenmu2)
    }

    // 注册到sql引擎:  spark.udf.register("cosin_sim",consinSim)
    spark.udf.register("cos_sim",cosinSim)
    joined.createTempView("temp")

    // 然后在这个表上计算两人之间的余弦相似度
    sql("select id,bid,cos_sim(features,bfeatures) as cos_similary from temp").show()

    // 可以自定义函数简单包装一下,就成为一个能生成column结果的dsl风格函数了
    val cossim2: UserDefinedFunction = udf(cosinSim)
    joined.select('id,'bid,cossim2('features,'bfeatures) as "cos_sim").show()

    spark.close()
  }
}

4.2用户自定义聚合函数UDAF

弱类型的DataFrame和强类型的Dataset都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。

除此之外,用户可以设定自己的自定义UDAF聚合函数。

UDAF的编程模板:

/**

 * @date: 2019/10/12

 * @site: www.doitedu.cn

 * @author: hunter.d 涛哥

 * @qq: 657270652

 * @description:

  *    用户自定义UDAF入门示例:求薪资的平均值

 */

object MyAvgUDAF extends UserDefinedAggregateFunction{

  // 函数输入的字段schema(字段名-字段类型)

  override def inputSchema: StructType = ???

  // 聚合过程中,用于存储局部聚合结果的schema

  // 比如求平均薪资,中间缓存(局部数据薪资总和,局部数据人数总和)

  override def bufferSchema: StructType = ???

  // 函数的最终返回结果数据类型

  override def dataType: DataType = ???

  // 你这个函数是否是稳定一致的?(对一组相同的输入,永远返回相同的结果),只要是确定的,就写true

  override def deterministic: Boolean = true

  // 对局部聚合缓存的初始化方法

  override def initialize(buffer: MutableAggregationBuffer): Unit = ???

  // 聚合逻辑所在方法,框架会不断地传入一个新的输入row,来更新你的聚合缓存数据

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???

  // 全局聚合:将多个局部缓存中的数据,聚合成一个缓存

  // 比如:薪资和薪资累加,人数和人数累加

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???

  // 最终输出

  // 比如:从全局缓存中取薪资总和/人数总和

  override def evaluate(buffer: Row): Any = ???

}

核心要义:

聚合是分步骤进行: 先局部聚合,再全局聚合

局部聚合(update)的结果是保存在一个局部buffer中的

全局聚合(merge)就是将多个局部buffer再聚合成一个buffer

最后通过evaluate将全局聚合的buffer中的数据做一个运算得出你要的结果

如下图所示:

 

4.2.1弱类型用户自定义聚合函数UDAF

(1)需求说明

示例数据:

+---+----------------+------+---------+------+----------+

| id|    name        | sales|discount |state |  saleDate|

+---+----------------+------+---------+------+----------+

|  1|       Widget Co|1000.0|      0.0|    AZ|2014-01-01|

|  2|   Acme Widgets |2000.0|    500.0|    CA|2014-02-01|

|  3|        Widgetry|1000.0|    200.0|    CA|2015-01-11|

|  4|   Widgets R Us |2000.0|      0.0|    CA|2015-02-19|

|  5|Ye Olde Widgete |3000.0|      0.0|    MA|2015-02-28|

+---+---------------+------+--------+-----+-------------+

需求:计算x年份的同比上一年份的总销售增长率;比如2015 vs 2014的同比增长

显然,没有任何一个内置聚合函数可以完成上述需求;

可以多写一些sql逻辑来实现,但如果能自定义一个聚合函数,当然更方便高效!

Select yearOnyear(saleDate,sales)  from t

(2)自定义UDAF实现销售额同比计算

通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。

自定义UDAF的代码骨架如下:

class UdfMy extends UserDefinedAggregateFunction{

  override def inputSchema: StructType = ???

  override def bufferSchema: StructType = ???

  override def dataType: DataType = ???

  override def deterministic: Boolean = ???

  override def initialize(buffer: MutableAggregationBuffer): Unit = ???

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???

  override def evaluate(buffer: Row): Any = ???

}

完整实现代码如下:

/**

  * 工具类

  * @param startDate

  * @param endDate

  */

case class DateRange(startDate: Timestamp, endDate: Timestamp) {

  def contain(targetDate: Date): Boolean = {

    targetDate.before(endDate) && targetDate.after(startDate)

  }

}

/**

 * @date: 2019/10/10

 * @site: www.doitedu.cn

 * @author: hunter.d 涛哥

 * @qq: 657270652

 * @description: 自定义UDAF实现年份销售额同比增长计算

 */

class YearOnYearBasis(current: DateRange) extends  UserDefinedAggregateFunction{

  // 聚合函数输入参数的数据类型

  override def inputSchema: StructType = {

    StructType(StructField("metric", DoubleType) :: StructField("timeCategory", DateType) :: Nil)

  }

  // 聚合缓冲区中值得数据类型

  override def bufferSchema: StructType = {

    StructType(StructField("sumOfCurrent", DoubleType) :: StructField("sumOfPrevious", DoubleType) :: Nil)

  }

  // 返回值的数据类型

  override def dataType: DataType = DoubleType

  // 对于相同的输入是否一直返回相同的输出。

  override def deterministic: Boolean = true

  // 初始化

  override def initialize(buffer: MutableAggregationBuffer): Unit = {

    buffer.update(0, 0.0)

    buffer.update(1, 0.0)

  }

  // 相同Execute间的数据合并。

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit =  {

    if (current.contain(input.getAs[Date](1))) {

      buffer(0) = buffer.getAs[Double](0) + input.getAs[Double](0)

    }

    val previous = DateRange(subtractOneYear(current.startDate), subtractOneYear(current.endDate))

    if (previous.contain(input.getAs[Date](1))) {

      buffer(1) = buffer.getAs[Double](0) + input.getAs[Double](0)

    }

  }

  // 不同Execute间的数据合并

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)

    buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)

  }

  // 计算最终结果

  override def evaluate(buffer: Row): Any = {

    if (buffer.getDouble(1) == 0.0)

      0.0

    else

      (buffer.getDouble(0) - buffer.getDouble(1)) / buffer.getDouble(1) * 100

  }

  def subtractOneYear(d:Timestamp):Timestamp={

    Timestamp.valueOf(d.toLocalDateTime.minusYears(1))

  }

}

(3)补充示例:自定义UDAF实现平均薪资计算

下面展示一个求平均工资的自定义聚合函数。

Scala
package cn.doitedu.sparksql.udf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}

/**
  * @description:
  * 用户自定义UDAF入门示例:求薪资的平均值
  */
object MyAvgUDAF extends UserDefinedAggregateFunction {

  // 函数输入的字段schema(字段名-字段类型)
  override def inputSchema: StructType = StructType(Seq(StructField("salary", DataTypes.DoubleType)))

  // 聚合过程中,用于存储局部聚合结果的schema
  // 比如求平均薪资,中间缓存(局部数据薪资总和,局部数据人数总和)
  override def bufferSchema: StructType = StructType(Seq(
    StructField("sum", DataTypes.DoubleType),
    StructField("cnts", DataTypes.LongType)

  ))

  // 函数的最终返回结果数据类型
  override def dataType: DataType = DataTypes.DoubleType

  // 你这个函数是否是稳定一致的?(对一组相同的输入,永远返回相同的结果),只要是确定的,就写true
  override def deterministic: Boolean = true

  // 对局部聚合缓存的初始化方法
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, 0.0)
    buffer.update(1, 0L)
  }

  // 聚合逻辑所在方法,框架会不断地传入一个新的输入row,来更新你的聚合缓存数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

    // 从输入中获取那个人的薪资,加到buffer的第一个字段上
    buffer.update(0, buffer.getDouble(0) + input.getDouble(0))

    // 给buffer的第2个字段加1
    buffer.update(1, buffer.getLong(1) + 1)

  }

  // 全局聚合:将多个局部缓存中的数据,聚合成一个缓存
  // 比如:薪资和薪资累加,人数和人数累加
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    // 把两个buffer的字段1(薪资和)累加到一起,并更新回buffer1
    buffer1.update(0, buffer1.getDouble(0) + buffer2.getDouble(0))

    // 更新人数
    buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))

  }

  // 最终输出
  // 比如:从全局缓存中取薪资总和/人数总和
  override def evaluate(buffer: Row): Any = {

    if (buffer.getLong(1) != 0)
      buffer.getDouble(0) / buffer.getLong(1)
    else
      0.0

  }
}

4.2.2强类型用户自定义聚合函数

通过继承Aggregator来实现强类型自定义聚合函数,同样是求平均工资

Scala
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.SparkSession

// 既然是强类型,可能有case类
case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)

object MyAverage extends Aggregator[Employee, Average, Double] {
    // 定义一个数据结构,保存工资总数和工资总个数,初始都为0
    def zero: Average = Average(0L, 0L)
    // Combine two values to produce a new value. For performance, the function may modify `buffer`
    // and return it instead of constructing a new object
    def reduce(buffer: Average, employee: Employee): Average = {
    buffer.sum += employee.salary
    buffer.count += 1
    buffer
    }
    // 聚合不同execute的结果
    def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
    }
    // 计算输出
    def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
    // 设定之间值类型的编码器,要转换成case类
    // Encoders.product是进行scala元组和case类转换的编码器
    def bufferEncoder: Encoder[Average] = Encoders.product
    // 设定最终输出值的编码器
    def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    }
    import spark.implicits._
​    
    val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
    ds.show()
    // +-------+------+
    // |   name|salary|
    // +-------+------+
    // |Michael|  3000|
    // |   Andy|  4500|
    // | Justin|  3500|
    // |  Berta|  4000|
    // +-------+------+
​    
    // Convert the function to a `TypedColumn` and give it a name
    val averageSalary = MyAverage.toColumn.name("average_salary")
    val result = ds.select(averageSalary)
    result.show()
    // +--------------+
    // |average_salary|
    // +--------------+
    // |        3750.0|
    // +--------------+
}

5. Spark SQL 的运行原理

正常的 SQL 执行先会经过 SQL Parser 解析 SQL,然后经过 Catalyst 优化器处理,最后到 Spark 执行。而 Catalyst 的过程又分为很多个过程,其中包括:

  • Analysis:主要利用 Catalog 信息将 Unresolved Logical Plan 解析成 Analyzed logical plan;
  • Logical Optimizations:利用一些 Rule (规则)将 Analyzed logical plan 解析成 Optimized Logical Plan;
  • Physical Planning:前面的 logical plan 不能被 Spark 执行,而这个过程是把 logical plan 转换成多个 physical plans,然后利用代价模型(cost model)选择最佳的 physical plan;
  • Code Generation:这个过程会把 SQL逻辑生成Java字节码。

所以整个 SQL 的执行过程可以使用下图表示:

 

其中蓝色部分就是 Catalyst 优化器处理的部分,也是本章主要讲解的内容。

5.1 元数据管理SessionCatalog

SessionCatalog 主要用于各种函数资源信息和元数据信息(数据库、数据表、数据视图、数据分区与函数等)的统一管理。

创建临时表或者视图,其实是往SessionCatalog注册;

Analyzer在进行逻辑计划元数据绑定时,也是从catalog中获取元数据;

5.2 SQL解析成逻辑执行计划

当调用SparkSession的sql或者SQLContext的sql方法,就会使用SparkSqlParser进行SQL解析。

Spark 2.0.0开始引入了第三方语法解析器工具 ANTLR,对 SQL 进行词法分析并构建语法树。

(Antlr 是一款强大的语法生成器工具,可用于读取、处理、执行和翻译结构化的文本或二进制文件,是当前 Java 语言中使用最为广泛的语法生成器工具,我们常见的大数据 SQL 解析都用到了这个工具,包括 Hive、Cassandra、Phoenix、Pig 以及 presto 等)目前最新版本的 Spark 使用的是 ANTLR4)

它分为2个步骤来生成Unresolved LogicalPlan:

  • 词法分析(SqlBaseLexer):Lexical Analysis,负责将token分组成符号类
  • 语法分析(SqlBaseParser):构建一棵分析树(parse tree)或者抽象语法树AST(abstract syntax tree)

Scala
/**
 * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
 * TableIdentifier.
 */
class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging {
  import ParserUtils._

  def this() = this(new SQLConf())

  protected def typedVisit[T](ctx: ParseTree): T = {
  
  ...
  }
}

具体来说,Spark 基于presto的语法文件定义了Spark SQL语法文件SqlBase.g4

(路径 spark-2.4.3\sql\catalyst\src\main\antlr4\org\apache\spark\sql\catalyst\parser\SqlBase.g4)

这个文件定义了 Spark SQL 支持的 SQL 语法。

 

如果我们需要自定义新的语法,需要在这个文件定义好相关语法。然后使用 ANTLR4 对 SqlBase.g4 文件自动解析生成几个 Java 类,其中就包含重要的词法分析器 SqlBaseLexer.java 和语法分析器SqlBaseParser.java。运行上面的 SQL 会使用 SqlBaseLexer 来解析关键词以及各种标识符等;然后使用 SqlBaseParser 来构建语法树。

下面以一条简单的 SQL 为例进行分析

SQL
SELECT sum(v)
    FROM (
      SELECT
        t1.id,
    1 + 2 + t1.value AS v
    FROM t1 JOIN t2
      WHERE
    t1.id = t2.id AND
    t1.cid = 1 AND
    t1.did = t1.cid + 1 AND
      t2.id > 5) o

整个过程就类似于下图。

 

生成语法树之后,使用 AstBuilder 将语法树转换成 LogicalPlan,这个 LogicalPlan 也被称为 Unresolved LogicalPlan。解析后的逻辑计划如下:

Plain Text
== Parsed Logical Plan ==
'Project [unresolvedalias('sum('v), None)]
+- 'SubqueryAlias `doitedu_stu`
   +- 'Project ['t1.id, ((1 + 2) + 't1.value) AS v#16]
      +- 'Filter ((('t1.id = 't2.id) && ('t1.cid = 1)) && (('t1.did = ('t1.cid + 1)) && ('t2.id > 5)))
         +- 'Join Inner
            :- 'UnresolvedRelation `t1`
            +- 'UnresolvedRelation `t2`

图片表示如下:

 

Unresolved LogicalPlan 是从下往上看的,t1 和 t2 两张表被生成了 UnresolvedRelation,过滤的条件、选择的列以及聚合字段都知道了。

Unresolved LogicalPlan 仅仅是一种数据结构,不包含任何数据信息,比如不知道数据源、数据类型,不同的列来自于哪张表等。

5.3 Analyzer绑定逻辑计划

Analyzer 阶段会使用事先定义好的 Rule 以及 SessionCatalog 等信息对 Unresolved LogicalPlan 进行元数据绑定。

Scala
/**
 * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
 * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]].
 */
class Analyzer(
    catalog: SessionCatalog,
    conf: SQLConf,
    maxIterations: Int)
  extends RuleExecutor[LogicalPlan] with CheckAnalysis {


class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser(conf) {
  val astBuilder = new SparkSqlAstBuilder(conf)


   override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
    astBuilder.visitSingleStatement(parser.singleStatement()) match {
      case plan: LogicalPlan => plan
      case _ =>
        val position = Origin(None, None)
        throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position)
    }
  }

Rule 是定义在 Analyzer 里面的,具体如下:
lazy val batches: Seq[Batch] = Seq(
    Batch("Hints", fixedPoint,
      new ResolveHints.ResolveBroadcastHints(conf),
      ResolveHints.ResolveCoalesceHints,
      ResolveHints.RemoveAllHints),
    Batch("Simple Sanity Check", Once,
      LookupFunctions),
    Batch("Substitution", fixedPoint,
      CTESubstitution,
      WindowsSubstitution,
      EliminateUnions,
      new SubstituteUnresolvedOrdinals(conf)),
    Batch("Resolution", fixedPoint,
      ResolveTableValuedFunctions ::                    //解析表的函数
      ResolveRelations ::                               //解析表或视图
      ResolveReferences ::                              //解析列
      ResolveCreateNamedStruct ::
      ResolveDeserializer ::                            //解析反序列化操作类
      ResolveNewInstance ::
      ResolveUpCast ::                                  //解析类型转换
      ResolveGroupingAnalytics ::
      ResolvePivot ::
      ResolveOrdinalInOrderByAndGroupBy ::
      ResolveAggAliasInGroupBy ::
      ResolveMissingReferences ::
      ExtractGenerator ::
      ResolveGenerate ::
      ResolveFunctions ::                               //解析函数
      ResolveAliases ::                                 //解析表别名
      ResolveSubquery ::                                //解析子查询
      ResolveSubqueryColumnAliases ::
      ResolveWindowOrder ::
      ResolveWindowFrame ::
      ResolveNaturalAndUsingJoin ::
      ResolveOutputRelation ::
      ExtractWindowExpressions ::
      GlobalAggregates ::
      ResolveAggregateFunctions ::
      TimeWindowing ::
      ResolveInlineTables(conf) ::
      ResolveHigherOrderFunctions(catalog) ::
      ResolveLambdaVariables(conf) ::
      ResolveTimeZone(conf) ::
      ResolveRandomSeed ::
      TypeCoercion.typeCoercionRules(conf) ++
      extendedResolutionRules : _*),
    Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
    Batch("View", Once,
      AliasViewChild(conf)),
    Batch("Nondeterministic", Once,
      PullOutNondeterministic),
    Batch("UDF", Once,
      HandleNullInputsForUDF),
    Batch("FixNullability", Once,
      FixNullability),
    Batch("Subquery", Once,
      UpdateOuterReferences),
    Batch("Cleanup", fixedPoint,
      CleanupAliases)
)

从上面代码可以看出,多个性质类似的 Rule 组成一个 Batch;而多个 Batch 构成一个 batches。这些 batches 会由 RuleExecutor 执行,先按一个一个 Batch 顺序执行,然后对 Batch 里面的每个 Rule 顺序执行。每个 Batch 会执行一次(Once)或多次(FixedPoint,由spark.sql.optimizer.maxIterations 参数决定),执行过程如下:

 

5.4 Optimizer优化逻辑计划

优化器也是会定义一套Rules,利用这些Rule对逻辑计划和Exepression进行迭代处理,从而使得树的节点进行合并和优化

 

 

在前文的绑定逻辑计划阶段对 Unresolved LogicalPlan 进行相关 transform 操作得到了 Analyzed Logical Plan,这个 Analyzed Logical Plan 是可以直接转换成 Physical Plan 然后在spark中执行。但是如果直接这么弄的话,得到的 Physical Plan 很可能不是最优的,因为在实际应用中,很多低效的写法会带来执行效率的问题,需要进一步对Analyzed Logical Plan 进行处理,得到更优的逻辑算子树。于是,针对SQL 逻辑算子树的优化器 Optimizer 应运而生。

这个阶段的优化器主要是基于规则的(Rule-based Optimizer,简称 RBO),而绝大部分的规则都是启发式规则,也就是基于直观或经验而得出的规则,比如列裁剪(过滤掉查询不需要使用到的列)、谓词下推(将过滤尽可能地下沉到数据源端)、常量累加(比如 1 + 2 这种事先计算好) 以及常量替换(比如 SELECT * FROM table WHERE i = 5 AND j = i + 3 可以转换成 SELECT * FROM table WHERE i = 5 AND j = 8)等等。

与绑定逻辑计划阶段类似,这个阶段所有的规则也是实现 Rule 抽象类,多个规则组成一个 Batch,多个 Batch 组成一个 batches,同样也是在 RuleExecutor 中进行执行。

核心源码骨架如下列截图所示:      

 

 

 

 

 

那么针对前文的 SQL 语句,这个过程都会执行哪些优化呢?下文举例说明。

5.4.1谓词下推

谓词下推在 Spark SQL 是由 PushDownPredicate 实现的,这个过程主要将过滤条件尽可能地下推到底层,最好是数据源。上面介绍的 SQL,使用谓词下推优化得到的逻辑计划如下:

 

从上图可以看出,谓词下推将 Filter 算子直接下推到 Join 之前了(注意,上图是从下往上看的)。也就是在扫描 t1 表的时候会先使用 ((((isnotnull(cid#2) && isnotnull(did#3)) && (cid#2 = 1)) && (did#3 = 2)) && (id#0 > 50000)) && isnotnull(id#0) 过滤条件过滤出满足条件的数据;同时在扫描 t2 表的时候会先使用 isnotnull(id#8) && (id#8 > 50000) 过滤条件过滤出满足条件的数据。经过这样的操作,可以大大减少 Join 算子处理的数据量,从而加快计算速度。

5.4.2列裁剪

列裁剪在 Spark SQL 是由 ColumnPruning 实现的。因为我们查询的表可能有很多个字段,但是每次查询我们很大可能不需要扫描出所有的字段,这个时候利用列裁剪可以把那些查询不需要的字段过滤掉,使得扫描的数据量减少。所以针对我们上面介绍的 SQL,使用列裁剪优化得到的逻辑计划如下:

 

从上图可以看出,经过列裁剪后,t1 表只需要查询 id 和 value 两个字段;t2 表只需要查询 id 字段。这样减少了数据的传输,而且如果底层的文件格式为列存(比如 Parquet),可以大大提高数据的扫描速度的。

 

5.4.3常量替换

常量替换在 Spark SQL 是由 ConstantPropagation 实现的。也就是将变量替换成常量,比如 SELECT * FROM table WHERE i = 5 AND j = i + 3 可以转换成 SELECT * FROM table WHERE i = 5 AND j = 8。这个看起来好像没什么的,但是如果扫描的行数非常多可以减少很多的计算时间的开销的。经过这个优化,得到的逻辑计划如下:

我们的查询中有 t1.cid = 1 AND t1.did = t1.cid + 1 查询语句,从里面可以看出 t1.cid 其实已经是确定的值了,所以我们完全可以使用它计算出 t1.did。

5.4.4常量累加

常量累加在 Spark SQL 是由 ConstantFolding 实现的。这个和常量替换类似,也是在这个阶段把一些常量表达式事先计算好。这个看起来改动的不大,但是在数据量非常大的时候可以减少大量的计算,减少 CPU 等资源的使用。经过这个优化,得到的逻辑计划如下:

 

经过上面四个步骤的优化之后,得到的优化之后的逻辑计划为:

Plain Text
== Optimized Logical Plan ==
Aggregate [sum(cast(v#16 as bigint)) AS sum(v)#22L]
+- Project [(3 + value#1) AS v#16]
   +- Join Inner, (id#0 = id#8)
      :- Project [id#0, value#1]
      :  +- Filter (((((isnotnull(cid#2) && isnotnull(did#3)) && (cid#2 = 1)) && (did#3 = 2)) && (id#0 > 5)) && isnotnull(id#0))
      :     +- Relation[id#0,value#1,cid#2,did#3] csv
      +- Project [id#8]
         +- Filter (isnotnull(id#8) && (id#8 > 5))
            +- Relation[id#8,value#9,cid#10,did#11] csv

对应的图如下:

 

到这里,优化逻辑计划阶段就算完成了。另外,Spark 内置提供了多达70个优化 Rule,详情请参见

https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala#L59

5.5使用SparkPlanner生成物理计划

SparkSpanner使用Planning Strategies,对优化后的逻辑计划进行转换,生成可以执行的物理计划SparkPlan.

Scala
/**
 * 将逻辑计划转成物理计划的抽象类.
 * 各实现类通过各种GenericStrategy来生成各种可行的待选物理计划.
 * 如一个策略无法对逻辑计划树的所有操作转换,则会调用[GenericStrategy#planLater planLater]], 来获得       一个“占位符”对象暂时填充;之后由[[collectPlaceholders collected]]收集并使用其他策略进行转换

 * TODO: 目前为止,永远只生成一个物理计划
 *       后续迭代中会对“多计划”予以实现
 */
abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] {
  /** A list of execution strategies that can be used by the planner */
  def strategies: Seq[GenericStrategy[PhysicalPlan]]

  def plan(plan: LogicalPlan): Iterator[PhysicalPlan] = {
    // 显然,此处还有大量工作需要做,可依然...

    // 收集所有可选的物理计划.
    val candidates = strategies.iterator.flatMap(_(plan))

abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
  self: SparkPlanner =>

  /**
   * Plans special cases of limit operators.
   */
  object SpecialLimits extends Strategy {
      
   
class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val experimentalMethods: ExperimentalMethods)
  extends SparkStrategies {      

逻辑计划翻译成物理计划时,使用的是策略(Strategy);

前面介绍的逻辑计划绑定和优化经过 Transformations 动作之后,树的类型并没有改变,

Logical Plan转化成物理计划后,树的类型改变了,由 Logical Plan 转换成 Physical Plan 了。

一个逻辑计划(Logical Plan)经过一系列的策略处理之后,得到多个物理计划(Physical Plans),物理计划在 Spark 是由 SparkPlan 实现的。

多个物理计划经过代价模型(Cost Model)得到选择后的物理计划(Selected Physical Plan),整个过程如下所示:

 

Cost Model 对应的就是基于代价的优化(Cost-based Optimizations,CBO,主要由华为的大佬们实现的,详见 SPARK-16026 ),核心思想是计算每个物理计划的代价,然后得到最优的物理计划。目前,这一部分并没有实现,直接返回多个物理计划列表的第一个作为最优的物理计划,如下:

Scala
lazy val sparkPlan: SparkPlan = {
    SparkSession.setActiveSession(sparkSession)
    // TODO: We use next(), i.e. take the first plan returned by the planner, here for now,
    //       but we will implement to choose the best plan.
    planner.plan(ReturnAnswer(optimizedPlan)).next()
}

而 SPARK-16026 引入的 CBO 优化主要是在前面介绍的优化逻辑计划阶段 - Optimizer 阶段进行的,对应的 Rule 为 CostBasedJoinReorder,并且默认是关闭的,需要通过 spark.sql.cbo.enabled 或 spark.sql.cbo.joinReorder.enabled 参数开启。

所以到了这个节点,最后得到的物理计划如下:

Plain Text
== Physical Plan ==
*(3) HashAggregate(keys=[], functions=[sum(cast(v#16 as bigint))], output=[sum(v)#22L])
+- Exchange SinglePartition
   +- *(2) HashAggregate(keys=[], functions=[partial_sum(cast(v#16 as bigint))], output=[sum#24L])
      +- *(2) Project [(3 + value#1) AS v#16]
         +- *(2) BroadcastHashJoin [id#0], [id#8], Inner, BuildRight
            :- *(2) Project [id#0, value#1]
            :  +- *(2) Filter (((((isnotnull(cid#2) && isnotnull(did#3)) && (cid#2 = 1)) && (did#3 = 2)) && (id#0 > 5)) && isnotnull(id#0))
            :     +- *(2) FileScan csv [id#0,value#1,cid#2,did#3] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/iteblog/t1.csv], PartitionFilters: [], PushedFilters: [IsNotNull(cid), IsNotNull(did), EqualTo(cid,1), EqualTo(did,2), GreaterThan(id,5), IsNotNull(id)], ReadSchema: struct<id:int,value:int,cid:int,did:int>
            +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)))
               +- *(1) Project [id#8]
                  +- *(1) Filter (isnotnull(id#8) && (id#8 > 5))
                     +- *(1) FileScan csv [id#8] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/iteblog/t2.csv], PartitionFilters: [], PushedFilters: [IsNotNull(id), GreaterThan(id,5)], ReadSchema: struct<id:int>

从上面的结果可以看出,物理计划阶段已经知道数据源是从 csv 文件里面读取了,也知道文件的路径,数据类型等。而且在读取文件的时候,直接将过滤条件(PushedFilters)加进去了。

同时,这个 Join 变成了 BroadcastHashJoin,也就是将 t2 表的数据 Broadcast 到 t1 表所在的节点。图表示如下:

 

到这里, Physical Plan 就完全生成了。

5.6从物理执行计划获取inputRdd执行

从物理计划上,获取inputRdd

从物理计划上,生成全阶段代码,并编译反射出迭代器newBiIterator的Clazz

 [真名:BufferedRowIterator]

然后将inputRDD做一个transformation得到最终要执行的rdd

Scala
inputRdd.mapPartitionsWithIndex((index,iter)=>{
   new newBiIterator(){
     hasNext(){
         iter.hasNext
}
     next(){
         processNext(iter.next())
}
}
})

然后,对最后返回的rdd,执行你所需要的行动算子
rdd.collect().foreach(println)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/719969.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PSI算法经典论文算法概述

文章目录 什么是隐私求交PSIPSI协议分类PSI算法的分类基于哈希函数的PSI算法基于不经意传输&#xff08;OT&#xff09;的 PSI算法基于GC的PSI算法基于公钥加密的PSI算法基于DH的PSI算法基于RSA盲签名的PSI算法基于同态加密的PSI算法 基于差分隐私的PSI算法 总结参考文献 什么是…

wails+vue3实现一个简单Monitor

介绍 本来呢最近是在学Rust,顺便看看Tauri相关的内容.然后刷评论区突然看到有人提到go生态中也有类似的框架—Wails,所以下午花了点时间来动手玩一下. 首先看一下最终的运行效果,前端样式懒得调整所以界面很丑只是实现一下功能 开始 这次的目标就是做一个功能类似于nvidia-s…

C#基础学习_字段与属性的比较

C#基础学习_字段与属性的比较 字段: 字段主要是为类的内部做数据交互使用,字段一般是private修饰; 字段可以赋值也可以读取; 当需要为外部提供数据的时候,请将字段封装为属性,而不是使用公有字段,这是面对对象编程所提倡的。 //字段:学号private int studentID;属性: …

语义分割大模型RSPrompter论文阅读

论文链接 RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model 开源代码链接 RSPrompter 论文阅读 摘要 Abstract—Leveraging vast training data (SA-1B), the foundation Segment Anything Model (SAM) propo…

vue动态组件component详解

附上代码 <template><div class"export-full-data-manage"><div class"main"><div class"left"><ul><li v-for"item in menus" :key"item.value" :class"[item.valuecurrent?curre…

【UE5 Cesium】11-Cesium for Unreal 切换Dynamic Pawn为其它Pawn

前言 我们知道在Cesium for Unreal中默认使用的是DynamicPawn来浏览地图场景。DynamicPawn适用全球浏览&#xff0c;可以按自定义曲线进行飞行。但是DynamicPawn是使用的是地理参考坐标系&#xff0c;并不是标准的UE坐标系&#xff0c;当我们全球浏览结束后&#xff0c;可能需要…

2023年6月榜单丨飞瓜数据B站UP主排行榜(哔哩哔哩)发布!

飞瓜轻数发布2023年6月飞瓜数据UP主排行榜&#xff08;B站平台&#xff09;&#xff0c;通过充电数、涨粉数、成长指数三个维度来体现UP主账号成长的情况&#xff0c;为用户提供B站号综合价值的数据参考&#xff0c;根据UP主成长情况用户能够快速找到运营能力强的B站UP主。 飞…

工作是EXCEL的天下

文章目录 EXCEL单元格内的换行筛选某一列的重复值批量删除重复值以某一列为联结&#xff0c;合并两个表格中的内容 本篇博文记录了笔者在工作中常用的EXCEL操作方法&#xff0c;持续更新中…… EXCEL单元格内的换行 AltEnter 筛选某一列的重复值 选中需要查找重复值的一列→…

如何在Microsoft Word中制作组织架构图

如果要说明公司或组织中的报告关系,可以创建一个使用组织结构图布局的 SmartArt 图形,如组织结构图。 注意:绘制组织结构图的另一种方法是使用 Microsoft 绘图应用程序 Visio。 使用 SmartArt 图形在 Excel、Outlook、PowerPoint 或 Word 中创建组织结构图,以显示组织中的…

磁盘镜像软件

什么是磁盘镜像 磁盘镜像是存储在计算机磁盘中的数据的副本或副本。磁盘镜像将包含数据存储设备的内容&#xff0c;并复制此类设备的结构。它还将包含操作系统分区。 磁盘镜像本质上是一种从主系统复制操作系统和存储在磁盘中的数据以将其分发到其他目标计算机的方法。自动化…

4.40ue4:样条线(轨迹)

1.创建样条线&#xff08;样条组建&#xff09; spline 多出一个点&#xff0c;按住alt拖住断点可以再生成一个点 可以在场景中拖动点编辑样条线 如果想要直角&#xff0c;可以点击细节面板找到spline组件&#xff0c;修改类型为linear&#xff0c;前后两点都需要改为绝对值&a…

3d渲染画面变形怎么办?

在用3dmax渲染图片时有时会遇到画面变形的情况&#xff0c;这个是什么原因呢&#xff1f;今天我们就来看看吧。 首先我们来看下变形的具体情况&#xff0c;再分析原因。可以看到整个画面都畸变了&#xff0c;呈现出上下拉伸的情况&#xff0c;能造成这个效果的&#xff0c;只有…

集合面试题详解

算法复杂度分析 List ArrayList 数据结构-数组 ArrayList源码分析 成员变量 构造方法 关键方法 ArrayList底层实现原理 如何实现数组和ArrayList之间的转换 LinkedList 单向链表 双向链表 ArrayList和LinkedList的区别 HashMap 二叉树 红黑树 散列表 HashMap的实现原理 Ha…

《C++ Primer》--学习12

动态内存 动态内存与智能指针 除了静态内存和栈内存&#xff0c;每个程序还拥有一个内存池。这部分内存被称为自由空间 或 堆&#xff0c;程序用堆来存储动态分配的对象 动态内存和智能指针 智能指针负责自动释放所指向的对象 shared_ptr 类 智能指针也是模板

一篇读懂React、vue框架的生命周期函数

当涉及到前端框架时&#xff0c;React 和 Vue.js 是两个非常受欢迎的选择。它们都提供了强大的工具和功能&#xff0c;帮助开发者构建交互式的、可扩展的应用程序。在这两个框架中&#xff0c;生命周期函数是一个重要的概念&#xff0c;它们允许我们在组件的不同阶段执行特定的…

车规芯片物理实现上的难度

在PR部分&#xff0c;车规级芯片的温度范围根据芯片工作位置不同也有所差别&#xff0c;最差的位置需要-40到150&#xff0c;除了DFM和功耗以及EM比较严格外&#xff0c;有些产品物理实现上对运算逻辑单元的位置也有特殊要求。 DFT部分才是车规级芯片的难点&#xff0c;感兴趣…

软件设计模式与体系结构-设计模式-行为型软件设计模式-迭代器模式

行为型软件设计模式 概述 行为型设计模式是软件设计模式中的一类&#xff0c;用于处理对象之间的交互和通信。这些模式关注的是对象之间的行为和职责分配。以下是几种常见的行为型设计模式&#xff1a; 观察者模式&#xff08;Observer Pattern&#xff09;&#xff1a;定义了…

【网络安全带你练爬虫-100练】第4练:添加异常处理代码

目录 一、异常处理代码&#xff1a; 二、执行结果&#xff1a; 三、完整代码&#xff1a; &#xff08;当代码越来越长的时候&#xff0c;异常处理代码有时候能起到很好的作用&#xff09; 一、异常处理代码&#xff1a; &#xff08;1&#xff09;try-except搭配&#xff…

Stable Difussion 解决绘图图片灰暗模糊,让图像色彩更丰富

在使用Stable Difussion进行AI绘画的时候&#xff0c;使用VAE能够产生许多有趣的效果。其滤镜功能可根据需要调整图像的色彩饱和度&#xff0c;使图像产生不同的视觉效果。 如下图所示&#xff1a; 但是如果不使用VAE那你的图像可能就会变得灰暗。 文章目录 VAE用途VAE的使…

cspm是什么?对比pmp怎么样?

一、国标项目管理&#xff08;项目管理专业人员能力评级&#xff09;证书是什么&#xff1f; 证书样式 《项目管理专业人员能力评价要求》&#xff08;GB/T 41831-2022&#xff09;是2022年10月12日开始实施的一项中国国家标准&#xff0c;归口于全国项目管理标准化技术委员会。…