• Spark 自定义函数(udf,udaf)


    Spark 版本 2.3

    文中测试数据(json)

    {"name":"lillcol", "age":24,"ip":"192.168.0.8"}
    {"name":"adson", "age":100,"ip":"192.168.255.1"}
    {"name":"wuli", "age":39,"ip":"192.143.255.1"}
    {"name":"gu", "age":20,"ip":"192.168.255.1"}
    {"name":"ason", "age":15,"ip":"243.168.255.9"}
    {"name":"tianba", "age":1,"ip":"108.168.255.1"}
    {"name":"clearlove", "age":25,"ip":"222.168.255.110"}
    {"name":"clearlove", "age":30,"ip":"222.168.255.110"}
    

    用户自定义udf

    自定义udf的方式有两种

    1. SQLContext.udf.register()
    2. 创建UserDefinedFunction

    这两种个方式 使用范围不一样

    package com.test.spark
    
    import org.apache.spark.sql.expressions.UserDefinedFunction
    import org.apache.spark.sql.functions.udf
    import org.apache.spark.sql.{Dataset, Row, SparkSession}
    
    /**
      * @author Administrator
      *         2019/7/22-14:04
      *
      */
    object TestUdf {
    
      val spark = SparkSession
        .builder()
        .appName("TestCreateDataset")
        .config("spark.some.config.option", "some-value")
        .master("local")
        .enableHiveSupport()
        .getOrCreate()
      val sQLContext = spark.sqlContext
    
      import spark.implicits._
    
    
      def main(args: Array[String]): Unit = {
        testudf
      }
    
      def testudf() = {
        val iptoLong: UserDefinedFunction = getIpToLong()
        val ds: Dataset[Row] = spark.read.json("D:\DATA-LG\PUBLIC\TYGQ\INF\testJson")
        ds.createOrReplaceTempView("table1")
        sQLContext.udf.register("addName", sqlUdf(_: String)) //addName 只能在SQL里面用  不能在DSL 里面用
        //1.SQL
        sQLContext.sql("select *,addName(name) as nameAddName  from table1")
          .show()
        //2.DSL
        val addName: UserDefinedFunction = udf((str: String) => ("ip: " + str))
        ds.select($"*", addName($"ip").as("ipAddName"))
          .show()
    
        //如果自定义函数相对复杂,可以将它分离出去 如iptoLong
        ds.select($"*", iptoLong($"ip").as("iptoLong"))
          .show()
      }
    
      def sqlUdf(name: String): String = {
        "name:" + name
      }
    
      /**
        * 用户自定义 UDF 函数
        *
        * @return
        */
      def getIpToLong(): UserDefinedFunction = {
        val ipToLong: UserDefinedFunction = udf((ip: String) => {
          val arr: Array[String] = ip.replace(" ", "").replace(""", "").split("\.")
          var result: Long = 0
          var ipl: Long = 0
          if (arr.length == 4) {
            for (i <- 0 to 3) {
              ipl = arr(i).toLong
              result |= ipl << ((3 - i) << 3)
            }
          } else {
            result = -1
          }
          result
        })
        ipToLong
      }
    
    
    }
    
    输出结果
    +---+---------------+---------+--------------+
    |age|             ip|     name|   nameAddName|
    +---+---------------+---------+--------------+
    | 24|    192.168.0.8|  lillcol|  name:lillcol|
    |100|  192.168.255.1|    adson|    name:adson|
    | 39|  192.143.255.1|     wuli|     name:wuli|
    | 20|  192.168.255.1|       gu|       name:gu|
    | 15|  243.168.255.9|     ason|     name:ason|
    |  1|  108.168.255.1|   tianba|   name:tianba|
    | 25|222.168.255.110|clearlove|name:clearlove|
    | 30|222.168.255.110|clearlove|name:clearlove|
    +---+---------------+---------+--------------+
    
    +---+---------------+---------+-------------------+
    |age|             ip|     name|          ipAddName|
    +---+---------------+---------+-------------------+
    | 24|    192.168.0.8|  lillcol|    ip: 192.168.0.8|
    |100|  192.168.255.1|    adson|  ip: 192.168.255.1|
    | 39|  192.143.255.1|     wuli|  ip: 192.143.255.1|
    | 20|  192.168.255.1|       gu|  ip: 192.168.255.1|
    | 15|  243.168.255.9|     ason|  ip: 243.168.255.9|
    |  1|  108.168.255.1|   tianba|  ip: 108.168.255.1|
    | 25|222.168.255.110|clearlove|ip: 222.168.255.110|
    | 30|222.168.255.110|clearlove|ip: 222.168.255.110|
    +---+---------------+---------+-------------------+
    
    +---+---------------+---------+----------+
    |age|             ip|     name|  iptoLong|
    +---+---------------+---------+----------+
    | 24|    192.168.0.8|  lillcol|3232235528|
    |100|  192.168.255.1|    adson|3232300801|
    | 39|  192.143.255.1|     wuli|3230662401|
    | 20|  192.168.255.1|       gu|3232300801|
    | 15|  243.168.255.9|     ason|4087938825|
    |  1|  108.168.255.1|   tianba|1823014657|
    | 25|222.168.255.110|clearlove|3735617390|
    | 30|222.168.255.110|clearlove|3735617390|
    +---+---------------+---------+----------+
    

    用户自定义 UDAF 函数(即聚合函数)

    弱类型用户自定义聚合函数

    通过继承UserDefinedAggregateFunction

    package com.test.spark
    
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.{Dataset, Row, SparkSession}
    
    /**
      * @author lillcol
      *         2019/7/22-15:09
      *         弱类型用户自定义聚合函数
      */
    object TestUDAF extends UserDefinedAggregateFunction {
      // 聚合函数输入参数的数据类型
      // :: 用于的是向队列的头部追加数据,产生新的列表,Nil 是一个空的 List,定义为 List[Nothing]
      override def inputSchema: StructType = StructType(StructField("age", IntegerType) :: Nil)
    
      //等效于
      //  override def inputSchema: StructType=new StructType() .add("age", IntegerType).add("name", StringType)
    
      // 聚合缓冲区中值的数据类型
      override def bufferSchema: StructType = {
        StructType(StructField("sum", IntegerType) :: StructField("count", IntegerType) :: Nil)
      }
    
      // UserDefinedAggregateFunction返回值的数据类型。
      override def dataType: DataType = DoubleType
    
      // 如果这个函数是确定的,即给定相同的输入,总是返回相同的输出。
      override def deterministic: Boolean = true
    
      //  初始化给定的聚合缓冲区,即聚合缓冲区的零值。
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        // sum,  总的年龄
        buffer(0) = 0
        // count, 人数
        buffer(1) = 0
      }
    
      //  使用来自输入的新输入数据更新给定的聚合缓冲区。
      // 每个输入行调用一次。(同一分区)
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer(0) = buffer.getInt(0) + input.getInt(0) //年龄 叠加
        buffer(1) = buffer.getInt(1) + 1 //人数叠加
      }
    
      //  合并两个聚合缓冲区并将更新后的缓冲区值存储回buffer1。
      // 当我们将两个部分聚合的数据合并在一起时,就会调用这个函数。(多个分区)
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0) //年龄 叠加
        buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1) //人数叠加
      }
    
      override def evaluate(buffer: Row): Any = {
        buffer.getInt(0).toDouble / buffer.getInt(1)
      }
    
      val spark = SparkSession
        .builder()
        .appName("Spark SQL basic example")
        // .config("spark.some.config.option", "some-value")
        .master("local[*]") // 本地测试
        .getOrCreate()
    
      import spark.implicits._
    
      def main(args: Array[String]): Unit = {
        spark.udf.register("myAvg", TestUDAF)
        val ds: Dataset[Row] = spark.read.json("D:\DATA-LG\PUBLIC\TYGQ\INF\testJson")
        ds.createOrReplaceTempView("table1")
        //SQL
        spark.sql("select myAvg(age) as avgAge from table1")
          .show()
    
        //DSL
        val myavg = TestUDAF
        ds.select(TestUDAF($"age").as("avgAge"))
          .show()
      }
    }
    
    输出结果:
    +------+
    |avgAge|
    +------+
    | 31.75|
    +------+
    
    +------+
    |avgAge|
    +------+
    | 31.75|
    +------+
    

    强类型用户自定义聚合函数

    通过继承Aggregator(是org.apache.spark.sql.expressions 下的 不要引错包了)

    package com.test.spark
    
    import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession}
    import org.apache.spark.sql.expressions._
    
    /**
      * @author Administrator
      *         2019/7/22-16:07
      *
      */
    // 既然是强类型,可能有 case 类
    case class Person(name: String, age: Double, ip: String)
    
    case class Average(var sum: Double, var count: Double)
    
    object MyAverage extends Aggregator[Person, Average, Double] {
      //  此聚合的值为零。应该满足任意b + 0 = b的性质。
      //  定义一个数据结构,保存工资总数和工资总个数,初始都为0
      override def zero: Average = {
        Average(0, 0)
      }
    
      //  将两个值组合起来生成一个新值。为了提高性能,函数可以修改b并返回它,而不是为b构造新的对象。
      //  相同 Execute 间的数据合并(同一分区)
      override def reduce(b: Average, a: Person): Average = {
        b.sum += a.age
        b.count += 1
        b
      }
    
      // 合并两个中间值。
      // 聚合不同 Execute 的结果(不同分区)
      override def merge(b1: Average, b2: Average): Average = {
        b1.sum += b2.sum
        b1.count += b2.count
        b1
      }
    
      // 计算最终结果
      override def finish(reduction: Average): Double = {
        reduction.sum.toInt / reduction.count
      }
    
      //  为中间值类型指定“编码器”。
      override def bufferEncoder: Encoder[Average] = Encoders.product
    
      //  为最终输出值类型指定“编码器”。
      override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    
      val spark = SparkSession
        .builder()
        .appName("Spark SQL basic example")
        // .config("spark.some.config.option", "some-value")
        .master("local[*]") // 本地测试
        .getOrCreate()
    
      import spark.implicits._
    
      def main(args: Array[String]): Unit = {
        val ds: Dataset[Person] = spark.read.json("D:\DATA-LG\PUBLIC\TYGQ\INF\testJson").as[Person]
        ds.show()
    
        val avgAge = MyAverage.toColumn/*.name("avgAge")*///指定该列的别名为avgAge
        ds.select(avgAge)//执行avgAge.as("columnName") 汇报org.apache.spark.sql.AnalysisException错误  别名只能在上面指定(目前测试是这样)
          .show()
      }
    }
    
    输出结果:
    +---+---------------+---------+
    |age|             ip|     name|
    +---+---------------+---------+
    | 24|    192.168.0.8|  lillcol|
    |100|  192.168.255.1|    adson|
    | 39|  192.143.255.1|     wuli|
    | 20|  192.168.255.1|       gu|
    | 15|  243.168.255.9|     ason|
    |  1|  108.168.255.1|   tianba|
    | 25|222.168.255.110|clearlove|
    | 30|222.168.255.110|clearlove|
    +---+---------------+---------+
    
    +------+
    |avgAge|
    +------+
    | 31.75|
    +------+
    

    本文为原创文章,转载请注明出处!!!

  • 相关阅读:
    win10安装nodejs,修改全局依赖位置和环境变量配置
    JavaScript判断两个对象内容是否相等
    JS判断是否是数组
    Js判断值是否是NaN
    typeof方法重写(区分数组对象)
    JS实现图片懒加载
    输入url到展示页面过程发生了什么?
    html如何在服务端跑起来
    nuxt怎么打包
    如果scss引用了字体图标文件该怎么打包
  • 原文地址:https://www.cnblogs.com/lillcol/p/11229044.html
Copyright © 2020-2023  润新知