1.自定义弱类型UDAF
1.1 弱类型UDAF定义
弱类型UDAF继承实现 UserDefinedAggregateFunction 抽象类
override def inputSchema: StructType = 输入schema
override def bufferSchema: StructType = 聚合过程schema
override def dataType: DataType = 返回值类型
override def deterministic: Boolean = 是否固定返回值类型
override def initialize(buffer: MutableAggregationBuffer): Unit = 初始化函数,用来初始化基准值
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = 分区内元素如何聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = 分区之间如何聚合
override def evaluate(buffer: Row): Any = 聚合结果计算
整个UDAF处理过程,非常类似RDD的aggregate算子
aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U
一个自定义求平均数UDAF例子
object UDAFApp extends App{ val spark = SparkSession.builder().master("local[2]").appName("UDAP-App").getOrCreate(); import spark.implicits._; val df = spark.read.format("json").load("D:\data\employees.json") //UDAF函数注册 只有UserDefinedAggregateFunction才能为SQL注册函数 spark.udf.register("cusAvg",MyAvgUDAF) //DF转临时视图 df.createTempView("employees_view") spark.sql("select cusAvg(salary) as salary from employees_view").show(); //df-api形式 df.select(MyAvgUDAF.apply($"salary")).show() spark.close() } object MyAvgUDAF extends UserDefinedAggregateFunction { //输入schema override def inputSchema: StructType = StructType(StructField("input",DoubleType)::Nil); //聚合过程schema override def bufferSchema: StructType = StructType(StructField("Sum",DoubleType)::StructField("Count",LongType)::Nil) //返回值类型 override def dataType: DataType = DoubleType //是否固定返回值类型 override def deterministic: Boolean = true //初始化函数 override def initialize(buffer: MutableAggregationBuffer): Unit = { //设定聚合基准初始值 aggregate算子((0,0))的部分 buffer(0) = 0D; //总和0 buffer(1) = 0L; //个数0 } override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { //行第一列(Row[0])是否为null if(!input.isNullAt(0)){ //aggregate算子....(seqOp: (U, T) => U 部分 buffer(0)= buffer.getDouble(0)+ input.getDouble(0); buffer(1) =buffer.getLong(1)+1; } } override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { //aggregate算子....combOp: (U, U) => U 部分 buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0); buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1); } override def evaluate(buffer: Row): Any = buffer.getDouble(0) /buffer.getLong(1) ; }
2.自定义强类型UDAF
自定义强类型UDAF 基础实现类 Aggregator
所以这种定义方式不能在UDF中注册,也不能用在SQL中
一个强类型UDAF定义如下:
object UDAFApp extends App{ val spark = SparkSession.builder().master("local[2]").appName("UDAP-App").getOrCreate(); import spark.implicits._; val ds = spark.read.format("json").load("D:\data\employees.json").as[Employee] //ds-api形式 ds.select(MyAverage.toColumn.name("salary")).show() spark.close() } //目标类型定义 case class Employee(val name: String,val salary: Long) //聚合类型定义 case class Average(var sum: Long, var count: Long) object MyAverage extends Aggregator[Employee, Average, Double] { override def zero: Average = Average(0,0) override def reduce(b: Average, a: Employee): Average = { b.sum += a.salary; b.count += 1 b } override def merge(b1: Average, b2: Average): Average = { b1.sum += b2.sum; b1.count += b2.count; b1; } override def finish(reduction: Average): Double = { println(reduction.sum + " "+ reduction.count) reduction.sum.toDouble/reduction.count } override def bufferEncoder: Encoder[Average] = Encoders.product override def outputEncoder: Encoder[Double] = Encoders.scalaDouble }