How to check the number of partitions of a Spark DataFrame without incurring the cost of .rdd
There is no inherent cost of rdd
component in rdd.getNumPartitions
, because returned RDD
is never evaluated.
While you can easily determine this empirically, using debugger (I'll leave this as an exercise for the reader), or establishing that no jobs are triggered in the base case scenario
Spark session available as 'spark'.
Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/___/ .__/\_,_/_/ /_/\_\ version 2.4.0
/_/
Using Scala version 2.11.12 (OpenJDK 64-Bit Server VM, Java 1.8.0_181)
Type in expressions to have them evaluated.
Type :help for more information.
scala> val ds = spark.read.text("README.md")
ds: org.apache.spark.sql.DataFrame = [value: string]
scala> ds.rdd.getNumPartitions
res0: Int = 1
scala> spark.sparkContext.statusTracker.getJobIdsForGroup(null).isEmpty // Check if there are any known jobs
res1: Boolean = true
it might be not enough to convince you. So let's approach this in a more systematic way:
rdd
returns aMapPartitionRDD
(ds
as defined above):scala> ds.rdd.getClass res2: Class[_ <: org.apache.spark.rdd.RDD[org.apache.spark.sql.Row]] = class org.apache.spark.rdd.MapPartitionsRDD
RDD.getNumPartitions
invokesRDD.partitions
.- In non-checkpointed scenario
RDD.partitions
invokesgetPartitions
(feel free to trace the checkpoint path as well). RDD.getPartitions
is abstract.- So the actual implementation used in this case is
MapPartitionsRDD.getPartitions
, which simply delegates the call to the parent. There are only
MapPartitionsRDD
betweenrdd
and the source.scala> ds.rdd.toDebugString res3: String = (1) MapPartitionsRDD[3] at rdd at <console>:26 [] | MapPartitionsRDD[2] at rdd at <console>:26 [] | MapPartitionsRDD[1] at rdd at <console>:26 [] | FileScanRDD[0] at rdd at <console>:26 []
Similarly if
Dataset
contained an exchange we would follow the parents to the nearest shuffle:scala> ds.orderBy("value").rdd.toDebugString res4: String = (67) MapPartitionsRDD[13] at rdd at <console>:26 [] | MapPartitionsRDD[12] at rdd at <console>:26 [] | MapPartitionsRDD[11] at rdd at <console>:26 [] | ShuffledRowRDD[10] at rdd at <console>:26 [] +-(1) MapPartitionsRDD[9] at rdd at <console>:26 [] | MapPartitionsRDD[5] at rdd at <console>:26 [] | FileScanRDD[4] at rdd at <console>:26 []
Note that this case is particularly interesting, because we actually triggered a job:
scala> spark.sparkContext.statusTracker.getJobIdsForGroup(null).isEmpty res5: Boolean = false scala> spark.sparkContext.statusTracker.getJobIdsForGroup(null) res6: Array[Int] = Array(0)
That's because we've encountered as scenario where the partitions cannot be determined statically (see Number of dataframe partitions after sorting? and Why does sortBy transformation trigger a Spark job?).
In such scenario
getNumPartitions
will also trigger a job:scala> ds.orderBy("value").rdd.getNumPartitions res7: Int = 67 scala> spark.sparkContext.statusTracker.getJobIdsForGroup(null) // Note new job id res8: Array[Int] = Array(1, 0)
however it doesn't mean that the observed cost is somehow related to
.rdd
call. Instead it is an intrinsic cost of findingpartitions
in case, where there is no static formula (some Hadoop input formats for example, where full scan of the data is required).
Please note that the points made here shouldn't be extrapolated to other applications of Dataset.rdd
. For example ds.rdd.count
would be indeed expensive and wasteful.
In my experience df.rdd.getNumPartitions
is very fast, I never encountered taking this more than a second or so.
Alternatively, you could also try
val numPartitions: Long = df
.select(org.apache.spark.sql.functions.spark_partition_id()).distinct().count()
which would avoid using .rdd