Automatically and Elegantly flatten DataFrame in Spark SQL
The short answer is, there's no "accepted" way to do this, but you can do it very elegantly with a recursive function that generates your select(...)
statement by walking through the DataFrame.schema
.
The recursive function should return an Array[Column]
. Every time the function hits a StructType
, it would call itself and append the returned Array[Column]
to its own Array[Column]
.
Something like:
import org.apache.spark.sql.Column
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.functions.col
def flattenSchema(schema: StructType, prefix: String = null) : Array[Column] = {
schema.fields.flatMap(f => {
val colName = if (prefix == null) f.name else (prefix + "." + f.name)
f.dataType match {
case st: StructType => flattenSchema(st, colName)
case _ => Array(col(colName))
}
})
}
You would then use it like this:
df.select(flattenSchema(df.schema):_*)
I added a DataFrame#flattenSchema
method to the open source spark-daria project.
Here's how you can use the function with your code.
import com.github.mrpowers.spark.daria.sql.DataFrameExt._
df.flattenSchema().show()
+-------+-------+---------+----+---+
|foo.bar|foo.baz| x| y| z|
+-------+-------+---------+----+---+
| this| is|something|cool| ;)|
+-------+-------+---------+----+---+
You can also specify different column name delimiters with the flattenSchema()
method.
df.flattenSchema(delimiter = "_").show()
+-------+-------+---------+----+---+
|foo_bar|foo_baz| x| y| z|
+-------+-------+---------+----+---+
| this| is|something|cool| ;)|
+-------+-------+---------+----+---+
This delimiter parameter is surprisingly important. If you're flattening your schema to load the table in Redshift, you won't be able to use periods as the delimiter.
Here's the full code snippet to generate this output.
val data = Seq(
Row(Row("this", "is"), "something", "cool", ";)")
)
val schema = StructType(
Seq(
StructField(
"foo",
StructType(
Seq(
StructField("bar", StringType, true),
StructField("baz", StringType, true)
)
),
true
),
StructField("x", StringType, true),
StructField("y", StringType, true),
StructField("z", StringType, true)
)
)
val df = spark.createDataFrame(
spark.sparkContext.parallelize(data),
StructType(schema)
)
df.flattenSchema().show()
The underlying code is similar to David Griffin's code (in case you don't want to add the spark-daria dependency to your project).
object StructTypeHelpers {
def flattenSchema(schema: StructType, delimiter: String = ".", prefix: String = null): Array[Column] = {
schema.fields.flatMap(structField => {
val codeColName = if (prefix == null) structField.name else prefix + "." + structField.name
val colName = if (prefix == null) structField.name else prefix + delimiter + structField.name
structField.dataType match {
case st: StructType => flattenSchema(schema = st, delimiter = delimiter, prefix = colName)
case _ => Array(col(codeColName).alias(colName))
}
})
}
}
object DataFrameExt {
implicit class DataFrameMethods(df: DataFrame) {
def flattenSchema(delimiter: String = ".", prefix: String = null): DataFrame = {
df.select(
StructTypeHelpers.flattenSchema(df.schema, delimiter, prefix): _*
)
}
}
}
I am improving my previous answer and offering a solution to my own problem stated in the comments of the accepted answer.
This accepted solution creates an array of Column objects and uses it to select these columns. In Spark, if you have a nested DataFrame, you can select the child column like this: df.select("Parent.Child")
and this returns a DataFrame with the values of the child column and is named Child. But if you have identical names for attributes of different parent structures, you lose the info about the parent and may end up with identical column names and cannot access them by name anymore as they are unambiguous.
This was my problem.
I found a solution to my problem, maybe it can help someone else as well. I called the flattenSchema
separately:
val flattenedSchema = flattenSchema(df.schema)
and this returned an Array of Column objects. Instead of using this in the select()
, which would return a DataFrame with columns named by the child of the last level, I mapped the original column names to themselves as strings, then after selecting Parent.Child
column, it renames it as Parent.Child
instead of Child
(I also replaced dots with underscores for my convenience):
val renamedCols = flattenedSchema.map(name => col(name.toString()).as(name.toString().replace(".","_")))
And then you can use the select function as shown in the original answer:
var newDf = df.select(renamedCols:_*)
Just wanted to share my solution for Pyspark - it's more or less a translation of @David Griffin's solution, so it supports any level of nested objects.
from pyspark.sql.types import StructType, ArrayType
def flatten(schema, prefix=None):
fields = []
for field in schema.fields:
name = prefix + '.' + field.name if prefix else field.name
dtype = field.dataType
if isinstance(dtype, ArrayType):
dtype = dtype.elementType
if isinstance(dtype, StructType):
fields += flatten(dtype, prefix=name)
else:
fields.append(name)
return fields
df.select(flatten(df.schema)).show()