Spark, ML, StringIndexer: handling unseen labels
There's a way around this in Spark 1.6.
Here's the jira: https://issues.apache.org/jira/browse/SPARK-8764
Here's an example:
val categoryIndexerModel = new StringIndexer()
.setInputCol("category")
.setOutputCol("indexedCategory")
.setHandleInvalid("skip") // new method. values are "error" or "skip"
I started using this, but ended up going back to KrisP's 2nd bullet point about fitting this particular Estimator to the full dataset.
You'll need this later in the pipeline when you convert the IndexToString.
Here's the modified example:
val categoryIndexerModel = new StringIndexer()
.setInputCol("category")
.setOutputCol("indexedCategory")
.fit(itemsDF) // Fit the Estimator and create a Model (Transformer)
... do some kind of classification ...
val categoryReverseIndexer = new IndexToString()
.setInputCol(classifier.getPredictionCol)
.setOutputCol("predictedCategory")
.setLabels(categoryIndexerModel.labels) // Use the labels from the Model
With Spark 2.2 (released 7-2017) you are able to use the .setHandleInvalid("keep")
option when creating the indexer. With this option, the indexer adds new indexes when it sees new labels.
val categoryIndexerModel = new StringIndexer()
.setInputCol("category")
.setOutputCol("indexedCategory")
.setHandleInvalid("keep") // options are "keep", "error" or "skip"
From the documentation: there are three strategies regarding how StringIndexer will handle unseen labels when you have fit a StringIndexer on one dataset and then use it to transform another:
- 'error': throws an exception (which is the default)
- 'skip': skips the rows containing the unseen labels entirely (removes the rows on the output!)
- 'keep': puts unseen labels in a special additional bucket, at index numLabels
Please see the linked documentation for examples on how the output of StringIndexer looks for the different options.
In my case, I was running spark ALS on a large data set and the data was not available at all partitions so I had to cache() the data appropriately and it worked like a charm
No nice way to do it, I'm afraid. Either
- filter out the test examples with unknown labels before applying
StringIndexer
- or fit
StringIndexer
to the union of train and test dataframe, so you are assured all labels are there - or transform the test example case with unknown label to a known label
Here is some sample code to perform above operations:
// get training labels from original train dataframe
val trainlabels = traindf.select(colname).distinct.map(_.getString(0)).collect //Array[String]
// or get labels from a trained StringIndexer model
val trainlabels = simodel.labels
// define an UDF on your dataframe that will be used for filtering
val filterudf = udf { label:String => trainlabels.contains(label)}
// filter out the bad examples
val filteredTestdf = testdf.filter( filterudf(testdf(colname)))
// transform unknown value to some value, say "a"
val mapudf = udf { label:String => if (trainlabels.contains(label)) label else "a"}
// add a new column to testdf:
val transformedTestdf = testdf.withColumn( "newcol", mapudf(testdf(colname)))