Spark SQL UDF和UDAF
/**
* scala代码
*/
package com.tom.spark.sql
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
/**
* UDF:User Defined Function, 用户自定义的函数,函数的输入是一条具体的数据记录,实现上讲就是普通的scala函数;
* UDAF:User Defined Aggregation Function, 用户自定义的聚合函数,函数本身作用于数据集合,能够在聚合操作的基础上进行自定义操作;
* 实质上讲,例如说UDF会被Spark SQL中的catalyst封装成为expression,最终会通过eval方法来计算输入的输入Row,此处的Row和DataFrame
* 中的Row没有任何关系
*/
object SparkSQLUDFUDAF {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local[4]").setAppName("SparkSQLUDFUDAF")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
//模拟实际使用的数据
val bigData = Array("Spark", "Spark", "Hadoop", "spark", "Hadoop", "spark", "Hadoop", "Hadoop", "spark", "spark")
/**
* 基于提供的数据创建DataFrame
*/
val bigDataRdd = sc.parallelize(bigData)
val bigDataRDDRow = bigDataRdd.map(item => {Row(item)})
val structType = StructType(Array(
new StructField("word", StringType, true)
))
val bigDataDF = sqlContext.createDataFrame(bigDataRDDRow, structType)
bigDataDF.registerTempTable("bigDataTable") //注册成为临时表
/**
* 通过SQLContext注册UDF,在Scala 2.10.x版本UDF函数最多可以接收22个输入参数
*/
sqlContext.udf.register("computeLength", (input: String) => input.length)
//直接在sql中使用udf,就像使用SQL自带的内部函数一样
sqlContext.sql("select word, computeLength(word) as length from bigDataTable").show
sqlContext.udf.register("wordcount", new MyUDAF)
sqlContext.sql("select word, wordcount(word) as count,computeLength(word) as length " +
"from bigDataTable group by word").show
// while(true){}
}
}
/**
* 按照模板实现UDAF
*/
class MyUDAF extends UserDefinedAggregateFunction {
/**
* 该方法指定具体输入数据的类型
* @return
*/
override def inputSchema: StructType = StructType(Array(StructField("input", StringType, true)))
/**
* 在进行聚合操作的时候所要处理的数据的结果的类型
* @return
*/
override def bufferSchema: StructType = StructType(Array(StructField("count", IntegerType, true)))
/**
* 指定UDAF函数计算后返回的结果类型
* @return
*/
override def dataType: DataType = IntegerType
/**
* 确保一致性,一般都用true
* @return
*/
override def deterministic: Boolean = true
/**
* 在Aggregate之前每组数据的初始化结果
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0 }
/**
* 在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
* 本地的聚合操作,相当于Hadoop MapReduce模型中的Combiner
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Int](0) + 1
}
/**
* 最后在分布式节点进行Local Reduce完成后需要进行全局级别的Merge操作
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
}
/**
* 返回UDAF最后的计算结果
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = buffer.getAs[Int](0)
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117