• SparkSQL 如何自定义函数


    1. SparkSql如何自定义函数

    2. 示例:Average

    3. 类型安全的自定义函数

    1. SparkSql如何自定义函数?

      spark中我们定义一个函数,需要继承 UserDefinedAggregateFunction这个抽象类,实现这个抽象类中所定义的方法,这是一个模板设计模式? 我只要实现抽象类的中方法,具体的所有的计算步骤由内部完成。而我们可以看一下UserDefinedAggregateFunction这个抽象类。

    package org.apache.spark.sql.expressions
    @org.apache.spark.annotation.InterfaceStability.Stable
    abstract class UserDefinedAggregateFunction() extends scala.AnyRef with scala.Serializable { def inputSchema : org.apache.spark.sql.types.StructType def bufferSchema : org.apache.spark.sql.types.StructType def dataType : org.apache.spark.sql.types.DataType def deterministic : scala.Boolean def initialize(buffer : org.apache.spark.sql.expressions.MutableAggregationBuffer) : scala.Unit def update(buffer : org.apache.spark.sql.expressions.MutableAggregationBuffer, input : org.apache.spark.sql.Row) : scala.Unit def merge(buffer1 : org.apache.spark.sql.expressions.MutableAggregationBuffer, buffer2 : org.apache.spark.sql.Row) : scala.Unit def evaluate(buffer : org.apache.spark.sql.Row) : scala.Any @scala.annotation.varargs def apply(exprs : org.apache.spark.sql.Column*) : org.apache.spark.sql.Column = { /* compiled code */ } @scala.annotation.varargs def distinct(exprs : org.apache.spark.sql.Column*) : org.apache.spark.sql.Column = { /* compiled code */ } }

      也就是说对于这几个函数,我们只要依次实现他们的功能,其余的交给spark就可以了。

      

    2. 自定义Average函数

      首先新建一个Object类MyAvage类,继承UserDefinedAggregateFunction。下面对每一个函数的实现进行解释。

      def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
    

      这个规定了输入数据的数据结构

    def bufferSchema: StructType = {
        StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
      }

      这个规定了缓存区的数据结构

      def dataType: DataType = DoubleType
    

      这个规定了返回值的数据类型

    def deterministic: Boolean = true
    def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = 0L
        buffer(1) = 0L
      }  

    进行初始化,这里要说明一下,官网中提到:

    // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
      // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
      // the opportunity to update its values. Note that arrays and maps inside the buffer are still
      // immutable.

    这里翻译一下:

    我们为我们的缓冲区设置初始值,我们不仅可以设置数字,还可以使用index getBoolen等去改变他的值,但是我们需要知道的是,在这个缓冲区中,数组和map依然是不可变的。

    其实最后一句我也是不太明白,等我以后如果能研究并理解这句话,再回来补充吧。

    def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        if (!input.isNullAt(0)) {
          buffer(0) = buffer.getLong(0) + input.getLong(0)
          buffer(1) = buffer.getLong(1) + 1
        }
      }
    

      这个是重要的update函数,对于平均值,我们可以不断迭代输入的值进行累加。buffer(0)统计总和,buffer(1)统计长度。

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
        buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
      }
    

      在做完update后spark 需要将结果进行merge到我们的区域,因此有一个merge 进行覆盖buffer

      def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
    

      这是将最终的结果进行计算。

    在写完这个类以后我们在我们的sparksession里面进行编写测试案例。

    spark.sparkContext.textFile("file:///Users/4pa/Desktop/people.txt")
          .map(_.split(","))
          .map(agg=>Person(agg(0),agg(1).trim.toInt))
          .toDF().createOrReplaceTempView("people")
    spark.udf.register("myAverage",Myaverage)
    val udfRes = spark.sql("select name,myAverage(age) as avgAge from people group by name")
    udfRes.show()
    

      

    3. 类型安全的自定义函数

    从上面我们可以看出来,这种自定义函数不是类型安全的,因此能否实现一个安全的自定义函数呢?

    个人觉得最好的例子还是官网给的例子,具体的解释都已经给了出来,思路其实和上面是一样的,只不过定义了两个caseclass,用于类型的验证。

    case class Employee(name: String, salary: Long)
    case class Average(var sum: Long, var count: Long)
    
    object MyAverage extends Aggregator[Employee, Average, Double] {
      // 初始化
      def zero: Average = Average(0L, 0L)
      // 这个其实有点map-reduce的意思,只不过是对一个类的reduce,第一个值是和,第二个是总数
      def reduce(buffer: Average, employee: Employee): Average = {
        buffer.sum += employee.salary
        buffer.count += 1
        buffer
      }
      // 实现缓冲区的一个覆盖
      def merge(b1: Average, b2: Average): Average = {
        b1.sum += b2.sum
        b1.count += b2.count
        b1
      }
      // 计算最终数值
      def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
      // Specifies the Encoder for the intermediate value type
      def bufferEncoder: Encoder[Average] = Encoders.product
      // 指定返回类型
      def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    }
    

      

     

  • 相关阅读:
    单例模式
    mysql之group_concat函数详解
    json中如何将key中的引号去掉
    show status,修改mysql用户密码 使用
    ThinkPHP连贯查询之子查询
    输入1-53周,输出1-53周的开始时间和结束时间
    Java编辑环境搭建
    Java语言简介
    html中iframe根据子页面内容动态修改高度
    JavaScript---通过正则表达式验证表单输入
  • 原文地址:https://www.cnblogs.com/tjpeng/p/12261901.html
Copyright © 2020-2023  润新知