• Spark MLlib 之 StringIndexer、IndexToString使用说明以及源码剖析


    最近在用Spark MLlib进行特征处理时,对于StringIndexer和IndexToString遇到了点问题,查阅官方文档也没有解决疑惑。无奈之下翻看源码才明白其中一二...这就给大家娓娓道来。

    更多内容参考我的大数据学习之路

    文档说明

    StringIndexer 字符串转索引

    StringIndexer可以把字符串的列按照出现频率进行排序,出现次数最高的对应的Index为0。比如下面的列表进行StringIndexer

    id category
    0 a
    1 b
    2 c
    3 a
    4 a
    5 c

    就可以得到如下:

    id category categoryIndex
    0 a 0.0
    1 b 2.0
    2 c 1.0
    3 a 0.0
    4 a 0.0
    5 c 1.0

    可以看到出现次数最多的"a",索引为0;次数最少的"b"索引为2。

    针对训练集中没有出现的字符串值,spark提供了几种处理的方法:

    • error,直接抛出异常
    • skip,跳过该样本数据
    • keep,使用一个新的最大索引,来表示所有未出现的值

    下面是基于Spark MLlib 2.2.0的代码样例:

    package xingoo.ml.features.tranformer
    
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.ml.feature.StringIndexer
    
    object StringIndexerTest {
      def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().master("local[*]").appName("string-indexer").getOrCreate()
        spark.sparkContext.setLogLevel("WARN")
    
        val df = spark.createDataFrame(
          Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
        ).toDF("id", "category")
    
        val df1 = spark.createDataFrame(
          Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "e"), (5, "f"))
        ).toDF("id", "category")
    
        val indexer = new StringIndexer()
          .setInputCol("category")
          .setOutputCol("categoryIndex")
          .setHandleInvalid("keep") //skip keep error
    
        val model = indexer.fit(df)
    
        val indexed = model.transform(df1)
        indexed.show(false)
      }
    }
    

    得到的结果为:

    +---+--------+-------------+
    |id |category|categoryIndex|
    +---+--------+-------------+
    |0  |a       |0.0          |
    |1  |b       |2.0          |
    |2  |c       |1.0          |
    |3  |a       |0.0          |
    |4  |e       |3.0          |
    |5  |f       |3.0          |
    +---+--------+-------------+
    

    IndexToString 索引转字符串

    这个索引转回字符串要搭配前面的StringIndexer一起使用才行:

    package xingoo.ml.features.tranformer
    
    import org.apache.spark.ml.attribute.Attribute
    import org.apache.spark.ml.feature.{IndexToString, StringIndexer}
    import org.apache.spark.sql.SparkSession
    
    object IndexToString2 {
      def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate()
        spark.sparkContext.setLogLevel("WARN")
    
        val df = spark.createDataFrame(Seq(
          (0, "a"),
          (1, "b"),
          (2, "c"),
          (3, "a"),
          (4, "a"),
          (5, "c")
        )).toDF("id", "category")
    
        val indexer = new StringIndexer()
          .setInputCol("category")
          .setOutputCol("categoryIndex")
          .fit(df)
        val indexed = indexer.transform(df)
    
        println(s"Transformed string column '${indexer.getInputCol}' " +
          s"to indexed column '${indexer.getOutputCol}'")
        indexed.show()
    
        val inputColSchema = indexed.schema(indexer.getOutputCol)
        println(s"StringIndexer will store labels in output column metadata: " +
          s"${Attribute.fromStructField(inputColSchema).toString}
    ")
    
        val converter = new IndexToString()
          .setInputCol("categoryIndex")
          .setOutputCol("originalCategory")
    
        val converted = converter.transform(indexed)
    
        println(s"Transformed indexed column '${converter.getInputCol}' back to original string " +
          s"column '${converter.getOutputCol}' using labels in metadata")
        converted.select("id", "categoryIndex", "originalCategory").show()
      }
    }
    

    得到的结果如下:

    Transformed string column 'category' to indexed column 'categoryIndex'
    +---+--------+-------------+
    | id|category|categoryIndex|
    +---+--------+-------------+
    |  0|       a|          0.0|
    |  1|       b|          2.0|
    |  2|       c|          1.0|
    |  3|       a|          0.0|
    |  4|       a|          0.0|
    |  5|       c|          1.0|
    +---+--------+-------------+
    
    StringIndexer will store labels in output column metadata: {"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}
    
    Transformed indexed column 'categoryIndex' back to original string column 'originalCategory' using labels in metadata
    +---+-------------+----------------+
    | id|categoryIndex|originalCategory|
    +---+-------------+----------------+
    |  0|          0.0|               a|
    |  1|          2.0|               b|
    |  2|          1.0|               c|
    |  3|          0.0|               a|
    |  4|          0.0|               a|
    |  5|          1.0|               c|
    +---+-------------+----------------+
    

    使用问题

    假如处理的过程很复杂,重新生成了一个DataFrame,此时想要把这个DataFrame基于IndexToString转回原来的字符串怎么办呢? 先来试试看:

    package xingoo.ml.features.tranformer
    
    import org.apache.spark.ml.feature.{IndexToString, StringIndexer}
    import org.apache.spark.sql.SparkSession
    
    object IndexToString3 {
      def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate()
        spark.sparkContext.setLogLevel("WARN")
    
        val df = spark.createDataFrame(Seq(
          (0, "a"),
          (1, "b"),
          (2, "c"),
          (3, "a"),
          (4, "a"),
          (5, "c")
        )).toDF("id", "category")
    
        val df2 = spark.createDataFrame(Seq(
          (0, 2.0),
          (1, 1.0),
          (2, 1.0),
          (3, 0.0)
        )).toDF("id", "index")
    
        val indexer = new StringIndexer()
          .setInputCol("category")
          .setOutputCol("categoryIndex")
          .fit(df)
        val indexed = indexer.transform(df)
    
        val converter = new IndexToString()
          .setInputCol("categoryIndex")
          .setOutputCol("originalCategory")
    
        val converted = converter.transform(df2)
        converted.show()
      }
    }
    

    运行后发现异常:

    18/07/05 20:20:32 INFO StateStoreCoordinatorRef: Registered StateStoreCoordinator endpoint
    Exception in thread "main" java.lang.IllegalArgumentException: Field "categoryIndex" does not exist.
    	at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
    	at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
    	at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)
    	at scala.collection.AbstractMap.getOrElse(Map.scala:59)
    	at org.apache.spark.sql.types.StructType.apply(StructType.scala:265)
    	at org.apache.spark.ml.feature.IndexToString.transformSchema(StringIndexer.scala:338)
    	at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:74)
    	at org.apache.spark.ml.feature.IndexToString.transform(StringIndexer.scala:352)
    	at xingoo.ml.features.tranformer.IndexToString3$.main(IndexToString3.scala:37)
    	at xingoo.ml.features.tranformer.IndexToString3.main(IndexToString3.scala)
    

    这是为什么呢?跟随源码来看吧!

    源码剖析

    首先我们创建一个DataFrame,获得原始数据:

    val df = spark.createDataFrame(Seq(
          (0, "a"),
          (1, "b"),
          (2, "c"),
          (3, "a"),
          (4, "a"),
          (5, "c")
        )).toDF("id", "category")
    

    然后创建对应的StringIndexer:

    val indexer = new StringIndexer()
          .setInputCol("category")
          .setOutputCol("categoryIndex")
          .setHandleInvalid("skip")
          .fit(df)
    

    这里面的fit就是在训练转换器了,进入fit():

    override def fit(dataset: Dataset[_]): StringIndexerModel = {
        transformSchema(dataset.schema, logging = true)
        // 这里针对需要转换的列先强制转换成字符串,然后遍历统计每个字符串出现的次数
        val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
          .rdd
          .map(_.getString(0))
          .countByValue()
        // counts是一个map,里面的内容为{a->3, b->1, c->2}
        val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
        // 按照个数大小排序,返回数组,[a, c, b]
        // 把这个label保存起来,并返回对应的model(mllib里边的模型都是这个套路,跟sklearn学的)
        copyValues(new StringIndexerModel(uid, labels).setParent(this))
      }
    

    这样就得到了一个列表,列表里面的内容是[a, c, b],然后执行transform来进行转换:

    val indexed = indexer.transform(df)
    

    这个transform可想而知就是用这个数组对每一行的该列进行转换,但是它其实还做了其他的事情:

    override def transform(dataset: Dataset[_]): DataFrame = {
        ...
        // --------
    	// 通过label生成一个Metadata,这个很关键!!!
    	// metadata其实是一个map,内容为:
    	// {"ml_attr":{"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}}
    	// --------
        val metadata = NominalAttribute.defaultAttr
          .withName($(outputCol)).withValues(filteredLabels).toMetadata()
        
        // 如果是skip则过滤一些数据
        ...
        
        // 下面是针对不同的情况处理转换的列,逻辑很简单
        val indexer = udf { label: String =>
          ...
          if (labelToIndex.contains(label)) {
              labelToIndex(label) //如果正常,就进行转换
            } else if (keepInvalid) {
              labels.length // 如果是keep,就返回索引的最大值(即数组的长度)
            } else {
              ... // 如果是error,就抛出异常
            }
        }
    
    	// 保留之前所有的列,新增一个字段,并设置字段的StructField中的Metadata!!!!
    	// 并设置字段的StructField中的Metadata!!!!
    	// 并设置字段的StructField中的Metadata!!!!
    	// 并设置字段的StructField中的Metadata!!!!
    	
        filteredDataset.select(col("*"),
          indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
      }
    

    看到了吗!关键的地方在这里,给新增加的字段的类型StructField设置了一个Metadata。这个Metadata正常都是空的{},但是这里设置了metadata之后,里面包含了label数组的信息。

    接下来看看IndexToString是怎么用的,由于IndexToString是一个Transformer,因此只有一个trasform方法:

    override def transform(dataset: Dataset[_]): DataFrame = {
        transformSchema(dataset.schema, logging = true)
        val inputColSchema = dataset.schema($(inputCol))
        
        // If the labels array is empty use column metadata
        // 关键是这里:
        // 如果IndexToString设置了labels数组,就直接返回;
        // 否则,就读取了传入的DataFrame的StructField中的Metadata
        val values = if (!isDefined(labels) || $(labels).isEmpty) {
          Attribute.fromStructField(inputColSchema)
            .asInstanceOf[NominalAttribute].values.get
        } else {
          $(labels)
        }
    
    	// 基于这个values把index转成对应的值
        val indexer = udf { index: Double =>
          val idx = index.toInt
          if (0 <= idx && idx < values.length) {
            values(idx)
          } else {
            throw new SparkException(s"Unseen index: $index ??")
          }
        }
        val outputColName = $(outputCol)
        dataset.select(col("*"),
          indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
      }
    

    了解StringIndexer和IndexToString的原理机制后,就可以作出如下的应对策略了。

    1 增加StructField的MetaData信息

     val df2 = spark.createDataFrame(Seq(
          (0, 2.0),
          (1, 1.0),
          (2, 1.0),
          (3, 0.0)
        )).toDF("id", "index").select(col("*"),col("index").as("formated_index", indexed.schema("categoryIndex").metadata))
    
        val converter = new IndexToString()
          .setInputCol("formated_index")
          .setOutputCol("origin_col")
    
        val converted = converter.transform(df2)
        converted.show(false)
    
    +---+-----+--------------+----------+
    |id |index|formated_index|origin_col|
    +---+-----+--------------+----------+
    |0  |2.0  |2.0           |b         |
    |1  |1.0  |1.0           |c         |
    |2  |1.0  |1.0           |c         |
    |3  |0.0  |0.0           |a         |
    +---+-----+--------------+----------+
    

    2 获取之前StringIndexer后的DataFrame中的Label信息

        val df3 = spark.createDataFrame(Seq(
          (0, 2.0),
          (1, 1.0),
          (2, 1.0),
          (3, 0.0)
        )).toDF("id", "index")
    
        val converter2 = new IndexToString()
          .setInputCol("index")
          .setOutputCol("origin_col")
          .setLabels(indexed.schema("categoryIndex").metadata.getMetadata("ml_attr").getStringArray("vals"))
    
        val converted2 = converter2.transform(df3)
        converted2.show(false)
    
    +---+-----+----------+
    |id |index|origin_col|
    +---+-----+----------+
    |0  |2.0  |b         |
    |1  |1.0  |c         |
    |2  |1.0  |c         |
    |3  |0.0  |a         |
    +---+-----+----------+
    

    两种方法都能得到正确的输出。

    完整的代码可以参考github链接:

    https://github.com/xinghalo/spark-in-action/blob/master/src/xingoo/ml/features/tranformer/IndexToStringTest.scala

    最终还是推荐详细阅读官方文档,不过官方文档真心有些粗糙,想要了解其中的原理,还是得静下心来看看源码。

  • 相关阅读:
    Redis基础
    MySQL基础
    MySQL基础
    MySQL基础
    MySQL基础
    Hello 博客园
    Linux | 常用命令
    JVM | 性能调优
    JVM | 垃圾回收
    学习笔记 | 分布式技术
  • 原文地址:https://www.cnblogs.com/xing901022/p/9270485.html
Copyright © 2020-2023  润新知