在spark ml pipeline的特征提取和转换阶段,有一种transformer可以将机器学习训练数据中常见的字符串列(例如表示各种分类)转换为数值索引列,以便于计算机处理。它就是StringIndexer。它支持的索引范围为[0, numLabels)(不支持的会编码为numLabels),并且支持四种排序方式,frequencyDesc(频率最高的索引赋值为0),frequencyAsc,alphabetDesc,alphabetAsc。
假设我们有dataframe
id | category ----|---------- 0 | a 1 | b 2 | c 3 | a 4 | a 5 | c
应用索引器 将category作为input 将categoryIndex作为output
id | category | categoryIndex ----|----------|--------------- 0 | a | 0.0 1 | b | 2.0 2 | c | 1.0 3 | a | 0.0 4 | a | 0.0 5 | c | 1.0
“a” gets index 0
because it is the most frequent, followed by “c” with index 1
and “b” with index 2
.
当StringIndexer遇到之前没有处理过的字符串时,对于新数据有三种处理策略
- 抛出异常 (默认)
- 跳过当前行
- 放置未知标签
如果我们使用之前生成的StringIndexer 应用于以下数据
id | category ----|---------- 0 | a 1 | b 2 | c 3 | d 4 | e
如果没有设置未知策略,或者设置为error策略,都会抛出异常,然后如果设置过setHandleInvalid("skip") 将会跳过d,e所在行
id | category | categoryIndex ----|----------|--------------- 0 | a | 0.0 1 | b | 2.0 2 | c | 1.0
如果调用setHandleInvalid("keep") 将会生成如下数据
id | category | categoryIndex ----|----------|--------------- 0 | a | 0.0 1 | b | 2.0 2 | c | 1.0 3 | d | 3.0 4 | e | 3.0
注意: “d” or “e” 所在行 都被映射为索引 “3.0”, keep设置了未知编码,而不是继续编码
scala代码示例:
import org.apache.spark.ml.feature.StringIndexer
//创建表 val df = spark.createDataFrame( Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) ).toDF("id", "category")
//创建新列索引器 val indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex")
//先fit让indexer编码df索引 然后对某一个表进行转换 这里还是本身 不会抛出异常 或者跳过 val indexed = indexer.fit(df).transform(df) indexed.show()
详细API见文档 https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.ml.feature.StringIndexer
ref: https://spark.apache.org/docs/latest/ml-features.html#stringindexer