Adding a group count column to a PySpark dataframe
When you do a groupBy()
, you have to specify the aggregation before you can display the results. For example:
import pyspark.sql.functions as f
data = [
('a', 5),
('a', 8),
('a', 7),
('b', 1),
]
df = sqlCtx.createDataFrame(data, ["x", "y"])
df.groupBy('x').count().select('x', f.col('count').alias('n')).show()
#+---+---+
#| x| n|
#+---+---+
#| b| 1|
#| a| 3|
#+---+---+
Here I used alias()
to rename the column. But this only returns one row per group. If you want all rows with the count appended, you can do this with a Window
:
from pyspark.sql import Window
w = Window.partitionBy('x')
df.select('x', 'y', f.count('x').over(w).alias('n')).sort('x', 'y').show()
#+---+---+---+
#| x| y| n|
#+---+---+---+
#| a| 5| 3|
#| a| 7| 3|
#| a| 8| 3|
#| b| 1| 1|
#+---+---+---+
Or if you're more comfortable with SQL, you can register the dataframe as a temporary table and take advantage of pyspark-sql
to do the same thing:
df.registerTempTable('table')
sqlCtx.sql(
'SELECT x, y, COUNT(x) OVER (PARTITION BY x) AS n FROM table ORDER BY x, y'
).show()
#+---+---+---+
#| x| y| n|
#+---+---+---+
#| a| 5| 3|
#| a| 7| 3|
#| a| 8| 3|
#| b| 1| 1|
#+---+---+---+
I found we can get even more close to the tidyverse example:
from pyspark.sql import Window
w = Window.partitionBy('x')
df.withColumn('n', f.count('x').over(w)).sort('x', 'y').show()
as @pault appendix
import pyspark.sql.functions as F
...
(df
.groupBy(F.col('x'))
.agg(F.count('x').alias('n'))
.show())
#+---+---+
#| x| n|
#+---+---+
#| b| 1|
#| a| 3|
#+---+---+
enjoy