自定义用户函数有两种方式,区别:是否使用强类型,参考demo:https://github.com/asker124143222/spark-demo
1、不使用强类型,继承UserDefinedAggregateFunction
package com.home.spark import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ object Ex_sparkUDAF { def main(args: Array[String]): Unit = { val conf = new SparkConf(true).setAppName("spark udf").setMaster("local[*]") val spark = SparkSession.builder().config(conf).getOrCreate() //自定义聚合函数 //创建聚合函数对象 val myUdaf = new MyAgeAvgFunc //注册自定义函数 spark.udf.register("ageAvg",myUdaf) //使用聚合函数 val frame: DataFrame = spark.read.json("input/userinfo.json") frame.createOrReplaceTempView("userinfo") spark.sql("select ageAvg(age) from userinfo").show() spark.stop() } } //声明自定义函数 //实现对年龄的平均,数据如:{ "name": "tom", "age" : 20} class MyAgeAvgFunc extends UserDefinedAggregateFunction { //函数输入的数据结构,本例中只有年龄是输入数据 override def inputSchema: StructType = { new StructType().add("age", LongType) } //计算时的数据结构(缓冲区) // 本例中有要计算年龄平均值,必须有两个计算结构,一个是年龄总计(sum),一个是年龄个数(count) override def bufferSchema: StructType = { new StructType().add("sum", LongType).add("count", LongType) } //函数返回的数据类型 override def dataType: DataType = DoubleType //函数是否稳定 override def deterministic: Boolean = true //计算前缓冲区的初始化,结构类似数组,这里缓冲区与之前定义的bufferSchema顺序一致 override def initialize(buffer: MutableAggregationBuffer): Unit = { //sum buffer(0) = 0L //count buffer(1) = 0L } //根据查询结果更新缓冲区数据,input是每次进入的数据,其数据结构与之前定义的inputSchema相同 //本例中每次输入的数据只有一个就是年龄 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { if(input.isNullAt(0)) return //sum buffer(0) = buffer.getLong(0) + input.getLong(0) //count,每次来一个数据加1 buffer(1) = buffer.getLong(1) + 1 } //将多个节点的缓冲区合并到一起(因为spark是分布式的) override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { //sum buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) //count buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) } //计算最终结果,本例中就是(sum / count) override def evaluate(buffer: Row): Any = { buffer.getLong(0).toDouble / buffer.getLong(1) } }
2、使用强类型,
package com.home.spark import org.apache.spark.SparkConf import org.apache.spark.sql._ import org.apache.spark.sql.expressions.Aggregator object Ex_sparkUDAF2 { def main(args: Array[String]): Unit = { val conf = new SparkConf(true).setAppName("spark udf class").setMaster("local[*]") val spark = SparkSession.builder().config(conf).getOrCreate() //rdd转换成df或者ds需要SparkSession实例的隐式转换 //导入隐式转换,注意这里的spark不是包名,而是SparkSession的对象名 import spark.implicits._ //创建聚合函数对象 val myAvgFunc = new MyAgeAvgClassFunc val avgCol: TypedColumn[UserBean, Double] = myAvgFunc.toColumn.name("avgAge") val frame = spark.read.json("input/userinfo.json") val userDS: Dataset[UserBean] = frame.as[UserBean] //应用函数 userDS.select(avgCol).show() spark.stop() } } case class UserBean(name: String, age: BigInt) case class AvgBuffer(var sum: BigInt, var count: Int) //声明用户自定义函数(强类型方式) //继承Aggregator,设定泛型 //实现方法 class MyAgeAvgClassFunc extends Aggregator[UserBean, AvgBuffer, Double] { //初始化缓冲区 override def zero: AvgBuffer = { AvgBuffer(0, 0) } //聚合数据 override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = { if(a.age == null) return b b.sum = b.sum + a.age b.count = b.count + 1 b } //缓冲区合并操作 override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = { b1.sum = b1.sum + b2.sum b1.count = b1.count + b2.count b1 } //完成计算 override def finish(reduction: AvgBuffer): Double = { reduction.sum.toDouble / reduction.count } override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product override def outputEncoder: Encoder[Double] = Encoders.scalaDouble }
继承Aggregator