PySpark DataFrames - way to enumerate without converting to Pandas?
It doesn't work because:
- the second argument for
withColumn
should be aColumn
not a collection.np.array
won't work here - when you pass
"index in indexes"
as a SQL expression towhere
indexes
is out of scope and it is not resolved as a valid identifier
PySpark >= 1.4.0
You can add row numbers using respective window function and query using Column.isin
method or properly formated query string:
from pyspark.sql.functions import col, rowNumber
from pyspark.sql.window import Window
w = Window.orderBy()
indexed = df.withColumn("index", rowNumber().over(w))
# Using DSL
indexed.where(col("index").isin(set(indexes)))
# Using SQL expression
indexed.where("index in ({0})".format(",".join(str(x) for x in indexes)))
It looks like window functions called without PARTITION BY
clause move all data to the single partition so above may be not the best solution after all.
Any faster and simpler way to deal with it?
Not really. Spark DataFrames don't support random row access.
PairedRDD
can be accessed using lookup
method which is relatively fast if data is partitioned using HashPartitioner
. There is also indexed-rdd project which supports efficient lookups.
Edit:
Independent of PySpark version you can try something like this:
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, LongType
row = Row("char")
row_with_index = Row("char", "index")
df = sc.parallelize(row(chr(x)) for x in range(97, 112)).toDF()
df.show(5)
## +----+
## |char|
## +----+
## | a|
## | b|
## | c|
## | d|
## | e|
## +----+
## only showing top 5 rows
# This part is not tested but should work and save some work later
schema = StructType(
df.schema.fields[:] + [StructField("index", LongType(), False)])
indexed = (df.rdd # Extract rdd
.zipWithIndex() # Add index
.map(lambda ri: row_with_index(*list(ri[0]) + [ri[1]])) # Map to rows
.toDF(schema)) # It will work without schema but will be more expensive
# inSet in Spark < 1.3
indexed.where(col("index").isin(indexes))
If you want a number range that's guaranteed not to collide but does not require a .over(partitionBy())
then you can use monotonicallyIncreasingId()
.
from pyspark.sql.functions import monotonicallyIncreasingId
df.select(monotonicallyIncreasingId().alias("rowId"),"*")
Note though that the values are not particularly "neat". Each partition is given a value range and the output will not be contiguous. E.g. 0, 1, 2, 8589934592, 8589934593, 8589934594
.
This was added to Spark on Apr 28, 2015 here: https://github.com/apache/spark/commit/d94cd1a733d5715792e6c4eac87f0d5c81aebbe2