博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Spark MLlib 之 StringIndexer、IndexToString使用说明以及源码剖析
阅读量:6153 次
发布时间:2019-06-21

本文共 10704 字,大约阅读时间需要 35 分钟。

最近在用Spark MLlib进行特征处理时,对于StringIndexer和IndexToString遇到了点问题,查阅官方文档也没有解决疑惑。无奈之下翻看源码才明白其中一二...这就给大家娓娓道来。

更多内容参考

文档说明

StringIndexer 字符串转索引

StringIndexer可以把字符串的列按照出现频率进行排序,出现次数最高的对应的Index为0。比如下面的列表进行StringIndexer

id category
0 a
1 b
2 c
3 a
4 a
5 c

就可以得到如下:

id category categoryIndex
0 a 0.0
1 b 2.0
2 c 1.0
3 a 0.0
4 a 0.0
5 c 1.0

可以看到出现次数最多的"a",索引为0;次数最少的"b"索引为2。

针对训练集中没有出现的字符串值,spark提供了几种处理的方法:

  • error,直接抛出异常
  • skip,跳过该样本数据
  • keep,使用一个新的最大索引,来表示所有未出现的值

下面是基于Spark MLlib 2.2.0的代码样例:

package xingoo.ml.features.tranformerimport org.apache.spark.sql.SparkSessionimport org.apache.spark.ml.feature.StringIndexerobject StringIndexerTest {  def main(args: Array[String]): Unit = {    val spark = SparkSession.builder().master("local[*]").appName("string-indexer").getOrCreate()    spark.sparkContext.setLogLevel("WARN")    val df = spark.createDataFrame(      Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))    ).toDF("id", "category")    val df1 = spark.createDataFrame(      Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "e"), (5, "f"))    ).toDF("id", "category")    val indexer = new StringIndexer()      .setInputCol("category")      .setOutputCol("categoryIndex")      .setHandleInvalid("keep") //skip keep error    val model = indexer.fit(df)    val indexed = model.transform(df1)    indexed.show(false)  }}

得到的结果为:

+---+--------+-------------+|id |category|categoryIndex|+---+--------+-------------+|0  |a       |0.0          ||1  |b       |2.0          ||2  |c       |1.0          ||3  |a       |0.0          ||4  |e       |3.0          ||5  |f       |3.0          |+---+--------+-------------+

IndexToString 索引转字符串

这个索引转回字符串要搭配前面的StringIndexer一起使用才行:

package xingoo.ml.features.tranformerimport org.apache.spark.ml.attribute.Attributeimport org.apache.spark.ml.feature.{IndexToString, StringIndexer}import org.apache.spark.sql.SparkSessionobject IndexToString2 {  def main(args: Array[String]): Unit = {    val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate()    spark.sparkContext.setLogLevel("WARN")    val df = spark.createDataFrame(Seq(      (0, "a"),      (1, "b"),      (2, "c"),      (3, "a"),      (4, "a"),      (5, "c")    )).toDF("id", "category")    val indexer = new StringIndexer()      .setInputCol("category")      .setOutputCol("categoryIndex")      .fit(df)    val indexed = indexer.transform(df)    println(s"Transformed string column '${indexer.getInputCol}' " +      s"to indexed column '${indexer.getOutputCol}'")    indexed.show()    val inputColSchema = indexed.schema(indexer.getOutputCol)    println(s"StringIndexer will store labels in output column metadata: " +      s"${Attribute.fromStructField(inputColSchema).toString}\n")    val converter = new IndexToString()      .setInputCol("categoryIndex")      .setOutputCol("originalCategory")    val converted = converter.transform(indexed)    println(s"Transformed indexed column '${converter.getInputCol}' back to original string " +      s"column '${converter.getOutputCol}' using labels in metadata")    converted.select("id", "categoryIndex", "originalCategory").show()  }}

得到的结果如下:

Transformed string column 'category' to indexed column 'categoryIndex'+---+--------+-------------+| id|category|categoryIndex|+---+--------+-------------+|  0|       a|          0.0||  1|       b|          2.0||  2|       c|          1.0||  3|       a|          0.0||  4|       a|          0.0||  5|       c|          1.0|+---+--------+-------------+StringIndexer will store labels in output column metadata: {"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}Transformed indexed column 'categoryIndex' back to original string column 'originalCategory' using labels in metadata+---+-------------+----------------+| id|categoryIndex|originalCategory|+---+-------------+----------------+|  0|          0.0|               a||  1|          2.0|               b||  2|          1.0|               c||  3|          0.0|               a||  4|          0.0|               a||  5|          1.0|               c|+---+-------------+----------------+

使用问题

假如处理的过程很复杂,重新生成了一个DataFrame,此时想要把这个DataFrame基于IndexToString转回原来的字符串怎么办呢? 先来试试看:

package xingoo.ml.features.tranformerimport org.apache.spark.ml.feature.{IndexToString, StringIndexer}import org.apache.spark.sql.SparkSessionobject IndexToString3 {  def main(args: Array[String]): Unit = {    val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate()    spark.sparkContext.setLogLevel("WARN")    val df = spark.createDataFrame(Seq(      (0, "a"),      (1, "b"),      (2, "c"),      (3, "a"),      (4, "a"),      (5, "c")    )).toDF("id", "category")    val df2 = spark.createDataFrame(Seq(      (0, 2.0),      (1, 1.0),      (2, 1.0),      (3, 0.0)    )).toDF("id", "index")    val indexer = new StringIndexer()      .setInputCol("category")      .setOutputCol("categoryIndex")      .fit(df)    val indexed = indexer.transform(df)    val converter = new IndexToString()      .setInputCol("categoryIndex")      .setOutputCol("originalCategory")    val converted = converter.transform(df2)    converted.show()  }}

运行后发现异常:

18/07/05 20:20:32 INFO StateStoreCoordinatorRef: Registered StateStoreCoordinator endpointException in thread "main" java.lang.IllegalArgumentException: Field "categoryIndex" does not exist.    at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)    at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)    at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)    at scala.collection.AbstractMap.getOrElse(Map.scala:59)    at org.apache.spark.sql.types.StructType.apply(StructType.scala:265)    at org.apache.spark.ml.feature.IndexToString.transformSchema(StringIndexer.scala:338)    at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:74)    at org.apache.spark.ml.feature.IndexToString.transform(StringIndexer.scala:352)    at xingoo.ml.features.tranformer.IndexToString3$.main(IndexToString3.scala:37)    at xingoo.ml.features.tranformer.IndexToString3.main(IndexToString3.scala)

这是为什么呢?跟随源码来看吧!

源码剖析

首先我们创建一个DataFrame,获得原始数据:

val df = spark.createDataFrame(Seq(      (0, "a"),      (1, "b"),      (2, "c"),      (3, "a"),      (4, "a"),      (5, "c")    )).toDF("id", "category")

然后创建对应的StringIndexer:

val indexer = new StringIndexer()      .setInputCol("category")      .setOutputCol("categoryIndex")      .setHandleInvalid("skip")      .fit(df)

这里面的fit就是在训练转换器了,进入fit():

override def fit(dataset: Dataset[_]): StringIndexerModel = {    transformSchema(dataset.schema, logging = true)    // 这里针对需要转换的列先强制转换成字符串,然后遍历统计每个字符串出现的次数    val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))      .rdd      .map(_.getString(0))      .countByValue()    // counts是一个map,里面的内容为{a->3, b->1, c->2}    val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray    // 按照个数大小排序,返回数组,[a, c, b]    // 把这个label保存起来,并返回对应的model(mllib里边的模型都是这个套路,跟sklearn学的)    copyValues(new StringIndexerModel(uid, labels).setParent(this))  }

这样就得到了一个列表,列表里面的内容是[a, c, b],然后执行transform来进行转换:

val indexed = indexer.transform(df)

这个transform可想而知就是用这个数组对每一行的该列进行转换,但是它其实还做了其他的事情:

override def transform(dataset: Dataset[_]): DataFrame = {    ...    // --------    // 通过label生成一个Metadata,这个很关键!!!    // metadata其实是一个map,内容为:    // {"ml_attr":{"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}}    // --------    val metadata = NominalAttribute.defaultAttr      .withName($(outputCol)).withValues(filteredLabels).toMetadata()        // 如果是skip则过滤一些数据    ...        // 下面是针对不同的情况处理转换的列,逻辑很简单    val indexer = udf { label: String =>      ...      if (labelToIndex.contains(label)) {          labelToIndex(label) //如果正常,就进行转换        } else if (keepInvalid) {          labels.length // 如果是keep,就返回索引的最大值(即数组的长度)        } else {          ... // 如果是error,就抛出异常        }    }    // 保留之前所有的列,新增一个字段,并设置字段的StructField中的Metadata!!!!    // 并设置字段的StructField中的Metadata!!!!    // 并设置字段的StructField中的Metadata!!!!    // 并设置字段的StructField中的Metadata!!!!        filteredDataset.select(col("*"),      indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))  }

看到了吗!关键的地方在这里,给新增加的字段的类型StructField设置了一个Metadata。这个Metadata正常都是空的{},但是这里设置了metadata之后,里面包含了label数组的信息。

接下来看看IndexToString是怎么用的,由于IndexToString是一个Transformer,因此只有一个trasform方法:

override def transform(dataset: Dataset[_]): DataFrame = {    transformSchema(dataset.schema, logging = true)    val inputColSchema = dataset.schema($(inputCol))        // If the labels array is empty use column metadata    // 关键是这里:    // 如果IndexToString设置了labels数组,就直接返回;    // 否则,就读取了传入的DataFrame的StructField中的Metadata    val values = if (!isDefined(labels) || $(labels).isEmpty) {      Attribute.fromStructField(inputColSchema)        .asInstanceOf[NominalAttribute].values.get    } else {      $(labels)    }    // 基于这个values把index转成对应的值    val indexer = udf { index: Double =>      val idx = index.toInt      if (0 <= idx && idx < values.length) {        values(idx)      } else {        throw new SparkException(s"Unseen index: $index ??")      }    }    val outputColName = $(outputCol)    dataset.select(col("*"),      indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))  }

了解StringIndexer和IndexToString的原理机制后,就可以作出如下的应对策略了。

1 增加StructField的MetaData信息

val df2 = spark.createDataFrame(Seq(      (0, 2.0),      (1, 1.0),      (2, 1.0),      (3, 0.0)    )).toDF("id", "index").select(col("*"),col("index").as("formated_index", indexed.schema("categoryIndex").metadata))    val converter = new IndexToString()      .setInputCol("formated_index")      .setOutputCol("origin_col")    val converted = converter.transform(df2)    converted.show(false)
+---+-----+--------------+----------+|id |index|formated_index|origin_col|+---+-----+--------------+----------+|0  |2.0  |2.0           |b         ||1  |1.0  |1.0           |c         ||2  |1.0  |1.0           |c         ||3  |0.0  |0.0           |a         |+---+-----+--------------+----------+

2 获取之前StringIndexer后的DataFrame中的Label信息

val df3 = spark.createDataFrame(Seq(      (0, 2.0),      (1, 1.0),      (2, 1.0),      (3, 0.0)    )).toDF("id", "index")    val converter2 = new IndexToString()      .setInputCol("index")      .setOutputCol("origin_col")      .setLabels(indexed.schema("categoryIndex").metadata.getMetadata("ml_attr").getStringArray("vals"))    val converted2 = converter2.transform(df3)    converted2.show(false)
+---+-----+----------+|id |index|origin_col|+---+-----+----------+|0  |2.0  |b         ||1  |1.0  |c         ||2  |1.0  |c         ||3  |0.0  |a         |+---+-----+----------+

两种方法都能得到正确的输出。

完整的代码可以参考github链接:

最终还是推荐详细阅读,不过官方文档真心有些粗糙,想要了解其中的原理,还是得静下心来看看源码。

转载地址:http://oywfa.baihongyu.com/

你可能感兴趣的文章
linux下项目开发加载动态库:ldconfig与 /etc/ld.so.conf
查看>>
Ubuntu server 搭建Git server
查看>>
搭建自己的OpenWrt开发环境
查看>>
shell的变量输入read讲解与实战
查看>>
Android源码资料
查看>>
在一个多模块的python项目中,如何在子模块中引用项目的根目录?
查看>>
Nginx
查看>>
FastDFS 分布式文件系统 搭建部署
查看>>
.NET(C#) Internals: 以一个数组填充的例子初步了解.NET 4.0中的并行(一)
查看>>
Oracle体系结构之SQL语句的执行过程
查看>>
Linux修改yum源为阿里云、网易、中国科技大学
查看>>
面向对象之绑定方法与非绑定方法(day7)
查看>>
Docker数据卷备份恢复、桥接网络设置
查看>>
awk命令
查看>>
系统架构师-基础到企业应用架构-系统设计规范与原则[下篇]
查看>>
安装loadrunner时出现”命令行选项语法错误键入命令 \?获得帮助“的解决方法
查看>>
傻瓜入侵
查看>>
docker1.9网络新特性,overlay网络实现主机间容器互联
查看>>
关于网络层的负载均衡和热备
查看>>
CUDNN学习笔记(4)
查看>>