Spark Dataframe groupBy and sort results into a list
Just wanted to add another hint to the answer of Daniel de Paula regarding sort_array
solution.
If you want to sort elements according to a different column, you can form a struct of two fields:
- the sort by field
- the result field
Since structs are sorted field by field, you'll get the order you want, all you need is to get rid of the sort by column in each element of the resulting list.
The same approach can be applied with several sort by columns when needed.
Here's an example that can be run in local spark-shell
(use :paste
mode):
import org.apache.spark.sql.Row
import spark.implicits._
case class Employee(name: String, department: String, salary: Double)
val employees = Seq(
Employee("JSMITH", "A", 20.0),
Employee("AJOHNSON", "A", 650.0),
Employee("CBAKER", "A", 650.2),
Employee("TGREEN", "A", 13.0),
Employee("CHORTON", "B", 111.0),
Employee("AIVANOV", "B", 233.0),
Employee("VSMIRNOV", "B", 11.0)
)
val employeesDF = spark.createDataFrame(employees)
val getNames = udf { salaryNames: Seq[Row] =>
salaryNames.map { case Row(_: Double, name: String) => name }
}
employeesDF
.groupBy($"department")
.agg(collect_list(struct($"salary", $"name")).as("salaryNames"))
.withColumn("namesSortedBySalary", getNames(sort_array($"salaryNames", asc = false)))
.show(truncate = false)
The result:
+----------+--------------------------------------------------------------------+----------------------------------+
|department|salaryNames |namesSortedBySalary |
+----------+--------------------------------------------------------------------+----------------------------------+
|B |[[111.0, CHORTON], [233.0, AIVANOV], [11.0, VSMIRNOV]] |[AIVANOV, CHORTON, VSMIRNOV] |
|A |[[20.0, JSMITH], [650.0, AJOHNSON], [650.2, CBAKER], [13.0, TGREEN]]|[CBAKER, AJOHNSON, JSMITH, TGREEN]|
+----------+--------------------------------------------------------------------+----------------------------------+
You could try the function sort_array
available in the functions package:
import org.apache.spark.sql.functions._
df.groupBy("columnA").agg(sort_array(collect_list("columnB")))