How to flatten columns of type array of structs (as returned by Spark ML API)?

wow, my coworker came up with an extremely elegant solution:

scala> recs.select($"userId", $"recommendations.itemId").show
+------+------+
|userId|itemId|
+------+------+
|     1|[1, 2]|
|     2|[0, 4]|
+------+------+

So maybe the Spark ML API isn't that difficult after all :)


With an array as the type of a column, e.g. recommendations, you'd be quite productive using explode function (or the more advanced flatMap operator).

explode(e: Column): Column Creates a new row for each element in the given array or map column.

That gives you bare structs to work with.

import org.apache.spark.sql.types._
val structType = new StructType().
  add($"itemId".int).
  add($"rating".float)
val arrayType = ArrayType(structType)
val recs = Seq((1, Array((1, .7), (2, .5))), (2, Array((0, .9), (4, .1)))).
  toDF("userId", "recommendations").
  select($"userId", $"recommendations" cast arrayType)

val exploded = recs.withColumn("recs", explode($"recommendations"))
scala> exploded.show
+------+------------------+-------+
|userId|   recommendations|   recs|
+------+------------------+-------+
|     1|[[1,0.7], [2,0.5]]|[1,0.7]|
|     1|[[1,0.7], [2,0.5]]|[2,0.5]|
|     2|[[0,0.9], [4,0.1]]|[0,0.9]|
|     2|[[0,0.9], [4,0.1]]|[4,0.1]|
+------+------------------+-------+

structs are nice in select operator with * (star) to flatten them to columns per the struct fields.

You could do select($"element.*").

scala> exploded.select("userId", "recs.*").show
+------+------+------+
|userId|itemId|rating|
+------+------+------+
|     1|     1|   0.7|
|     1|     2|   0.5|
|     2|     0|   0.9|
|     2|     4|   0.1|
+------+------+------+

I think that could do what you're after.


p.s. Stay away from UDFs as long as possible since they "trigger" row conversion from the internal format (InternalRow) to JVM objects that can lead to excessive GCs.