get first N elements from dataframe ArrayType column in pyspark
Here's how to do it with the API functions.
Suppose your DataFrame were the following:
df.show()
#+---+---------+
#| id| letters|
#+---+---------+
#| 1|[a, b, c]|
#| 2|[d, e, f]|
#| 3|[g, h, i]|
#+---+---------+
df.printSchema()
#root
# |-- id: long (nullable = true)
# |-- letters: array (nullable = true)
# | |-- element: string (containsNull = true)
You can use square brackets to access elements in the letters
column by index, and wrap that in a call to pyspark.sql.functions.array()
to create a new ArrayType
column.
import pyspark.sql.functions as f
df.withColumn("first_two", f.array([f.col("letters")[0], f.col("letters")[1]])).show()
#+---+---------+---------+
#| id| letters|first_two|
#+---+---------+---------+
#| 1|[a, b, c]| [a, b]|
#| 2|[d, e, f]| [d, e]|
#| 3|[g, h, i]| [g, h]|
#+---+---------+---------+
Or if you had too many indices to list, you can use a list comprehension:
df.withColumn("first_two", f.array([f.col("letters")[i] for i in range(2)])).show()
#+---+---------+---------+
#| id| letters|first_two|
#+---+---------+---------+
#| 1|[a, b, c]| [a, b]|
#| 2|[d, e, f]| [d, e]|
#| 3|[g, h, i]| [g, h]|
#+---+---------+---------+
For pyspark versions 2.4+ you can also use pyspark.sql.functions.slice()
:
df.withColumn("first_two",f.slice("letters",start=1,length=2)).show()
#+---+---------+---------+
#| id| letters|first_two|
#+---+---------+---------+
#| 1|[a, b, c]| [a, b]|
#| 2|[d, e, f]| [d, e]|
#| 3|[g, h, i]| [g, h]|
#+---+---------+---------+
slice
may have better performance for large arrays (note that start index is 1, not 0)
Either my pyspark skills have gone rusty (I confess I don't hone them much anymore nowadays), or this is a tough nut indeed... The only way I managed to do it is by using SQL statements:
spark.version
# u'2.3.1'
# dummy data:
from pyspark.sql import Row
x = [Row(col1="xx", col2="yy", col3="zz", col4=[123,234, 456])]
rdd = sc.parallelize(x)
df = spark.createDataFrame(rdd)
df.show()
# result:
+----+----+----+---------------+
|col1|col2|col3| col4|
+----+----+----+---------------+
| xx| yy| zz|[123, 234, 456]|
+----+----+----+---------------+
df.createOrReplaceTempView("df")
df2 = spark.sql("SELECT col1, col2, col3, (col4[0], col4[1]) as col5 FROM df")
df2.show()
# result:
+----+----+----+----------+
|col1|col2|col3| col5|
+----+----+----+----------+
| xx| yy| zz|[123, 234]|
+----+----+----+----------+
For future questions, it would be good to follow the suggested guidelines on How to make good reproducible Apache Spark Dataframe examples.