Spark dataframe transform multiple rows to column
Using zero323's dataframe,
df = sqlContext.createDataFrame([
("a", 1, "m1"), ("a", 1, "m2"), ("a", 2, "m3"),
("a", 3, "m4"), ("b", 4, "m1"), ("b", 1, "m2"),
("b", 2, "m3"), ("c", 3, "m1"), ("c", 4, "m3"),
("c", 5, "m4"), ("d", 6, "m1"), ("d", 1, "m2"),
("d", 2, "m3"), ("d", 3, "m4"), ("d", 4, "m5"),
("e", 4, "m1"), ("e", 5, "m2"), ("e", 1, "m3"),
("e", 1, "m4"), ("e", 1, "m5")],
("a", "cnt", "major"))
you could also use
reshaped_df = df.groupby('a').pivot('major').max('cnt').fillna(0)
Lets start with example data:
df = sqlContext.createDataFrame([
("a", 1, "m1"), ("a", 1, "m2"), ("a", 2, "m3"),
("a", 3, "m4"), ("b", 4, "m1"), ("b", 1, "m2"),
("b", 2, "m3"), ("c", 3, "m1"), ("c", 4, "m3"),
("c", 5, "m4"), ("d", 6, "m1"), ("d", 1, "m2"),
("d", 2, "m3"), ("d", 3, "m4"), ("d", 4, "m5"),
("e", 4, "m1"), ("e", 5, "m2"), ("e", 1, "m3"),
("e", 1, "m4"), ("e", 1, "m5")],
("a", "cnt", "major"))
Please note that I've changed count
to cnt
. Count is a reserved keyword in most of the SQL dialects and it is not a good choice for a column name.
There are at least two ways to reshape this data:
aggregating over DataFrame
from pyspark.sql.functions import col, when, max majors = sorted(df.select("major") .distinct() .map(lambda row: row[0]) .collect()) cols = [when(col("major") == m, col("cnt")).otherwise(None).alias(m) for m in majors] maxs = [max(col(m)).alias(m) for m in majors] reshaped1 = (df .select(col("a"), *cols) .groupBy("a") .agg(*maxs) .na.fill(0)) reshaped1.show() ## +---+---+---+---+---+---+ ## | a| m1| m2| m3| m4| m5| ## +---+---+---+---+---+---+ ## | a| 1| 1| 2| 3| 0| ## | b| 4| 1| 2| 0| 0| ## | c| 3| 0| 4| 5| 0| ## | d| 6| 1| 2| 3| 4| ## | e| 4| 5| 1| 1| 1| ## +---+---+---+---+---+---+
groupBy
over RDDfrom pyspark.sql import Row grouped = (df .map(lambda row: (row.a, (row.major, row.cnt))) .groupByKey()) def make_row(kv): k, vs = kv tmp = dict(list(vs) + [("a", k)]) return Row(**{k: tmp.get(k, 0) for k in ["a"] + majors}) reshaped2 = sqlContext.createDataFrame(grouped.map(make_row)) reshaped2.show() ## +---+---+---+---+---+---+ ## | a| m1| m2| m3| m4| m5| ## +---+---+---+---+---+---+ ## | a| 1| 1| 2| 3| 0| ## | e| 4| 5| 1| 1| 1| ## | c| 3| 0| 4| 5| 0| ## | b| 4| 1| 2| 0| 0| ## | d| 6| 1| 2| 3| 4| ## +---+---+---+---+---+---+