How to flatten columns of type array of structs (as returned by Spark ML API)?
wow, my coworker came up with an extremely elegant solution:
scala> recs.select($"userId", $"recommendations.itemId").show
+------+------+
|userId|itemId|
+------+------+
| 1|[1, 2]|
| 2|[0, 4]|
+------+------+
So maybe the Spark ML API isn't that difficult after all :)
With an array as the type of a column, e.g. recommendations
, you'd be quite productive using explode function (or the more advanced flatMap operator).
explode(e: Column): Column Creates a new row for each element in the given array or map column.
That gives you bare structs to work with.
import org.apache.spark.sql.types._
val structType = new StructType().
add($"itemId".int).
add($"rating".float)
val arrayType = ArrayType(structType)
val recs = Seq((1, Array((1, .7), (2, .5))), (2, Array((0, .9), (4, .1)))).
toDF("userId", "recommendations").
select($"userId", $"recommendations" cast arrayType)
val exploded = recs.withColumn("recs", explode($"recommendations"))
scala> exploded.show
+------+------------------+-------+
|userId| recommendations| recs|
+------+------------------+-------+
| 1|[[1,0.7], [2,0.5]]|[1,0.7]|
| 1|[[1,0.7], [2,0.5]]|[2,0.5]|
| 2|[[0,0.9], [4,0.1]]|[0,0.9]|
| 2|[[0,0.9], [4,0.1]]|[4,0.1]|
+------+------------------+-------+
structs are nice in select
operator with *
(star) to flatten them to columns per the struct fields.
You could do select($"element.*")
.
scala> exploded.select("userId", "recs.*").show
+------+------+------+
|userId|itemId|rating|
+------+------+------+
| 1| 1| 0.7|
| 1| 2| 0.5|
| 2| 0| 0.9|
| 2| 4| 0.1|
+------+------+------+
I think that could do what you're after.
p.s. Stay away from UDFs as long as possible since they "trigger" row conversion from the internal format (InternalRow
) to JVM objects that can lead to excessive GCs.