How to max value and keep all columns (for max records per group)?
It's the perfect example for window operators (using over
function) or join
.
Since you've already figured out how to use windows, I focus on join
exclusively.
scala> val inventory = Seq(
| ("a", "pref1", "b", 168),
| ("a", "pref3", "h", 168),
| ("a", "pref3", "t", 63)).toDF("uid", "k", "v", "count")
inventory: org.apache.spark.sql.DataFrame = [uid: string, k: string ... 2 more fields]
scala> val maxCount = inventory.groupBy("uid", "k").max("count")
maxCount: org.apache.spark.sql.DataFrame = [uid: string, k: string ... 1 more field]
scala> maxCount.show
+---+-----+----------+
|uid| k|max(count)|
+---+-----+----------+
| a|pref3| 168|
| a|pref1| 168|
+---+-----+----------+
scala> val maxCount = inventory.groupBy("uid", "k").agg(max("count") as "max")
maxCount: org.apache.spark.sql.DataFrame = [uid: string, k: string ... 1 more field]
scala> maxCount.show
+---+-----+---+
|uid| k|max|
+---+-----+---+
| a|pref3|168|
| a|pref1|168|
+---+-----+---+
scala> maxCount.join(inventory, Seq("uid", "k")).where($"max" === $"count").show
+---+-----+---+---+-----+
|uid| k|max| v|count|
+---+-----+---+---+-----+
| a|pref3|168| h| 168|
| a|pref1|168| b| 168|
+---+-----+---+---+-----+
Here's the best solution I came up with so far:
val w = Window.partitionBy("uid","k").orderBy(col("count").desc)
df.withColumn("rank", dense_rank().over(w)).select("uid", "k","v","count").where("rank == 1").show
You can use window functions:
from pyspark.sql.functions import max as max_
from pyspark.sql.window import Window
w = Window.partitionBy("uid", "k")
df.withColumn("max_count", max_("count").over(w))