• 45、sparkSQL UDF&UDAF


    一、UDF

    1、UDF

    UDF:User Defined Function。用户自定义函数。


    2、scala案例

    package cn.spark.study.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.SparkContext
    import org.apache.spark.sql.SQLContext
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.types.StructType
    import org.apache.spark.sql.types.StructField
    import org.apache.spark.sql.types.StringType
    
    object UDF {
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf().setMaster("local").setAppName("UDF")
        val sc = new SparkContext(conf)
        val sqlContext = new SQLContext(sc)
        
        // 构造模拟数据
        val names = Array("Leo", "Marry", "Jack", "Tom")
        val namesRDD = sc.parallelize(names, 5)
        val namesRowRDD = namesRDD.map(name => Row(name))
        val structType = StructType(Array(StructField("name", StringType, true)))
        val namesDF = sqlContext.createDataFrame(namesRowRDD, structType)
        
        // 注册一张names表
        namesDF.registerTempTable("names")
        
        // 定义和注册自定义函数
        // 定义函数:自己写匿名函数
        // 注册函数:SQLContext.udf.register()
        // UDF函数名:strLen; 函数体(匿名函数):(str: String) => str.length()
        sqlContext.udf.register("strLen", (str: String) => str.length())
        
        // 使用自定义函数
        sqlContext.sql("select name, strLen(name) from names")
          .collect()
          .foreach(println)
        
      }
    }


    3、java案例

    package cn.spark.study.sql;
    
    import java.util.ArrayList;
    import java.util.List;
    
    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.api.java.function.VoidFunction;
    import org.apache.spark.sql.DataFrame;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.RowFactory;
    import org.apache.spark.sql.SQLContext;
    import org.apache.spark.sql.api.java.UDF1;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    public class UDF {
        public static void main(String[] args) {
            SparkConf conf = new SparkConf().setAppName("UDFJava").setMaster("local");
            JavaSparkContext sparkContext = new JavaSparkContext(conf);
            SQLContext sqlContext = new SQLContext(sparkContext);
            
            List<String> stringList = new ArrayList<String>();
            stringList.add("Leo");
            stringList.add("Marry");
            stringList.add("Jack");
            stringList.add("Tom");
            JavaRDD<String> rdd = sparkContext.parallelize(stringList);
            JavaRDD<Row> nameRDD = rdd.map(new Function<String, Row>() {
    
                private static final long serialVersionUID = 1L;
    
                @Override
                public Row call(String v1) throws Exception {
                    return RowFactory.create(v1);
                }
            });
            
            List<StructField> fieldList = new ArrayList<StructField>();
            fieldList.add(DataTypes.createStructField("name", DataTypes.StringType, true));
            StructType structType = DataTypes.createStructType(fieldList);
            DataFrame dataFrame = sqlContext.createDataFrame(nameRDD, structType);
            
            dataFrame.registerTempTable("name");
            sqlContext.udf().register("strLen", new UDF1<String, Integer>() {
                
                private static final long serialVersionUID = 1L;
    
                @Override
                public Integer call(String s) throws Exception {
                    // TODO Auto-generated method stub
                    return s.length();
                }
                
            }, DataTypes.IntegerType);
            
            sqlContext.sql("select name, strLen(name) from name").javaRDD().
            foreach(new VoidFunction<Row>() {
    
                private static final long serialVersionUID = 1L;
    
                @Override
                public void call(Row row) throws Exception {
                    System.out.println(row);            
                }
            });
            
            
        }
    }


    二、UDAF

    1、概述

    UDAF:User Defined Aggregate Function。用户自定义聚合函数。是Spark 1.5.x引入的最新特性。
    
    UDF,其实更多的是针对单行输入,返回一个输出,这里的UDAF,则可以针对一组(多行)输入,进行聚合计算,返回一个输出,功能更加强大


    使用:
    
    1. 自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现
    
    2. 在spark中注册UDAF,为其绑定一个名字
    
    3. 然后就可以在sql语句中使用上面绑定的名字调用


    2、scala案例

    统计字符串次数的例子,先定义一个类继承UserDefinedAggregateFunction:

    package cn.spark.study.sql
    
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
    import org.apache.spark.sql.types.StructType
    import org.apache.spark.sql.types.DataType
    import org.apache.spark.sql.expressions.MutableAggregationBuffer
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.types.StructField
    import org.apache.spark.sql.types.StringType
    import org.apache.spark.sql.types.IntegerType
    
    /**
     * @author Administrator
     */
    class StringCount extends UserDefinedAggregateFunction {  
      
      // inputSchema,指的是,输入数据的类型
      def inputSchema: StructType = {
        StructType(Array(StructField("str", StringType, true)))   
      }
      
      // bufferSchema,指的是,中间进行聚合时,所处理的数据的类型
      def bufferSchema: StructType = {
        StructType(Array(StructField("count", IntegerType, true)))   
      }
      
      // dataType,指的是,函数返回值的类型
      def dataType: DataType = {
        IntegerType
      }
      
      def deterministic: Boolean = {
        true
      }
    
      // 为每个分组的数据执行初始化操作
      def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = 0
      }
      
      // 指的是,每个分组,有新的值进来的时候,如何进行分组对应的聚合值的计算
      def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer(0) = buffer.getAs[Int](0) + 1
      }
      
      // 由于Spark是分布式的,所以一个分组的数据,可能会在不同的节点上进行局部聚合,就是update
      // 但是,最后一个分组,在各个节点上的聚合值,要进行merge,也就是合并
      def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)  
      }
      
      // 最后,指的是,一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
      def evaluate(buffer: Row): Any = {
        buffer.getAs[Int](0)    
      }
      
    }


    然后注册并使用它:

    package cn.spark.study.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.SparkContext
    import org.apache.spark.sql.SQLContext
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.types.StructType
    import org.apache.spark.sql.types.StructField
    import org.apache.spark.sql.types.StringType
    
    /**
     * @author Administrator
     */
    object UDAF {
      
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf()
            .setMaster("local") 
            .setAppName("UDAF")
        val sc = new SparkContext(conf)
        val sqlContext = new SQLContext(sc)
      
        // 构造模拟数据
        val names = Array("Leo", "Marry", "Jack", "Tom", "Tom", "Tom", "Leo")  
        val namesRDD = sc.parallelize(names, 5) 
        val namesRowRDD = namesRDD.map { name => Row(name) }
        val structType = StructType(Array(StructField("name", StringType, true)))  
        val namesDF = sqlContext.createDataFrame(namesRowRDD, structType) 
        
        // 注册一张names表
        namesDF.registerTempTable("names")  
        
        // 定义和注册自定义函数
        // 定义函数:自己写匿名函数
        // 注册函数:SQLContext.udf.register()
        sqlContext.udf.register("strCount", new StringCount) 
        
        // 使用自定义函数
        sqlContext.sql("select name,strCount(name) from names group by name")  
            .collect()
            .foreach(println)  
      }
      
    }
  • 相关阅读:
    C++11 序列化库 cereal
    Eigen 3.3.7 数组类(Array)和元素操作
    Java【 final、权限、内部类、引用类型】学习笔记
    Java多态学习笔记
    学习GUI编程第二天笔记
    GUI编程小测试
    第一篇学习笔记(Typora使用手册)
    Spring Boot2 系列教程(九)Spring Boot 整合 Thymeleaf
    Spring Boot2 系列教程(七)理解自动化配置的原理
    Spring Boot2 系列教程(六)自定义 Spring Boot 中的 starter
  • 原文地址:https://www.cnblogs.com/weiyiming007/p/11308694.html
Copyright © 2020-2023  润新知