Aggregating multiple columns with custom function in Spark
Consider using the struct
function to group the columns together before collecting as a list:
import org.apache.spark.sql.functions.{collect_list, struct}
import sqlContext.implicits._
val df = Seq(
("john", "tomato", 1.99),
("john", "carrot", 0.45),
("bill", "apple", 0.99),
("john", "banana", 1.29),
("bill", "taco", 2.59)
).toDF("name", "food", "price")
df.groupBy($"name")
.agg(collect_list(struct($"food", $"price")).as("foods"))
.show(false)
Outputs:
+----+---------------------------------------------+
|name|foods |
+----+---------------------------------------------+
|john|[[tomato,1.99], [carrot,0.45], [banana,1.29]]|
|bill|[[apple,0.99], [taco,2.59]] |
+----+---------------------------------------------+
The easiest way to do this as a DataFrame
is to first collect two lists, and then use a UDF
to zip
the two lists together. Something like:
import org.apache.spark.sql.functions.{collect_list, udf}
import sqlContext.implicits._
val zipper = udf[Seq[(String, Double)], Seq[String], Seq[Double]](_.zip(_))
val df = Seq(
("john", "tomato", 1.99),
("john", "carrot", 0.45),
("bill", "apple", 0.99),
("john", "banana", 1.29),
("bill", "taco", 2.59)
).toDF("name", "food", "price")
val df2 = df.groupBy("name").agg(
collect_list(col("food")) as "food",
collect_list(col("price")) as "price"
).withColumn("food", zipper(col("food"), col("price"))).drop("price")
df2.show(false)
# +----+---------------------------------------------+
# |name|food |
# +----+---------------------------------------------+
# |john|[[tomato,1.99], [carrot,0.45], [banana,1.29]]|
# |bill|[[apple,0.99], [taco,2.59]] |
# +----+---------------------------------------------+