Filtering rows based on column values in spark dataframe scala
Hi I found the solution using Window and self join.
val data = Seq((3,0,2),(3,1,3),(3,0,1),(4,1,6),(4,0,5),(4,0,4),(1,0,7),(1,1,8),(1,0,9),(2,1,10),(2,0,11),(2,0,12)).toDF("id", "value","sorted")
data.show
scala> data.show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
| 3| 0| 2|
| 3| 1| 3|
| 3| 0| 1|
| 4| 1| 6|
| 4| 0| 5|
| 4| 0| 4|
| 1| 0| 7|
| 1| 1| 8|
| 1| 0| 9|
| 2| 1| 10|
| 2| 0| 11|
| 2| 0| 12|
+---+-----+------+
val sort_df=data.sort($"sorted")
scala> sort_df.show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
| 3| 0| 1|
| 3| 0| 2|
| 3| 1| 3|
| 4| 0| 4|
| 4| 0| 5|
| 4| 1| 6|
| 1| 0| 7|
| 1| 1| 8|
| 1| 0| 9|
| 2| 1| 10|
| 2| 0| 11|
| 2| 0| 12|
+---+-----+------+
var window=Window.partitionBy("id").orderBy("$sorted")
val sort_idx=sort_df.select($"*",rowNumber.over(window).as("count_index"))
val minIdx=sort_idx.filter($"value"===1).groupBy("id").agg(min("count_index")).toDF("idx","min_idx")
val result_id=sort_idx.join(minIdx,($"id"===$"idx") &&($"count_index" <= $"min_idx"))
result_id.show
+---+-----+------+-----------+---+-------+
| id|value|sorted|count_index|idx|min_idx|
+---+-----+------+-----------+---+-------+
| 1| 0| 7| 1| 1| 2|
| 1| 1| 8| 2| 1| 2|
| 2| 1| 10| 1| 2| 1|
| 3| 0| 1| 1| 3| 3|
| 3| 0| 2| 2| 3| 3|
| 3| 1| 3| 3| 3| 3|
| 4| 0| 4| 1| 4| 3|
| 4| 0| 5| 2| 4| 3|
| 4| 1| 6| 3| 4| 3|
+---+-----+------+-----------+---+-------+
Still looking for a more optimized solutions.Thanks
One way is to use monotonically_increasing_id()
and a self-join:
val data = Seq((3,0),(3,1),(3,0),(4,1),(4,0),(4,0)).toDF("id", "value")
data.show
+---+-----+
| id|value|
+---+-----+
| 3| 0|
| 3| 1|
| 3| 0|
| 4| 1|
| 4| 0|
| 4| 0|
+---+-----+
Now we generate a column named idx
with an increasing Long
:
val dataWithIndex = data.withColumn("idx", monotonically_increasing_id())
// dataWithIndex.cache()
Now we get the min(idx)
for each id
where value = 1
:
val minIdx = dataWithIndex
.filter($"value" === 1)
.groupBy($"id")
.agg(min($"idx"))
.toDF("r_id", "min_idx")
Now we join the min(idx)
back to the original DataFrame
:
dataWithIndex.join(
minIdx,
($"r_id" === $"id") && ($"idx" <= $"min_idx")
).select($"id", $"value").show
+---+-----+
| id|value|
+---+-----+
| 3| 0|
| 3| 1|
| 4| 1|
+---+-----+
Note: monotonically_increasing_id()
generates its value based on the partition of the row. This value may change each time dataWithIndex
is re-evaluated. In my code above, because of lazy evaluation, it's only when I call the final show
that monotonically_increasing_id()
is evaluated.
If you want to force the value to stay the same, for example so you can use show
to evaluate the above step-by-step, uncomment this line above:
// dataWithIndex.cache()