• Alink漫谈(十九) :源码解析 之 分位点离散化Quantile


    Alink漫谈(十九) :源码解析 之 分位点离散化Quantile

    0x00 摘要

    Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文将带领大家来分析Alink中 Quantile 的实现。

    因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。

    本文缘由是因为想分析GBDT,发现GBDT涉及到Quantile的使用,所以只能先分析Quantile 。

    0x01 背景概念

    1.1 离散化

    离散化:就是把无限空间中有限的个体映射到有限的空间中(分箱处理)。数据离散化操作大多是针对连续数据进行的,处理之后的数据值域分布将从连续属性变为离散属性。

    离散化方式会影响后续数据建模和应用效果:

    • 使用决策树往往倾向于少量的离散化区间,过多的离散化将使得规则过多受到碎片区间的影响。
    • 关联规则需要对所有特征一起离散化,关联规则关注的是所有特征的关联关系,如果对每个列单独离散化将失去整体规则性。

    连续数据的离散化结果可以分为两类:

    • 一类是将连续数据划分为特定区间的集合,例如{(0,10], (10,20], (20,50],(50,100]};
    • 一类是将连续数据划分为特定类,例如类1、类2、类3;

    1.2 分位数

    分位数(Quantile),亦称分位点,是指将一个随机变量的概率分布范围分为几个等份的数值点,常用的有中位数(即二分位数)、四分位数、百分位数等。

    假如有1000个数字(正数),这些数字的5%, 30%, 50%, 70%, 99%分位数分别是 [3.0,5.0,6.0,9.0,12.0],这表明

    • 有5%的数字分布在0-3.0之间
    • 有25%的数字分布在3.0-5.0之间
    • 有20%的数字分布在5.0-6.0之间
    • 有20%的数字分布在6.0-9.0之间
    • 有29%的数字分布在9.0-12.0之间
    • 有1%的数字大于12.0

    这就是分位数的统计学理解。

    因此求解某一组数字中某个数的分位数,只需要将该组数字进行排序,然后再统计小于等于该数的个数,除以总的数字个数即可。

    确定p分位数位置的两种方法

    • position = (n+1)p
    • position = 1 + (n-1)p

    1.3 四分位数

    这里我们用四分位数做进一步说明。

    四分位数 概念:把给定的乱序数值由小到大排列并分成四等份,处于三个分割点位置的数值就是四分位数。

    第1四分位数 (Q1),又称“较小四分位数”,等于该样本中所有数值由小到大排列后第25%的数字。

    第2四分位数 (Q2),又称“中位数”,等于该样本中所有数值由小到大排列后第50%的数字。

    第3四分位数 (Q3),又称“较大四分位数”,等于该样本中所有数值由小到大排列后第75%的数字。

    四分位距(InterQuartile Range, IQR)= 第3四分位数与第1四分位数的差距。

    0x02 示例代码

    Alink中完成分位数功能的是QuantileDiscretizerQuantileDiscretizer输入连续的特征列,输出分箱的类别特征。

    • 分位点离散可以计算选定列的分位点,然后使用这些分位点进行离散化。生成选中列对应的q-quantile,其中可以所有列指定一个,也可以每一列对应一个。
    • 分箱数(所需离散的数目,即分为几段)是通过参数numBuckets(桶数目)来指定的。 箱的范围是通过使用近似算法来得到的。

    本文示例代码如下。

    public class QuantileDiscretizerExample {
        public static void main(String[] args) throws Exception {
            NumSeqSourceBatchOp numSeqSourceBatchOp = new NumSeqSourceBatchOp(1001, 2000, "col0"); // 就是把1001 ~ 2000 这个连续数值分段
    
            Pipeline pipeline = new Pipeline()
                    .add(new QuantileDiscretizer()
                            .setNumBuckets(6) // 指定分箱数数目
                            .setSelectedCols(new String[]{"col0"}));
    
            List<Row> result = pipeline.fit(numSeqSourceBatchOp).transform(numSeqSourceBatchOp).collect();
            System.out.println(result);
        }
    }
    

    输出

    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 
    .....
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
    .....
    5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]
    

    0x03 总体逻辑

    我们首先给出总体逻辑图例

    -------------------------------- 准备阶段 --------------------------------
           │
           │
           │  
    ┌───────────────────┐ 
    │  getSelectedCols  │ 获取需要分位的列名字
    └───────────────────┘ 
           │
           │
           │
    ┌─────────────────────┐ 
    │     quantileNum     │ 获取分箱数
    └─────────────────────┘ 
           │
           │
           │
    ┌──────────────────────┐ 
    │ Preprocessing.select │ 从输入中根据列名字select出数据
    └──────────────────────┘ 
           │
           │
           │
    -------------------------------- 预处理阶段 --------------------------------
           │ 
           │
           │
    ┌──────────────────────┐ 
    │       quantile       │ 后续步骤 就是 计算分位数
    └──────────────────────┘ 
           │
           │
           │ 
    ┌────────────────────────────────┐ 
    │   countElementsPerPartition    │ 在每一个partition中获取该分区的所有元素个数
    └────────────────────────────────┘ 
           │ <task id, count in this task>
           │
           │
    ┌──────────────────────┐ 
    │       sum(1)         │ 这里对第二个参数,即"count in this task"进行累积,得出所有元素的个数
    └──────────────────────┘ 
           │  
           │
           │
    ┌──────────────────────┐ 
    │        map           │ 取出所有元素个数,cnt在后续会使用
    └──────────────────────┘ 
           │    
           │    
           │
           │    
    ┌──────────────────────┐ 
    │     missingCount     │ 分区查找应选的列中,有哪些数据没有被查到,比如zeroAsMissing, null, isNaN
    └──────────────────────┘ 
           │
           │
           │
    ┌────────────────┐ 
    │  mapPartition  │ 把输入数据Row打散,对于Row中的子元素按照Row内顺序一一发送出来
    └────────────────┘ 
           │ <idx in row, item in row>, 即<row中第几个元素,元素>
           │
           │  
    ┌──────────────┐ 
    │    pSort     │ 将flatten数据进行排序
    └──────────────┘ 
           │ 返回的是二元组
           │ f0: dataset which is indexed by partition id
           │ f1: dataset which has partition id and count
           │ 
           │  
    -------------------------------- 计算阶段 --------------------------------
           │ 
           │
           │ 
    ┌─────────────────┐ 
    │  MultiQuantile  │ 后续都是具体计算步骤
    └─────────────────┘ 
           │
           │ 
           │
    ┌─────────────────┐ 
    │      open       │ 从广播中获取变量,初步处理counts(排序),totalCnt,missingCounts(排序)
    └─────────────────┘ 
           │
           │ 
           │
    ┌─────────────────┐ 
    │  mapPartition   │ 具体计算
    └─────────────────┘         
           │
           │ 
           │
    ┌─────────────────┐ 
    │    groupBy(0)   │ 依据 列idx 分组
    └─────────────────┘   
           │
           │ 
           │
    ┌─────────────────┐ 
    │   reduceGroup   │ 归并排序
    └─────────────────┘    
           │set(Tuple2<column idx, 真实数据值>)
           │ 
           │ 
    -------------------------------- 序列化模型 --------------------------------
           │ 
           │
           │    
    ┌──────────────┐ 
    │  reduceGroup │ 分组归并
    └──────────────┘ 
           │ 
           │
           │   
    ┌─────────────────┐ 
    │  SerializeModel │ 序列化模型
    └─────────────────┘ 
      
    

    下面图片是为了在手机上缩放适配展示。

    QuantileDiscretizerTrainBatchOp.linkFrom如下:

    public QuantileDiscretizerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
       BatchOperator<?> in = checkAndGetFirst(inputs);
    
       // 示例中设置了 .setSelectedCols(new String[]{"col0"}));, 所以这里 quantileColNames 的数值是"col0 
       String[] quantileColNames = getSelectedCols();
    
       int[] quantileNum = null;
    
       // 示例中设置了 .setNumBuckets(6),所以这里 quantileNum 是 quantileNum = {int[1]@2705} 0 = 6
       if (getParams().contains(QuantileDiscretizerTrainParams.NUM_BUCKETS)) {
          quantileNum = new int[quantileColNames.length];
          Arrays.fill(quantileNum, getNumBuckets());
       } else {
          quantileNum = Arrays.stream(getNumBucketsArray()).mapToInt(Integer::intValue).toArray();
       }
    
       /* filter the selected column from input */
       // 获取了 选择的列 "col0"
       DataSet<Row> input = Preprocessing.select(in, quantileColNames).getDataSet();
    
       // 计算分位数
       DataSet<Row> quantile = quantile(
          input, quantileNum,
          getParams().get(HasRoundMode.ROUND_MODE),
          getParams().get(Preprocessing.ZERO_AS_MISSING)
       );
    
       // 序列化模型
       quantile = quantile.reduceGroup(
          new SerializeModel(
             getParams(),
             quantileColNames,
             TableUtil.findColTypesWithAssertAndHint(in.getSchema(), quantileColNames),
             BinTypes.BinDivideType.QUANTILE
          )
       );
    
       /* set output */
       setOutput(quantile, new QuantileDiscretizerModelDataConverter().getModelSchema());
    
       return this;
    }
    

    其总体逻辑如下:

    • 获取需要分位的列名字
    • 获取分箱数
    • 从输入中根据列名字select出数据
    • 调用 quantile 计算分位数
      • 调用 countElementsPerPartition 在每一个partition中获取该分区的所有元素个数,返回<task id, count in this task>,然后 对于元素个数进行累积 sum(1) ,即"count in this task"进行累积,得出所有元素的个数 cnt;
      • 分区查找应选的列中,有哪些数据没有被查到,从代码看,是zeroAsMissing, null, isNaN这几种情况,然后依据 partition id 进行分组 groupBy(0) 累积求和,得到 missingCount;
      • 把输入数据Row打散,对于Row中的子元素按照Row内顺序一一发送出来,这就做到了把Row类型给flatten了, 返回flatten = <idx in row, item in row>, 即<row中第几个元素,元素>;
      • 将flatten数据进行排序,pSort是大规模分区排序,此时还没有分类。pSort返回的是二元组sortedData,f0: dataset which is indexed by partition id, f1: dataset which has partition id and count;
      • 调用 MultiQuantile ,对 sortedData.f0(f0: dataset which is indexed by partition id) 进行计算分位数;具体是分区计算 mapPartition:
        • 累积,得到当前 task 的起始位置,即 n 个输入数据中从哪个数据开始计算;
        • 根据 taskId 从 counts 中得到了本 task 应该处理哪些数据,即数据的start,end位置;
        • 把数据插入 allRows.add(value); value 可认为是 <partition id, 真实数据>;
        • 调用 QIndex 计算分位数元数据;quantileNum是分成几段,q1就是每一段的大小。如果分成6段,则每一段的大小是1/6;
        • 遍历一直到分箱数,每次循环 调用 qIndex.genIndex(j) 获取每个分箱的index。然后依据这个分箱的index从输入数据中获取真实数据值,这个 真实数据值 就是 真实数据的index。比如连续区域是 1001 ~ 2000,分成 6 份,则第一份调用 qIndex.genIndex(j) 得到 167,则根据167,获取真实数据是 1001 + 167 = 1168,即在 1001 ~ 2000 中,第一个分位index 是 1168.
      • 依据 列idx 分组,得到 set(Tuple2<column idx, 真实数据值>);
    • 序列化模型

    0x04 训练

    4.1 quantile

    训练是通过 quantile 完成的,大致包含以下步骤。

    • 调用 countElementsPerPartition 在每一个partition中获取该分区的所有元素个数,返回<task id, count in this task>,然后 对于元素个数进行累积 sum(1) ,即"count in this task"进行累积,得出所有元素的个数 cnt;
    • 分区查找应选的列中,有哪些数据没有被查到,从代码看,是zeroAsMissing, null, isNaN这几种情况,然后依据 partition id 进行分组 groupBy(0) 累积求和,得到 missingCount;
    • 把输入数据Row打散,对于Row中的子元素按照Row内顺序一一发送出来,这就做到了把Row类型给flatten了,返回flatten = <idx in row, item in row>, 即<row中第几个元素,元素>;
    • 将flatten数据进行排序,pSort是大规模分区排序,此时还没有分类。pSort返回的是二元组sortedData,f0: dataset which is indexed by partition id, f1: dataset which has partition id and count;
    • 调用 MultiQuantile ,对 sortedData.f0(f0: dataset which is indexed by partition id) 进行计算分位数。

    具体如下

    public static DataSet<Row> quantile(
       DataSet<Row> input,
       final int[] quantileNum,
       final HasRoundMode.RoundMode roundMode,
       final boolean zeroAsMissing) {
      
       /* instance count of dataset */
       // countElementsPerPartition 的作用是:在每一个partition中获取该分区的所有元素个数,返回<task id, count in this task>。
       DataSet<Long> cnt = DataSetUtils
          .countElementsPerPartition(input)
          .sum(1) // 这里对第二个参数,即"count in this task"进行累积,得出所有元素的个数。
          .map(new MapFunction<Tuple2<Integer, Long>, Long>() {
             @Override
             public Long map(Tuple2<Integer, Long> value) throws Exception {
                return value.f1; // 取出所有元素个数
             }
          }); // cnt在后续会使用
    
       /* missing count of columns */
       // 会查找应选的列中,有哪些数据没有被查到,从代码看,是zeroAsMissing, null, isNaN这几种情况
       DataSet<Tuple2<Integer, Long>> missingCount = input
          .mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Long>>() {
             public void mapPartition(Iterable<Row> values, Collector<Tuple2<Integer, Long>> out) {
                StreamSupport.stream(values.spliterator(), false)
                   .flatMap(x -> {
                      long[] counts = new long[x.getArity()];
    
                      Arrays.fill(counts, 0L);
       
                      // 如果发现有数据没有查到,就增加counts
                      for (int i = 0; i < x.getArity(); ++i) {
                         if (x.getField(i) == null
                         || (zeroAsMissing && ((Number) x.getField(i)).doubleValue() == 0.0)
                         || Double.isNaN(((Number)x.getField(i)).doubleValue())) {
                            counts[i]++;
                         }
                      }
    
                      return IntStream.range(0, x.getArity())
                         .mapToObj(y -> Tuple2.of(y, counts[y]));
                   })
                   .collect(Collectors.groupingBy(
                      x -> x.f0,
                      Collectors.mapping(x -> x.f1, Collectors.reducing((a, b) -> a + b))
                      )
                   )
                   .entrySet()
                   .stream()
                   .map(x -> Tuple2.of(x.getKey(), x.getValue().get()))
                   .forEach(out::collect);
             }
          })
          .groupBy(0) //按第一个元素分组
          .reduce(new RichReduceFunction<Tuple2<Integer, Long>>() {
             @Override
             public Tuple2<Integer, Long> reduce(Tuple2<Integer, Long> value1, Tuple2<Integer, Long> value2) {
                return Tuple2.of(value1.f0, value1.f1 + value2.f1); //累积求和
             }
          });
    
       /* flatten dataset to 1d */
       // 把输入数据打散。
       DataSet<PairComparable> flatten = input
          .mapPartition(new RichMapPartitionFunction<Row, PairComparable>() {
             PairComparable pairBuff;
             public void mapPartition(Iterable<Row> values, Collector<PairComparable> out) {
                for (Row value : values) { // 遍历分区内所有输入元素
                   for (int i = 0; i < value.getArity(); ++i) { // 如果输入元素Row本身包含多个子元素
                      pairBuff.first = i; // 则对于这些子元素按照Row内顺序一一发送出来,这就做到了把Row类型给flatten了
                      if (value.getField(i) == null
                         || (zeroAsMissing && ((Number) value.getField(i)).doubleValue() == 0.0)
                         || Double.isNaN(((Number)value.getField(i)).doubleValue())) {
                         pairBuff.second = null;
                      } else {
                         pairBuff.second = (Number) value.getField(i);
                      }
                      out.collect(pairBuff); // 返回<idx in row, item in row>, 即<row中第几个元素,元素>
                   }
                }
             }
          });
    
       /* sort data */
       // 将flatten数据进行排序,pSort是大规模分区排序,此时还没有分类
       // pSort返回的是二元组,f0: dataset which is indexed by partition id, f1: dataset which has partition id and count.
       Tuple2<DataSet<PairComparable>, DataSet<Tuple2<Integer, Long>>> sortedData
          = SortUtilsNext.pSort(flatten);
    
       /* calculate quantile */
       return sortedData.f0 //f0: dataset which is indexed by partition id
          .mapPartition(new MultiQuantile(quantileNum, roundMode))
          .withBroadcastSet(sortedData.f1, "counts") //f1: dataset which has partition id and count
          .withBroadcastSet(cnt, "totalCnt")
          .withBroadcastSet(missingCount, "missingCounts")
          .groupBy(0) // 依据 列idx 分组
          .reduceGroup(new RichGroupReduceFunction<Tuple2<Integer, Number>, Row>() {
             @Override
             public void reduce(Iterable<Tuple2<Integer, Number>> values, Collector<Row> out) {
                TreeSet<Number> set = new TreeSet<>(new Comparator<Number>() {
                   @Override
                   public int compare(Number o1, Number o2) {
                      return SortUtils.OBJECT_COMPARATOR.compare(o1, o2);
                   }
                });
    
                int id = -1;
                for (Tuple2<Integer, Number> val : values) {
                   // Tuple2<column idx, 数据>
                   id = val.f0;
                   set.add(val.f1); 
                }
    
    // runtime变量           
    set = {TreeSet@9379}  size = 5
     0 = {Long@9389} 167 // 就是第 0 列的第一段 idx
     1 = {Long@9392} 333 // 就是第 0 列的第二段 idx
     2 = {Long@9393} 500 
     3 = {Long@9394} 667
     4 = {Long@9382} 833
      
                out.collect(Row.of(id, set.toArray(new Number[0])));
             }
          });
    }
    

    下面会对几个重点函数做说明。

    4.2 countElementsPerPartition

    countElementsPerPartition 的作用是:在每一个partition中获取该分区的所有元素个数。

    public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
       return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
          @Override
          public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
             long counter = 0;
             for (T value : values) {
                counter++; // 在每一个partition中获取该分区的所有元素个数
             }
             out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
          }
       });
    }
    

    4.3 MultiQuantile

    MultiQuantile用来计算具体的分位点。

    open函数中会从广播中获取变量,初步处理counts(排序),totalCnt,missingCounts(排序)等等。

    mapPartition函数则做具体计算,大致步骤如下:

    • 累积,得到当前 task 的起始位置,即 n 个输入数据中从哪个数据开始计算;
    • 根据 taskId 从 counts 中得到了本 task 应该处理哪些数据,即数据的start,end位置;
    • 把数据插入 allRows.add(value); value 可认为是 <partition id, 真实数据>;
    • 调用 QIndex 计算分位数元数据;quantileNum是分成几段,q1就是每一段的大小。如果分成6段,则每一段的大小是1/6;
    • 遍历一直到分箱数,每次循环 调用 qIndex.genIndex(j) 获取每个分箱的index。然后依据这个分箱的index从输入数据中获取真实数据值,这个 真实数据值 就是 真实数据的index。比如连续区域是 1001 ~ 2000,分成 6 份,则第一份调用 qIndex.genIndex(j) 得到 167,则根据167,获取真实数据是 1001 + 167 = 1168,即在 1001 ~ 2000 中,第一个分位index 是 1168;

    具体代码是:

    public static class MultiQuantile
       extends RichMapPartitionFunction<PairComparable, Tuple2<Integer, Number>> {
    		private List<Tuple2<Integer, Long>> counts;
    		private List<Tuple2<Integer, Long>> missingCounts;
    		private long totalCnt = 0;
    		private int[] quantileNum;
    		private HasRoundMode.RoundMode roundType;
    		private int taskId;
    
    		@Override
    		public void open(Configuration parameters) throws Exception {
          // 从广播中获取变量,初步处理counts(排序),totalCnt,missingCounts(排序)。
          // 之前设置广播变量.withBroadcastSet(sortedData.f1, "counts"),其中 f1 的格式是: dataset which has partition id and count,所以就是用 partition id来排序
    			this.counts = getRuntimeContext().getBroadcastVariableWithInitializer(
    				"counts",
    				new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
    					@Override
    					public List<Tuple2<Integer, Long>> initializeBroadcastVariable(
    						Iterable<Tuple2<Integer, Long>> data) {
    						ArrayList<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
    						for (Tuple2<Integer, Long> datum : data) {
    							sortedData.add(datum);
    						}
                //排序
    						sortedData.sort(Comparator.comparing(o -> o.f0));
                
    // runtime的数据如下,本机有4核,所以数据分为4个 partition,每个partition的数据分别为251,250,250,250        
    sortedData = {ArrayList@9347}  size = 4
     0 = {Tuple2@9350} "(0,251)" // partition 0, 数据个数是251
     1 = {Tuple2@9351} "(1,250)"
     2 = {Tuple2@9352} "(2,250)"
     3 = {Tuple2@9353} "(3,250)"         
                
    						return sortedData;
    					}
    				});
    
    			this.totalCnt = getRuntimeContext().getBroadcastVariableWithInitializer("totalCnt",
    				new BroadcastVariableInitializer<Long, Long>() {
    					@Override
    					public Long initializeBroadcastVariable(Iterable<Long> data) {
    						return data.iterator().next();
    					}
    				});
    
    			this.missingCounts = getRuntimeContext().getBroadcastVariableWithInitializer(
    				"missingCounts",
    				new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
    					@Override
    					public List<Tuple2<Integer, Long>> initializeBroadcastVariable(
    						Iterable<Tuple2<Integer, Long>> data) {
    						return StreamSupport.stream(data.spliterator(), false)
    							.sorted(Comparator.comparing(o -> o.f0))
    							.collect(Collectors.toList());
    					}
    				}
    			);
    
    			taskId = getRuntimeContext().getIndexOfThisSubtask();
          
    // runtime的数据如下        
    this = {QuantileDiscretizerTrainBatchOp$MultiQuantile@9348} 
     counts = {ArrayList@9347}  size = 4
      0 = {Tuple2@9350} "(0,251)"
      1 = {Tuple2@9351} "(1,250)"
      2 = {Tuple2@9352} "(2,250)"
      3 = {Tuple2@9353} "(3,250)"
     missingCounts = {ArrayList@9375}  size = 1
      0 = {Tuple2@9381} "(0,0)"
     totalCnt = 1001
     quantileNum = {int[1]@9376} 
      0 = 6
     roundType = {HasRoundMode$RoundMode@9377} "ROUND"
     taskId = 2
    		}
    
    		@Override
    		public void mapPartition(Iterable<PairComparable> values, Collector<Tuple2<Integer, Number>> out) throws Exception {
    
    			long start = 0;
    			long end;
    
    			int curListIndex = -1;
    			int size = counts.size(); // 分成4份,所以这里是4
    
    			for (int i = 0; i < size; ++i) {
    				int curId = counts.get(i).f0; // 取出输入元素中的 partition id
    
    				if (curId == taskId) {
    					curListIndex = i; // 当前 task 对应哪个 partition id
    					break; // 到了当前task,就可以跳出了
    				}
    
    				start += counts.get(i).f1; // 累积,得到当前 task 的起始位置,即1000个数据中从哪个数据开始计算
    			}
    
          // 根据 taskId 从counts中得到了本 task 应该处理哪些数据,即数据的start,end位置
          // 本 partition 是 0,其中有251个数据
    			end = start + counts.get(curListIndex).f1; // end = 起始位置 + 此partition的数据个数 
    
    			ArrayList<PairComparable> allRows = new ArrayList<>((int) (end - start));
    
    			for (PairComparable value : values) {
    				allRows.add(value); // value 可认为是 <partition id, 真实数据>
    			}
    
    			allRows.sort(Comparator.naturalOrder());
    
    // runtime变量
    start = 0
    curListIndex = 0
    size = 4
    end = 251
    allRows = {ArrayList@9406}  size = 251
     0 = {PairComparable@9408} 
      first = {Integer@9397} 0
      second = {Long@9434} 0
     1 = {PairComparable@9409} 
      first = {Integer@9397} 0
      second = {Long@9435} 1
     2 = {PairComparable@9410} 
      first = {Integer@9397} 0
      second = {Long@9439} 2
     ......
          
          // size = ((251 - 1) / 1001 - 0 / 1001) + 1 = 1
    			size = (int) ((end - 1) / totalCnt - start / totalCnt) + 1;
    
    			int localStart = 0;
    			for (int i = 0; i < size; ++i) {
    				int fIdx = (int) (start / totalCnt + i);
    				int subStart = 0;
    				int subEnd = (int) totalCnt;
    
    				if (i == 0) {
    					subStart = (int) (start % totalCnt); // 0
    				}
    
    				if (i == size - 1) {
    					subEnd = (int) (end % totalCnt == 0 ? totalCnt : end % totalCnt); // 251
    				}
    
    				if (totalCnt - missingCounts.get(fIdx).f1 == 0) {
    					localStart += subEnd - subStart;
    					continue;
    				}
    
    				QIndex qIndex = new QIndex(
    					totalCnt - missingCounts.get(fIdx).f1, quantileNum[fIdx], roundType);
    
    // runtime变量
    qIndex = {QuantileDiscretizerTrainBatchOp$QIndex@9548} 
     totalCount = 1001.0
     q1 = 0.16666666666666666
     roundMode = {HasRoundMode$RoundMode@9377} "ROUND"      
            
            // 遍历,一直到分箱数。
    				for (int j = 1; j < quantileNum[fIdx]; ++j) {
              // 获取每个分箱的index 
    					long index = qIndex.genIndex(j); // j = 1 ---> index = 167,就是把 1001 个分为6段,第一段终点是167
              //对应本 task = 0,subStart = 0,subEnd = 251。则index = 167,直接从allRows获取第167个,数值是 1168。因为连续区域是 1001 ~ 2000,所以第167个对应数值就是1168
              //如果本 task = 1,subStart = 251,subEnd = 501。则index = 333,直接从allRows获取第 (333 + 0 - 251)= 第 82 个,获取其中的数值。这里因为数值区域是 1001 ~ 2000, 所以数值是1334。
    					if (index >= subStart && index < subEnd) { // idx刚刚好在本分区的数据中
    						PairComparable pairComparable = allRows.get(
    							(int) (index + localStart - subStart)); // 
                
                  
    // runtime变量            
    pairComparable = {PairComparable@9581} 
     first = {Integer@9507} 0 // first是column idx
     second = {Long@9584} 167 // 真实数据     
       
    						out.collect(Tuple2.of(pairComparable.first, pairComparable.second));
    					}
    				}
    
    				localStart += subEnd - subStart;
    			}
    		}
    	}
    

    4.4 QIndex

    其中 QIndex 是本文关键所在,就是具体计算分位数。

    • 构造函数中会得倒所有元素个数,每段大小;
    • genIndex函数中会具体计算,比如假设还是6段,则如果取第一段,则k=1,其index为 (1/6 * (1001 - 1) * 1) = 167
    public static class QIndex {
       private double totalCount;
       private double q1;
       private HasRoundMode.RoundMode roundMode;
    
       public QIndex(double totalCount, int quantileNum, HasRoundMode.RoundMode type) {
          this.totalCount = totalCount; // 1001,所有元素的个数
          this.q1 = 1.0 / (double) quantileNum; // 1.0 / 6 = 16666666666666666。quantileNum是分成几段,q1就是每一段的大小。如果分成6段,则每一段的大小是1/6
          this.roundMode = type;
       }
    
       public long genIndex(int k) {
          // 假设还是6段,则如果取第一段,则k=1,其index为 (1/6 * (1001 - 1) * 1) = 167
          return roundMode.calc(this.q1 * (this.totalCount - 1.0) * (double) k);
       }
    }
    

    0x05 输出模型

    输出模型是通过 reduceGroup 调用 SerializeModel 来完成。

    具体逻辑是:

    • 先构建分箱点元数据信息;
    • 然后序列化成模型;
    // 序列化模型
    quantile = quantile.reduceGroup(
          new SerializeModel(
             getParams(),
             quantileColNames,
             TableUtil.findColTypesWithAssertAndHint(in.getSchema(), quantileColNames),
             BinTypes.BinDivideType.QUANTILE
          )
    );
    

    SerializeModel 的具体实现是:

    public static class SerializeModel implements GroupReduceFunction<Row, Row> {
       private Params meta;
       private String[] colNames;
       private TypeInformation<?>[] colTypes;
       private BinTypes.BinDivideType binDivideType;
    
       @Override
       public void reduce(Iterable<Row> values, Collector<Row> out) throws Exception {
          Map<String, FeatureBorder> m = new HashMap<>();
          for (Row val : values) {
             int index = (int) val.getField(0);
             Number[] splits = (Number[]) val.getField(1);
             m.put(
                colNames[index],
                QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder(
                   colNames[index],
                   colTypes[index],
                   splits,
                   meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN),
                   binDivideType
                )
             );
          }
    
          for (int i = 0; i < colNames.length; ++i) {
             if (m.containsKey(colNames[i])) {
                continue;
             }
    
             m.put(
                colNames[i],
                QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder(
                   colNames[i],
                   colTypes[i],
                   null,
                   meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN),
                   binDivideType
                )
             );
          }
    
          QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter(m, meta);
    
          model.save(model, out);
       }
    }
    

    这里用到了 FeatureBorder 类。

    数据分箱是按照某种规则将数据进行分类。就像可以将水果按照大小进行分类,售卖不同的价格一样。

    FeatureBorder 就是专门为了 Featureborder for binning, discrete Featureborder and continuous Featureborder。

    我们能够看出来,该分箱对应的列名,index,各个分割点。

    m = {HashMap@9380}  size = 1
     "col0" -> {FeatureBorder@9438} "{"binDivideType":"QUANTILE","featureName":"col0","bin":{"NORM":[{"index":0},{"index":1},{"index":2},{"index":3},{"index":4},{"index":5}],"NULL":{"index":6}},"featureType":"BIGINT","splitsArray":[1168,1334,1501,1667,1834],"isLeftOpen":true,"binCount":6}"
    

    0x06 预测

    预测是在 QuantileDiscretizerModelMapper 中完成的。

    6.1 加载模型

    模型数据是

    model = {QuantileDiscretizerModelDataConverter@9582} 
     meta = {Params@9670} "Params {selectedCols=["col0"], version="v2", numBuckets=6}"
     data = {HashMap@9584}  size = 1
      "col0" -> {FeatureBorder@9676} "{"binDivideType":"QUANTILE","featureName":"col0","bin":{"NORM":[{"index":0},{"index":1},{"index":2},{"index":3},{"index":4},{"index":5}],"NULL":{"index":6}},"featureType":"BIGINT","splitsArray":[1168,1334,1501,1667,1834],"isLeftOpen":true,"binCount":6}"
    

    loadModel会完成加载。

    @Override
    public void loadModel(List<Row> modelRows) {
       QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter();
       model.load(modelRows);
    
       for (int i = 0; i < mapperBuilder.paramsBuilder.selectedCols.length; i++) {
          FeatureBorder border = model.data.get(mapperBuilder.paramsBuilder.selectedCols[i]);
          List<Bin.BaseBin> norm = border.bin.normBins;
          int size = norm.size();
          Long maxIndex = norm.get(0).getIndex();
          Long lastIndex = norm.get(size - 1).getIndex();
          for (int j = 0; j < norm.size(); ++j) {
             if (maxIndex < norm.get(j).getIndex()) {
                maxIndex = norm.get(j).getIndex();
             }
          }
    
          long maxIndexWithNull = Math.max(maxIndex, border.bin.nullBin.getIndex());
    
          switch (mapperBuilder.paramsBuilder.handleInvalidStrategy) {
             case KEEP:
                mapperBuilder.vectorSize.put(i, maxIndexWithNull + 1);
                break;
             case SKIP:
             case ERROR:
                mapperBuilder.vectorSize.put(i, maxIndex + 1);
                break;
             default:
                throw new UnsupportedOperationException("Unsupported now.");
          }
    
          if (mapperBuilder.paramsBuilder.dropLast) {
             mapperBuilder.dropIndex.put(i, lastIndex);
          }
    
          mapperBuilder.discretizers[i] = createQuantileDiscretizer(border, model.meta);
       }
    
       mapperBuilder.setAssembledVectorSize();
    }
    

    加载中,最后调用 createQuantileDiscretizer 生成 LongQuantileDiscretizer。这就是针对Long类型的离散器。

    public static class LongQuantileDiscretizer implements NumericQuantileDiscretizer {
       long[] bounds;
       boolean isLeftOpen;
       int[] boundIndex;
       int nullIndex;
       boolean zeroAsMissing;
    
       @Override
       public int findIndex(Object number) {
          if (number == null) {
             return nullIndex;
          }
    
          long lVal = ((Number) number).longValue();
    
          if (isMissing(lVal, zeroAsMissing)) {
             return nullIndex;
          }
    
          int hit = Arrays.binarySearch(bounds, lVal);
    
          if (isLeftOpen) {
             hit = hit >= 0 ? hit - 1 : -hit - 2;
          } else {
             hit = hit >= 0 ? hit : -hit - 2;
          }
    
          return boundIndex[hit];
       }
    }
    

    其数值如下:

    this = {QuantileDiscretizerModelMapper$LongQuantileDiscretizer@9768} 
     bounds = {long[7]@9757} 
      0 = -9223372036854775807
      1 = 1168
      2 = 1334
      3 = 1501
      4 = 1667
      5 = 1834
      6 = 9223372036854775807
     isLeftOpen = true
     boundIndex = {int[7]@9743} 
      0 = 0 // -9223372036854775807 ~ 1168 之间对应的最终分箱离散值是 0 
      1 = 1
      2 = 2
      3 = 3
      4 = 4
      5 = 5
      6 = 5 // 1834 ~ 9223372036854775807 之间对应的最终分箱离散值是 5 
     nullIndex = 6
     zeroAsMissing = false
    

    6.2 预测

    预测 QuantileDiscretizerModelMapper 的 DiscretizerMapperBuilder 完成。

    Row map(Row row){
      
    // 这里的 row 举例是: row = {Row@9743} "1003"
       for (int i = 0; i < paramsBuilder.selectedCols.length; i++) {
          int colIdxInData = selectedColIndicesInData[i];
          Object val = row.getField(colIdxInData);
          int foundIndex = discretizers[i].findIndex(val); // 找到 1003对应的index,就是调用Discretizer完成,这里找到 foundIndex 是0
          predictIndices[i] = (long) foundIndex;
       }
    
       return paramsBuilder.outputColsHelper.getResultRow(
          row,
          setResultRow(
             predictIndices,
             paramsBuilder.encode,
             dropIndex,
             vectorSize,
             paramsBuilder.dropLast,
             assembledVectorSize) // 最后返回离散值是0
       );
    }
    
    this = {QuantileDiscretizerModelMapper$DiscretizerMapperBuilder@9744} 
     paramsBuilder = {QuantileDiscretizerModelMapper$DiscretizerParamsBuilder@9752} 
     selectedColIndicesInData = {int[1]@9754} 
     vectorSize = {HashMap@9758}  size = 1
     dropIndex = {HashMap@9759}  size = 1
     assembledVectorSize = {Integer@9760} 6
     discretizers = {QuantileDiscretizerModelMapper$NumericQuantileDiscretizer[1]@9761} 
      0 = {QuantileDiscretizerModelMapper$LongQuantileDiscretizer@9768} 
       bounds = {long[7]@9776} 
       isLeftOpen = true
       boundIndex = {int[7]@9777} 
       nullIndex = 6
       zeroAsMissing = false
     predictIndices = {Long[1]@9763} 
    

    0xFF 参考

    QuantileDiscretizer的用法

    Spark QuantileDiscretizer 分位数离散器

    机器学习——数据离散化(时间离散,多值离散化,分位数,聚类法,频率区间,二值化)

    如何通俗地理解分位数?

    分位数通俗理解

    Python解释数学系列——分位数Quantile

    spark之QuantileDiscretizer源码解析

  • 相关阅读:
    消息队列RocketMQ版最佳实践订阅关系一致
    Java8 stream、List forEach 遍历对象 List 对某一字段重新赋值
    SQL的嵌套查询与连接查询
    Xshell7 个人可以申请免费使用正版
    @NotEmpty、@NotBlank、@NotNull 区别和使用
    List集合日常总结
    Time Zone(时区)
    Arrays.asList() 和Collections.singletonList()的区别
    GitBash生成SSH密钥
    Mysql中用SQL增加、删除、修改(包括字段长度/注释/字段名)总结
  • 原文地址:https://www.cnblogs.com/rossiXYZ/p/13531980.html
Copyright © 2020-2023  润新知