Spark: How to aggregate/reduce records based on time difference?

For Pandas users, there is pretty much a common programming pattern using shift() + cumsum() to setup a group-label to identify consecutive rows matching some specific patterns/conditions. With pyspark, we can use Window functions lag() + sum() to do the same and find this group-label (d2 in the following code):

Data Setup:

from pyspark.sql import functions as F, Window

>>> df.orderBy('timestamp').show()
+-------+----------+-----+
|trip-id| timestamp|speed|
+-------+----------+-----+
|    001|1538204192|44.55|
|    001|1538204193|47.20|
|    001|1538204194|42.14|
|    001|1538204195|39.20|
|    001|1538204196|35.30|
|    001|1538204197|32.22|
|    001|1538204198|34.80|
|    001|1538204199|37.10|
|    001|1538204221|55.30|
|    001|1538204222|57.20|
|    001|1538204223|54.60|
|    001|1538204224|52.15|
|    001|1538204225|49.27|
|    001|1538204226|47.89|
|    001|1538204227|50.57|
|    001|1538204228|53.72|
+-------+----------+-----+

>>> df.printSchema()
root
 |-- trip-id: string (nullable = true)
 |-- unix_timestamp: integer (nullable = true)
 |-- speed: double (nullable = true)

Set up two Window Spec (w1, w2):

# Window spec used to find previous speed F.lag('speed').over(w1) and also do the cumsum() to find flag `d2`
w1 = Window.partitionBy('trip-id').orderBy('timestamp')

# Window spec used to find the minimal value of flag `d1` over the partition(`trip-id`,`d2`)
w2 = Window.partitionBy('trip-id', 'd2').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

Three flags (d1, d2, d3):

  • d1 : flag to identify if the previous speed is greater than the current speed, if true d1 = 0, else d1 = 1
  • d2 : flag to mark the consecutive rows for speed-drop with the same unique number
  • d3 : flag to identify the minimal value of d1 on the partition('trip-id', 'd2'), only when d3 == 0 can the row belong to a group with speed-drop. this will be used to filter out unrelated rows

    df_1 = df.withColumn('d1', F.when(F.lag('speed').over(w1) > F.col('speed'), 0).otherwise(1))\
             .withColumn('d2', F.sum('d1').over(w1)) \
             .withColumn('d3', F.min('d1').over(w2))
    
    >>> df_1.orderBy('timestamp').show()
    +-------+----------+-----+---+---+---+
    |trip-id| timestamp|speed| d1| d2| d3|
    +-------+----------+-----+---+---+---+
    |    001|1538204192|44.55|  1|  1|  1|
    |    001|1538204193|47.20|  1|  2|  0|
    |    001|1538204194|42.14|  0|  2|  0|
    |    001|1538204195|39.20|  0|  2|  0|
    |    001|1538204196|35.30|  0|  2|  0|
    |    001|1538204197|32.22|  0|  2|  0|
    |    001|1538204198|34.80|  1|  3|  1|
    |    001|1538204199|37.10|  1|  4|  1|
    |    001|1538204221|55.30|  1|  5|  1|
    |    001|1538204222|57.20|  1|  6|  0|
    |    001|1538204223|54.60|  0|  6|  0|
    |    001|1538204224|52.15|  0|  6|  0|
    |    001|1538204225|49.27|  0|  6|  0|
    |    001|1538204226|47.89|  0|  6|  0|
    |    001|1538204227|50.57|  1|  7|  1|
    |    001|1538204228|53.72|  1|  8|  1|
    +-------+----------+-----+---+---+---+
    

Remove rows which are not with concern:

df_1 = df_1.where('d3 == 0')

>>> df_1.orderBy('timestamp').show()
+-------+----------+-----+---+---+---+
|trip-id| timestamp|speed| d1| d2| d3|
+-------+----------+-----+---+---+---+
|    001|1538204193|47.20|  1|  2|  0|
|    001|1538204194|42.14|  0|  2|  0|
|    001|1538204195|39.20|  0|  2|  0|
|    001|1538204196|35.30|  0|  2|  0|
|    001|1538204197|32.22|  0|  2|  0|
|    001|1538204222|57.20|  1|  6|  0|
|    001|1538204223|54.60|  0|  6|  0|
|    001|1538204224|52.15|  0|  6|  0|
|    001|1538204225|49.27|  0|  6|  0|
|    001|1538204226|47.89|  0|  6|  0|
+-------+----------+-----+---+---+---+

Final Step:

Now for df_1, group by trip-id and d2, find the min and max of F.struct('timestamp', 'speed') which will return the first and last records in the group, select the corresponding fields from the struct to get the final result:

df_new = df_1.groupby('trip-id', 'd2').agg(
          F.min(F.struct('timestamp', 'speed')).alias('start')
        , F.max(F.struct('timestamp', 'speed')).alias('end')
).select(
      'trip-id'
    , F.col('start.timestamp').alias('start timestamp')
    , F.col('end.timestamp').alias('end timestamp')
    , F.col('start.speed').alias('start speed')
    , F.col('end.speed').alias('end speed')
)

>>> df_new.show()
+-------+---------------+-------------+-----------+---------+
|trip-id|start timestamp|end timestamp|start speed|end speed|
+-------+---------------+-------------+-----------+---------+
|    001|     1538204193|   1538204197|      47.20|    32.22|
|    001|     1538204222|   1538204226|      57.20|    47.89|
+-------+---------------+-------------+-----------+---------+

Note: Remove the intermediate dataframe df_1, we can have the following:

df_new = df.withColumn('d1', F.when(F.lag('speed').over(w1) > F.col('speed'), 0).otherwise(1))\
           .withColumn('d2', F.sum('d1').over(w1)) \
           .withColumn('d3', F.min('d1').over(w2)) \
           .where('d3 == 0') \
           .groupby('trip-id', 'd2').agg(
                F.min(F.struct('timestamp', 'speed')).alias('start')
              , F.max(F.struct('timestamp', 'speed')).alias('end')
            )\
           .select(
                'trip-id'
              , F.col('start.timestamp').alias('start timestamp')
              , F.col('end.timestamp').alias('end timestamp')
              , F.col('start.speed').alias('start speed')
              , F.col('end.speed').alias('end speed')
            )

Hope this helps. Scala code.

Output

+-------------+---------------+-------------+-----------+---------+
|      breakID|start timestamp|end timestamp|start speed|end speed|
+-------------+---------------+-------------+-----------+---------+
|0011538204193|     1538204193|   1538204196|       47.2|     35.3|
|0011538204222|     1538204222|   1538204225|       57.2|    49.27|
+-------------+---------------+-------------+-----------+---------+

CODE

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.expressions.WindowSpec
import org.apache.spark.sql.functions._

scala> df.show
+-------+----------+-----+
|trip-id| timestamp|speed|
+-------+----------+-----+
|    001|1538204192|44.55|
|    001|1538204193| 47.2|
|    001|1538204194|42.14|
|    001|1538204195| 39.2|
|    001|1538204196| 35.3|
|    001|1538204197|32.22|
|    001|1538204198| 34.8|
|    001|1538204199| 37.1|
|    001|1538204221| 55.3|
|    001|1538204222| 57.2|
|    001|1538204223| 54.6|
|    001|1538204224|52.15|
|    001|1538204225|49.27|
|    001|1538204226|47.89|
|    001|1538204227|50.57|
|    001|1538204228|53.72|
+-------+----------+-----+

val overColumns = Window.partitionBy("trip-id").orderBy("timestamp")
val breaksDF = df
  .withColumn("speeddiff", lead("speed", 1).over(overColumns) - $"speed")
  .withColumn("breaking", when($"speeddiff" < 0, 1).otherwise(0))

scala> breaksDF.show
+-------+----------+-----+-------------------+--------+
|trip-id| timestamp|speed|          speeddiff|breaking|
+-------+----------+-----+-------------------+--------+
|    001|1538204192|44.55| 2.6500000000000057|       0|
|    001|1538204193| 47.2| -5.060000000000002|       1|
|    001|1538204194|42.14|-2.9399999999999977|       1|
|    001|1538204195| 39.2|-3.9000000000000057|       1|
|    001|1538204196| 35.3|-3.0799999999999983|       1|
|    001|1538204197|32.22| 2.5799999999999983|       0|
|    001|1538204198| 34.8| 2.3000000000000043|       0|
|    001|1538204199| 37.1| 18.199999999999996|       0|
|    001|1538204221| 55.3| 1.9000000000000057|       0|
|    001|1538204222| 57.2|-2.6000000000000014|       1|
|    001|1538204223| 54.6| -2.450000000000003|       1|
|    001|1538204224|52.15|-2.8799999999999955|       1|
|    001|1538204225|49.27|-1.3800000000000026|       1|
|    001|1538204226|47.89| 2.6799999999999997|       0|
|    001|1538204227|50.57| 3.1499999999999986|       0|
|    001|1538204228|53.72|               null|       0|
+-------+----------+-----+-------------------+--------+


val outputDF = breaksDF
  .withColumn("breakevent", 
    when(($"breaking" - lag($"breaking", 1).over(overColumns)) === 1, "start of break")
    .when(($"breaking" - lead($"breaking", 1).over(overColumns)) === 1, "end of break"))

scala> outputDF.show
+-------+----------+-----+-------------------+--------+--------------+
|trip-id| timestamp|speed|          speeddiff|breaking|    breakevent|
+-------+----------+-----+-------------------+--------+--------------+
|    001|1538204192|44.55| 2.6500000000000057|       0|          null|
|    001|1538204193| 47.2| -5.060000000000002|       1|start of break|
|    001|1538204194|42.14|-2.9399999999999977|       1|          null|
|    001|1538204195| 39.2|-3.9000000000000057|       1|          null|
|    001|1538204196| 35.3|-3.0799999999999983|       1|  end of break|
|    001|1538204197|32.22| 2.5799999999999983|       0|          null|
|    001|1538204198| 34.8| 2.3000000000000043|       0|          null|
|    001|1538204199| 37.1| 18.199999999999996|       0|          null|
|    001|1538204221| 55.3| 1.9000000000000057|       0|          null|
|    001|1538204222| 57.2|-2.6000000000000014|       1|start of break|
|    001|1538204223| 54.6| -2.450000000000003|       1|          null|
|    001|1538204224|52.15|-2.8799999999999955|       1|          null|
|    001|1538204225|49.27|-1.3800000000000026|       1|  end of break|
|    001|1538204226|47.89| 2.6799999999999997|       0|          null|
|    001|1538204227|50.57| 3.1499999999999986|       0|          null|
|    001|1538204228|53.72|               null|       0|          null|
+-------+----------+-----+-------------------+--------+--------------+


scala> outputDF.filter("breakevent is not null").select("trip-id", "timestamp", "speed", "breakevent").show
+-------+----------+-----+--------------+
|trip-id| timestamp|speed|    breakevent|
+-------+----------+-----+--------------+
|    001|1538204193| 47.2|start of break|
|    001|1538204196| 35.3|  end of break|
|    001|1538204222| 57.2|start of break|
|    001|1538204225|49.27|  end of break|
+-------+----------+-----+--------------+

outputDF.filter("breakevent is not null").withColumn("breakID", 
  when($"breakevent" === "start of break", concat($"trip-id",$"timestamp"))
  .when($"breakevent" === "end of break", concat($"trip-id", lag($"timestamp", 1).over(overColumns))))
  .groupBy("breakID").agg(first($"timestamp") as "start timestamp", last($"timestamp") as "end timestamp", first($"speed") as "start speed", last($"speed") as "end speed").show


+-------------+---------------+-------------+-----------+---------+
|      breakID|start timestamp|end timestamp|start speed|end speed|
+-------------+---------------+-------------+-----------+---------+
|0011538204193|     1538204193|   1538204196|       47.2|     35.3|
|0011538204222|     1538204222|   1538204225|       57.2|    49.27|
+-------------+---------------+-------------+-----------+---------+