DataFrame equality in Apache Spark
Scala (see below for PySpark)
The spark-fast-tests library has two methods for making DataFrame comparisons (I'm the creator of the library):
The assertSmallDataFrameEquality
method collects DataFrames on the driver node and makes the comparison
def assertSmallDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
if (!actualDF.schema.equals(expectedDF.schema)) {
throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
}
if (!actualDF.collect().sameElements(expectedDF.collect())) {
throw new DataFrameContentMismatch(contentMismatchMessage(actualDF, expectedDF))
}
}
The assertLargeDataFrameEquality
method compares DataFrames spread on multiple machines (the code is basically copied from spark-testing-base)
def assertLargeDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
if (!actualDF.schema.equals(expectedDF.schema)) {
throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
}
try {
actualDF.rdd.cache
expectedDF.rdd.cache
val actualCount = actualDF.rdd.count
val expectedCount = expectedDF.rdd.count
if (actualCount != expectedCount) {
throw new DataFrameContentMismatch(countMismatchMessage(actualCount, expectedCount))
}
val expectedIndexValue = zipWithIndex(actualDF.rdd)
val resultIndexValue = zipWithIndex(expectedDF.rdd)
val unequalRDD = expectedIndexValue
.join(resultIndexValue)
.filter {
case (idx, (r1, r2)) =>
!(r1.equals(r2) || RowComparer.areRowsEqual(r1, r2, 0.0))
}
val maxUnequalRowsToShow = 10
assertEmpty(unequalRDD.take(maxUnequalRowsToShow))
} finally {
actualDF.rdd.unpersist()
expectedDF.rdd.unpersist()
}
}
assertSmallDataFrameEquality
is faster for small DataFrame comparisons and I've found it sufficient for my test suites.
PySpark
Here's a simple function that returns true if the DataFrames are equal:
def are_dfs_equal(df1, df2):
if df1.schema != df2.schema:
return False
if df1.collect() != df2.collect():
return False
return True
or simplified
def are_dfs_equal(df1, df2):
return (df1.schema == df2.schema) and (df1.collect() == df2.collect())
You'll typically perform DataFrame equality comparisons in a test suite and will want a descriptive error message when the comparisons fail (a True
/ False
return value doesn't help much when debugging).
Use the chispa library to access the assert_df_equality
method that returns descriptive error messages for test suite workflows.
I don't know about idiomatic, but I think you can get a robust way to compare DataFrames as you describe as follows. (I'm using PySpark for illustration, but the approach carries across languages.)
a = spark.range(5)
b = spark.range(5)
a_prime = a.groupBy(sorted(a.columns)).count()
b_prime = b.groupBy(sorted(b.columns)).count()
assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0
This approach correctly handles cases where the DataFrames may have duplicate rows, rows in different orders, and/or columns in different orders.
For example:
a = spark.createDataFrame([('nick', 30), ('bob', 40)], ['name', 'age'])
b = spark.createDataFrame([(40, 'bob'), (30, 'nick')], ['age', 'name'])
c = spark.createDataFrame([('nick', 30), ('bob', 40), ('nick', 30)], ['name', 'age'])
a_prime = a.groupBy(sorted(a.columns)).count()
b_prime = b.groupBy(sorted(b.columns)).count()
c_prime = c.groupBy(sorted(c.columns)).count()
assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0
assert a_prime.subtract(c_prime).count() != 0
This approach is quite expensive, but most of the expense is unavoidable given the need to perform a full diff. And this should scale fine as it doesn't require collecting anything locally. If you relax the constraint that the comparison should account for duplicate rows, then you can drop the groupBy()
and just do the subtract()
, which would probably speed things up notably.
There are some standard ways in the Apache Spark test suites, however most of these involve collecting the data locally and if you want to do equality testing on large DataFrames then that is likely not a suitable solution.
Checking the schema first and then you could do an intersection to df3 and verify that the count of df1,df2 & df3 are all equal (however this only works if there aren't duplicate rows, if there are different duplicates rows this method could still return true).
Another option would be getting the underlying RDDs of both of the DataFrames, mapping to (Row, 1), doing a reduceByKey to count the number of each Row, and then cogrouping the two resulting RDDs and then do a regular aggregate and return false if any of the iterators are not equal.