Find maximum row per group in Spark DataFrame
I think what you might be looking for are window functions: http://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=window#pyspark.sql.Window
https://databricks.com/blog/2015/07/15/introducing-window-functions-in-spark-sql.html
Here is an example in Scala (I don't have a Spark Shell with Hive available right now, so I was not able to test the code, but I think it should work):
case class MyRow(name: String, id_sa: String, id_sb: String)
val myDF = sc.parallelize(Array(
MyRow("n1", "a1", "b1"),
MyRow("n2", "a1", "b2"),
MyRow("n3", "a1", "b2"),
MyRow("n1", "a2", "b2")
)).toDF("name", "id_sa", "id_sb")
import org.apache.spark.sql.expressions.Window
val windowSpec = Window.partitionBy(myDF("id_sa")).orderBy(myDF("id_sb").desc)
myDF.withColumn("max_id_b", first(myDF("id_sb")).over(windowSpec).as("max_id_sb")).filter("id_sb = max_id_sb")
There are probably more efficient ways to achieve the same results with Window functions, but I hope this points you in the right direction.
Using join
(it will result in more than one row in group in case of ties):
import pyspark.sql.functions as F
from pyspark.sql.functions import count, col
cnts = df.groupBy("id_sa", "id_sb").agg(count("*").alias("cnt")).alias("cnts")
maxs = cnts.groupBy("id_sa").agg(F.max("cnt").alias("mx")).alias("maxs")
cnts.join(maxs,
(col("cnt") == col("mx")) & (col("cnts.id_sa") == col("maxs.id_sa"))
).select(col("cnts.id_sa"), col("cnts.id_sb"))
Using window functions (will drop ties):
from pyspark.sql.functions import row_number
from pyspark.sql.window import Window
w = Window().partitionBy("id_sa").orderBy(col("cnt").desc())
(cnts
.withColumn("rn", row_number().over(w))
.where(col("rn") == 1)
.select("id_sa", "id_sb"))
Using struct
ordering:
from pyspark.sql.functions import struct
(cnts
.groupBy("id_sa")
.agg(F.max(struct(col("cnt"), col("id_sb"))).alias("max"))
.select(col("id_sa"), col("max.id_sb")))
See also How to select the first row of each group?