• 自定义spark UDAF


    Spark提供了两种自定义聚合函数的方法,分别如下:

    Untyped User-Defined Aggregate Functions

      有类型的自定义聚合函数,主要适用于 DataSet

    Type-Safe User-Defined Aggregate Functions

      无类型的自定义聚合函数,主要适用于 DataFrame


    无类型的自定义聚合函数样例代码:

    import java.util.ArrayList;
    import java.util.List;
    
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.SparkSession;
    import org.apache.spark.sql.expressions.MutableAggregationBuffer;
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
    import org.apache.spark.sql.types.DataType;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    public static class MyAverage extends UserDefinedAggregateFunction {
    
      private StructType inputSchema;
      private StructType bufferSchema;
    
      public MyAverage() {
        List<StructField> inputFields = new ArrayList<>();
        inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
        inputSchema = DataTypes.createStructType(inputFields);
    
        List<StructField> bufferFields = new ArrayList<>();
        bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
        bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
        bufferSchema = DataTypes.createStructType(bufferFields);
      }
      // Data types of input arguments of this aggregate function
      public StructType inputSchema() {
        return inputSchema;
      }
      // Data types of values in the aggregation buffer
      public StructType bufferSchema() {
        return bufferSchema;
      }
      // The data type of the returned value
      public DataType dataType() {
        return DataTypes.DoubleType;
      }
      // Whether this function always returns the same output on the identical 相同的 input
      public boolean deterministic() {
        return true;
      }
      // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
      // standard methods like retrieving 获取 a value at an index (e.g., get(), getBoolean()), provides
      // the opportunity 方式 to update its values. Note that arrays and maps inside the buffer are still
      // immutable 不可变的.
      public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, 0L);
        buffer.update(1, 0L);
      }
      // Updates the given aggregation buffer `buffer` with new input data from `input`
      public void update(MutableAggregationBuffer buffer, Row input) {
        if (!input.isNullAt(0)) {
          long updatedSum = buffer.getLong(0) + input.getLong(0);
          long updatedCount = buffer.getLong(1) + 1;
          buffer.update(0, updatedSum);
          buffer.update(1, updatedCount);
        }
      }
      // Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
      public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
        long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
        buffer1.update(0, mergedSum);
        buffer1.update(1, mergedCount);
      }
      // Calculates the final result
      public Double evaluate(Row buffer) {
        return ((double) buffer.getLong(0)) / buffer.getLong(1);
      }
    }
    
    // Register the function to access it
    spark.udf().register("myAverage", new MyAverage());
    
    Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
    df.createOrReplaceTempView("employees");
    df.show();
    // +-------+------+
    // |   name|salary|
    // +-------+------+
    // |Michael|  3000|
    // |   Andy|  4500|
    // | Justin|  3500|
    // |  Berta|  4000|
    // +-------+------+
    
    Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
    result.show();
    // +--------------+
    // |average_salary|
    // +--------------+
    // |        3750.0|
    // +--------------+

    样例代码2:

    import java.util.Arrays;
    
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.expressions.MutableAggregationBuffer;
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
    import org.apache.spark.sql.types.DataType;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructType;
    
    /**
     * 组内拼接去重函数(group_concat_distinct())
     */
    public class GroupConcatDistinctUDAF extends UserDefinedAggregateFunction {
    
        private static final long serialVersionUID = -2510776241322950505L;
        
        // 指定输入数据的字段与类型
        //    指定具体的输入数据的类型
        //       * 自段名称随意:Users can choose names to identify the input arguments - 这里可以是“name”,或者其他任意串
        private StructType inputSchema = DataTypes.createStructType(Arrays.asList(
                DataTypes.createStructField("cityInfo", DataTypes.StringType, true)));  
        
        // 指定缓冲数据的字段与类型
        //    在进行聚合操作的时候所要处理的数据的中间结果类型
        private StructType bufferSchema = DataTypes.createStructType(Arrays.asList(
                DataTypes.createStructField("bufferCityInfo", DataTypes.StringType, true)));  
    
        // 指定返回类型
        private DataType dataType = DataTypes.StringType;
        
        // 指定是否是确定性的
        /*whether given the same input,
           * always return the same output
           * true: yes*/
        private boolean deterministic = true;
        
        @Override
        public StructType inputSchema() {
            return inputSchema;
        }
        
        @Override
        public StructType bufferSchema() {
            return bufferSchema;
        }
    
        @Override
        public DataType dataType() {
            return dataType;
        }
    
        @Override
        public boolean deterministic() {
            return deterministic;
        }
        
        /**
         * 初始化
         * 可以认为是,你自己在内部指定一个初始的值
         * Initializes the given aggregation buffer
         */
        @Override
        public void initialize(MutableAggregationBuffer buffer) {
            buffer.update(0, "");  
        }
        
        /**
         * 更新
         * 可以认为是,一个一个地将组内的字段值传递进来
         * 实现拼接的逻辑
         * 
         * 在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
         * 本地的聚合操作,相当于Hadoop MapReduce模型中的Combiner
         */
        @Override
        public void update(MutableAggregationBuffer buffer, Row input) {
            // 缓冲中的已经拼接过的城市信息串
            String bufferCityInfo = buffer.getString(0);
            // 刚刚传递进来的某个城市信息
            String cityInfo = input.getString(0);
            
            // 在这里要实现去重的逻辑
            // 判断:之前没有拼接过某个城市信息,那么这里才可以接下去拼接新的城市信息
            if(!bufferCityInfo.contains(cityInfo)) {
                if("".equals(bufferCityInfo)) {
                    bufferCityInfo += cityInfo;
                } else {
                    // 比如1:北京
                    // 1:北京,2:上海
                    bufferCityInfo += "," + cityInfo;
                }
                
                buffer.update(0, bufferCityInfo);  
            }
        }
        
        /**
         * 合并
         * update操作,可能是针对一个分组内的部分数据,在某个节点上发生的
         * 但是可能一个分组内的数据,会分布在多个节点上处理
         * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
         */
        @Override
        public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
            String bufferCityInfo1 = buffer1.getString(0);
            String bufferCityInfo2 = buffer2.getString(0);
            
            for(String cityInfo : bufferCityInfo2.split(",")) {
                if(!bufferCityInfo1.contains(cityInfo)) {
                    if("".equals(bufferCityInfo1)) {
                        bufferCityInfo1 += cityInfo;
                    } else {
                        bufferCityInfo1 += "," + cityInfo;
                    }
                 }
            }
            
            buffer1.update(0, bufferCityInfo1);  
        }
        
        @Override
        public Object evaluate(Row row) {  
            return row.getString(0);  
        }
    
    }

     有类型的自定义聚合函数,样例代码:

    import java.io.Serializable;
    
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Encoder;
    import org.apache.spark.sql.Encoders;
    import org.apache.spark.sql.SparkSession;
    import org.apache.spark.sql.TypedColumn;
    import org.apache.spark.sql.expressions.Aggregator;
    
    public static class Employee implements Serializable {
      private String name;
      private long salary;
    
      // Constructors, getters, setters...
    
    }
    
    public static class Average implements Serializable  {
      private long sum;
      private long count;
    
      // Constructors, getters, setters...
    
    }
    
    public static class MyAverage extends Aggregator<Employee, Average, Double> {
      // A zero value for this aggregation. Should satisfy the property that any b + zero = b
      public Average zero() {
        return new Average(0L, 0L);
      }
      // Combine two values to produce a new value. For performance, the function may modify `buffer`
      // and return it instead of constructing a new object
      public Average reduce(Average buffer, Employee employee) {
        long newSum = buffer.getSum() + employee.getSalary();
        long newCount = buffer.getCount() + 1;
        buffer.setSum(newSum);
        buffer.setCount(newCount);
        return buffer;
      }
      // Merge two intermediate values
      public Average merge(Average b1, Average b2) {
        long mergedSum = b1.getSum() + b2.getSum();
        long mergedCount = b1.getCount() + b2.getCount();
        b1.setSum(mergedSum);
        b1.setCount(mergedCount);
        return b1;
      }
      // Transform the output of the reduction
      public Double finish(Average reduction) {
        return ((double) reduction.getSum()) / reduction.getCount();
      }
      // Specifies the Encoder for the intermediate value type
      public Encoder<Average> bufferEncoder() {
        return Encoders.bean(Average.class);
      }
      // Specifies the Encoder for the final output value type
      public Encoder<Double> outputEncoder() {
        return Encoders.DOUBLE();
      }
    }
    
    Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class);
    String path = "examples/src/main/resources/employees.json";
    Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
    ds.show();
    // +-------+------+
    // |   name|salary|
    // +-------+------+
    // |Michael|  3000|
    // |   Andy|  4500|
    // | Justin|  3500|
    // |  Berta|  4000|
    // +-------+------+
    
    MyAverage myAverage = new MyAverage();
    // Convert the function to a `TypedColumn` and give it a name
    TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
    Dataset<Double> result = ds.select(averageSalary);
    result.show();
    // +--------------+
    // |average_salary|
    // +--------------+
    // |        3750.0|
    // +--------------+

     相关API

     


    http://spark.apache.org/docs/2.3.4/sql-programming-guide.html#type-safe-user-defined-aggregate-functions

  • 相关阅读:
    Node Introduce
    鼠标拖动物体
    给模型自动赋予贴图代码
    JS读取XML
    动态控件01
    背包代码
    输出文本信息在U3D读取切换SHADER的SCRIPT测试
    材质球一闪一闪
    适配器模式1
    简单工厂,工厂方法的区别总结
  • 原文地址:https://www.cnblogs.com/zz-ksw/p/11737631.html
Copyright © 2020-2023  润新知