• SparkSQL之UDAF使用


    1.创建一个类继承UserDefinedAggregateFunction类。

    ---------------------------------------------------------------------

    package cn.piesat.test

    import org.apache.spark.sql.Row
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types.{DataType, DataTypes, IntegerType, StructType}

    class CountUDAF extends UserDefinedAggregateFunction{
    /**
    * 聚合函数的输入类型
    * @return
    */
    override def inputSchema: StructType = {
    new StructType().add("ageType",IntegerType)
    }

    /**
    * 缓存的数据类型
    * @return
    */
    override def bufferSchema: StructType = {
    new StructType().add("bufferAgeType",IntegerType)
    }

    /**
    * UDAF返回值的类型
    * @return
    */
    override def dataType: DataType = {
    DataTypes.StringType
    }

    /**
    * 如果该函数是确定性的,那么将会返回true,一般给true就行。
    * @return
    */
    override def deterministic: Boolean = true

    /**
    * 为每个分组的数据执行初始化操作
    * @param buffer
    */
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0
    }

    /**
    * 更新操作,指的是每个分组有新的值进来的时候,如何进行分组对应的聚合值的计算
    * @param buffer
    * @param input
    */
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val num= input.getAs[Int](0)
    buffer(0)=buffer.getAs[Int](0)+num
    }

    /**
    * 分区合并时执行的操作
    * @param buffer1
    * @param buffer2
    */
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0)=buffer1.getAs[Int](0)+buffer2.getAs[Int](0)
    }

    /**
    * 最后返回的结果
    * @param buffer
    * @return
    */
    override def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0).toString
    }
    }
    --------------------------------------------------------------


    2.在main函数中使用样例
    ---------------------------------------------------------------
    package cn.piesat.test

    import org.apache.spark.sql.SparkSession

    import scala.collection.mutable.ArrayBuffer


    object SparkSQLTest {

    def main(args: Array[String]): Unit = {
    val spark=SparkSession.builder().appName("sparkSql").master("local[4]")
    .config("spark.serializer","org.apache.spark.serializer.KryoSerializer").getOrCreate()
    val sc=spark.sparkContext
    val sqlContext=spark.sqlContext
    val workerRDD=sc.textFile("F://Workers.txt").mapPartitions(itor=>{
    val array=new ArrayBuffer[Worker]()
    while(itor.hasNext){
    val splited=itor.next().split(",")
    array.append(new Worker(splited(0),splited(2).toInt,splited(2)))
    }
    array.toIterator
    })
    import spark.implicits._
    //注册UDAF
    spark.udf.register("countUDF",new CountUDAF())
    val workDS=workerRDD.toDS()
    workDS.createOrReplaceTempView("worker")
    val resultDF=spark.sql("select countUDF(age) from worker")
    val resultDS=resultDF.as("WO")
    resultDS.show()

    spark.stop()

    }
    }
    -----------------------------------------------------------------------------------------------
  • 相关阅读:
    村上春树的《海边的卡夫卡》与中日现实
    熊的甜蜜世界
    VS创建dll和调用dll
    DIRECTSHOW在VS2005中PVOID64问题和配置问题
    Vs 2008 解决方案的目录结构设置和管理
    SQL Server 2008中的代码安全(二):DDL触发器与登录触发器
    如何在自动SGA管理模式下调节参数设置
    将ORACLE数据库从归档改成非归档状态
    查看oracle数据库是否归档和修改归档模式(转)
    oracle TRANSLATE函数详解
  • 原文地址:https://www.cnblogs.com/runnerjack/p/10662338.html
Copyright © 2020-2023  润新知