• [Spark]-结构化数据查询之自定义UDAF


    1.自定义弱类型UDAF

      1.1 弱类型UDAF定义

        弱类型UDAF继承实现 UserDefinedAggregateFunction 抽象类

        override def inputSchema: StructType = 输入schema

        override def bufferSchema: StructType = 聚合过程schema

        override def dataType: DataType = 返回值类型

        override def deterministic: Boolean = 是否固定返回值类型

        override def initialize(buffer: MutableAggregationBuffer): Unit = 初始化函数,用来初始化基准值

        override def update(buffer: MutableAggregationBuffer, input: Row): Unit = 分区内元素如何聚合

        override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = 分区之间如何聚合

        override def evaluate(buffer: Row): Any = 聚合结果计算

        整个UDAF处理过程,非常类似RDD的aggregate算子

          aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U

        一个自定义求平均数UDAF例子

                object UDAFApp extends App{
                
                val spark = SparkSession.builder().master("local[2]").appName("UDAP-App").getOrCreate();
                import  spark.implicits._;
                val df = spark.read.format("json").load("D:\data\employees.json")
                
                //UDAF函数注册 只有UserDefinedAggregateFunction才能为SQL注册函数
                spark.udf.register("cusAvg",MyAvgUDAF)
                //DF转临时视图
                df.createTempView("employees_view")
                spark.sql("select cusAvg(salary) as salary from employees_view").show();
                
                //df-api形式
                df.select(MyAvgUDAF.apply($"salary")).show()
                spark.close()
                
                }
                
                object MyAvgUDAF extends UserDefinedAggregateFunction
                {
                //输入schema
                override def inputSchema: StructType = StructType(StructField("input",DoubleType)::Nil);
                //聚合过程schema
                override def bufferSchema: StructType = StructType(StructField("Sum",DoubleType)::StructField("Count",LongType)::Nil)
                //返回值类型
                override def dataType: DataType = DoubleType
                
                //是否固定返回值类型
                override def deterministic: Boolean = true
                
                //初始化函数
                override def initialize(buffer: MutableAggregationBuffer): Unit = {
                    //设定聚合基准初始值 aggregate算子((0,0))的部分
                    buffer(0) = 0D; //总和0
                    buffer(1) = 0L; //个数0
                }
                
                override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
                    //行第一列(Row[0])是否为null
                    if(!input.isNullAt(0)){
                    //aggregate算子....(seqOp: (U, T) => U 部分
                    buffer(0)= buffer.getDouble(0)+ input.getDouble(0);
                    buffer(1) =buffer.getLong(1)+1;
                    }
                }
                
                override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
                    //aggregate算子....combOp: (U, U) => U 部分
                    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0);
                    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1);
                }
                
                override def evaluate(buffer: Row): Any = buffer.getDouble(0) /buffer.getLong(1) ;
                }

    2.自定义强类型UDAF

      自定义强类型UDAF 基础实现类 Aggregator 

      所以这种定义方式不能在UDF中注册,也不能用在SQL中

      一个强类型UDAF定义如下:    

                object UDAFApp extends App{
                
                val spark = SparkSession.builder().master("local[2]").appName("UDAP-App").getOrCreate();
                import  spark.implicits._;
                val ds = spark.read.format("json").load("D:\data\employees.json").as[Employee]
                
                //ds-api形式
                ds.select(MyAverage.toColumn.name("salary")).show()
                spark.close()
                
                }
                
                //目标类型定义
                case class Employee(val name: String,val salary: Long)
                //聚合类型定义
                case class Average(var sum: Long, var count: Long)
                object MyAverage extends Aggregator[Employee, Average, Double]  {
                override def zero: Average = Average(0,0)
                
                override def reduce(b: Average, a: Employee): Average = {
                    b.sum += a.salary;
                    b.count +=  1
                    b
                }
                
                override def merge(b1: Average, b2: Average): Average = {
                    b1.sum += b2.sum;
                    b1.count += b2.count;
                    b1;
                }
                
                override def finish(reduction: Average): Double = {
                    println(reduction.sum + "  "+ reduction.count)
                    reduction.sum.toDouble/reduction.count
                }
                
                override def bufferEncoder: Encoder[Average] = Encoders.product
                
                override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
                }
  • 相关阅读:
    LeetCode 1110. Delete Nodes And Return Forest
    LeetCode 473. Matchsticks to Square
    LeetCode 886. Possible Bipartition
    LeetCode 737. Sentence Similarity II
    LeetCode 734. Sentence Similarity
    LeetCode 491. Increasing Subsequences
    LeetCode 1020. Number of Enclaves
    LeetCode 531. Lonely Pixel I
    LeetCode 1091. Shortest Path in Binary Matrix
    LeetCode 590. N-ary Tree Postorder Traversal
  • 原文地址:https://www.cnblogs.com/NightPxy/p/9269171.html
Copyright © 2020-2023  润新知