PySpark: Take average of a column after using filter function
Aggregation function should be a value and a column name a key:
dataframe.filter(df['salary'] > 100000).agg({"age": "avg"})
Alternatively you can use pyspark.sql.functions
:
from pyspark.sql.functions import col, avg
dataframe.filter(df['salary'] > 100000).agg(avg(col("age")))
It is also possible to use CASE .. WHEN
from pyspark.sql.functions import when
dataframe.select(avg(when(df['salary'] > 100000, df['age'])))
You can try this too:
dataframe.filter(df['salary'] > 100000).groupBy().avg('age')