用户自定义函数
UDF函数
在操作关系型数据库时,Spark支持大部分常用SQL函数,而有些函数Spark官方并没有支持,需要根据业务自行创建。这些函数成为用户自定义函数(user defined function, UDF)。
接受一个参数,返回一个结果。即一进一出的函数。
实例
实现一个UDF,将name列中的用户名称全部转换为大写字母。
spark.udf.register("toUpperCaseUDF", (column : String) => column.toUpperCase)
spark.sql("SELECT toUpperCaseUDF(name), age FROM t_user").show
UDAF函数
用户自定义聚合函数(user defined aggregation function, UDAF),该类型函数可以接受并处理多个参数(某一列多个行中的值),之后返回一个值,属于多进一出的函数。
开发者可以通过继承UserDefinedAggregateFunction抽象类来实现UDAF。继承该类需要覆写8个抽象方法。
object AverageUDAF extends UserDefindAggregationFunction {}
def inputSchema : StructType
def bufferSchema : StructType
def dataType : DataType
def deterministic : Boolean
def initialize(buffer : MutableAggregationBuffer) : Unit
def update(buffer : MutableAggregationBuffer, input : Row) : Unit
def merge(buffer1 : MutableAggregationBuffer, buffer2 : Row) : Unit
def evaluate(buffer : Row) : Any
在聚合过程中,用于存放累加数据的容器是MutableAggregationBuffer类型的实例,该类型继承自Row类型。整个聚合过程就是将原始表的某一列的多个Row实例取出,将对应列中所有待聚合的值累加到缓冲区的Row实例中。
实例
求每个性别的平均年龄
//inputSchema来指定调用avgUDAF函数时传入的参数类型
override def inputSchema: StructType = {
StructType(
List(
StructField("numInput", DoubleType, nullable = true)
)
)
}
//bufferSchema设置UDAF在聚合过程中的缓冲区保存数据的类型,一个参数是年龄总和,一个参数是累加人数
override def bufferSchema: StructType = {
StructType(
List(
StructField("buffer1", DoubleType, nullable = true)
StructField("buffer2", LongType, nullable = true)
)
)
}
//dataType设置UDAF运算结束时返回的数据类型
override def dataType: DataType = DoubleType
//deterministic判断UDAF可接收的参数类型与返回的结果类型是否一致
override def deteministic: Boolean = true
//initialize初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0.0
buffer(1) = 0L
}
//update用于控制具体的聚合逻辑,通过update方法,将每行参与运算的列累加到聚合缓冲区的Row实例中
//每访问一行,都会调用一次update方法。
override def update(buffer: MutableAggregation, input: Row): Unit = {
buffer.update(0, buffer.getDouble(0) + input.getDouble(0))
buffer.update(1, buffer.getLong(1) + 1)
}
//merge用于合并每个分区聚合缓冲区的值
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getDouble(0) + buffer2.getDouble(0))
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
}
//evaluate方法用于对聚合缓冲区的数据进行最后一次运算
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0) / buffer.getLong(1)
}
在创建完AverageUDAF类后,要注册UDAF
spark.udf.register("toDouble", (column: Any) => column.toString.toDouble)
spark.udf.register("avgUDAF", AverageUDAF)
spark.sql("SELECT sex, avgUDAF(toDOUble(age)) as avgAge FROM t_user GROUP BY sex").show
UDTF函数
用户自定义表生成函数。该类型函数可以将一行中的某一列数据展开,变为基于这一列展开后的多行数据。可以通过DataFrame执行flatMap函数来实现“列转行”。一进多出。
实例
val tableArray = df1.flatMap(row => {
val listTuple = new scala.collection.mutable.ListBuffer[(String, String)] ()
val categoryArray = row.getString(1).split(",")
for(c <- categoryArray) {
listTuple.append((row.getString(0), c))
}
listTuple
}).collect()
val df2 = spark.createDataFrame(tableArray).toDF("movie", "category")
df.show