Select Specific Columns from Spark DataFrame

If you want to split you dataframe into two different ones, do two selects on it with the different columns you want.

 val sourceDf = spark.read.csv(...)
 val df1 = sourceDF.select("first column", "second column", "third column")
 val df2 = sourceDF.select("first column", "second column", "third column")

Note that this of course means that the sourceDf would be evaluated twice, so if it can fit into distributed memory and you use most of the columns across both dataframes it might be a good idea to cache it. It it has many extra columns that you don't need, then you can do a select on it first to select on the columns you will need so it would store all that extra data in memory.


Let's say our parent Dataframe has 'n' columns

we can create 'x' child DataFrames( Lets consider 2 in our case).

The columns for the child Dataframe can be chosen as per desire from any of the parent Dataframe columns.

Consider source has 10 columns and we want to split into 2 DataFrames that contains columns referenced from the parent Dataframe.

The columns for the child Dataframe can be decided using the select Dataframe API

val parentDF = spark.read.format("csv").load("/path of the CSV file")

val Child1_DF = parentDF.select("col1","col2","col3","col9","col10").show()

val child2_DF = parentDF.select("col5", "col6","col7","col8","col1","col2").show()

Notice that the column count in the child dataframes can differ in length and will be less than the parent dataframe column count.

we can also refer to the column names without mentioning the real names using the positional indexes of the desired column from the parent dataframe

Import spark implicits first which acts as a helper class for usage of $-notation to access the columns using the positional indexes

import spark.implicits._
import org.apache.spark.sql.functions._

val child3_DF  = parentDF.select("_c0","_c1","_c2","_c8","_c9").show()

we can also select column basing on certain conditions. Lets say we want only even numbered columns to be selected in the child dataframe. By even we refer to even indexed columns and index being starting from '0'

val parentColumns = parentDF.columns.toList


res0: List[String] = List(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7,_c8,_c9)

val evenParentColumns =  res0.zipWithIndex.filter(_._2 % 2 == 0).map( _._1).toSeq

res1: scala.collection.immutable.Seq[String] = List(_c0, _c2, _c4, _c6,_c8)

Now feed these columns to be selected from the parentDF.Note that the select API need seq type arguments.So we converted the "evenParentColumns" to Seq collection

val child4_DF = parentDF.select(res1.head, res1.tail:_*).show()

This will show the even indexed columns from the parent Dataframe.


| _c0 | _c2 | _c4 |_c6 |_c8 |


|ITE00100554|TMAX|null| E| 1 |

|TE00100554 |TMIN|null| E| 4 |

|GM000010962|PRCP|null| E| 7 |

So Now we are left with the even numbered columns in the dataframe

Similarly we can also apply other operations to the Dataframe column like shown below

val child5_DF = parentDF.select($"_c0", $"_c8" + 1).show()

So by many ways as mentioned we can select the columns in the Dataframe.


There are multiple options (especially in Scala) to select a subset of columns of that Dataframe. The following lines show the options and most of them are documented in the ScalaDocs of Column:

import spark.implicits._
import org.apache.spark.sql.functions.{col, column, expr}

inputDf.select(col("colA"), col("colB"))
inputDf.select(inputDf.col("colA"), inputDf.col("colB"))
inputDf.select(column("colA"), column("colB"))
inputDf.select(expr("colA"), expr("colB"))

// only available in Scala
inputDf.select($"colA", $"colB")
inputDf.select('colA, 'colB) // makes use of Scala's Symbol

// selecting columns based on a given iterable of Strings
val selectedColumns: Seq[Column] = Seq("colA", "colB").map(c => col(c))
inputDf.select(selectedColumns: _*)

// Special cases
col("columnName.field")     // Extracting a struct field
col("`a.column.with.dots`") // Escape `.` in column names.

// select the first or last 2 columns
inputDf.selectExpr(inputDf.columns.take(2): _*)
inputDf.selectExpr(inputDf.columns.takeRight(2): _*)

The usage of $ is possible as Scala provides an implicit class that converts a String into a Column using the method $:

implicit class StringToColumn(val sc : scala.StringContext) extends scala.AnyRef {
  def $(args : scala.Any*) : org.apache.spark.sql.ColumnName = { /* compiled code */ }
}

Typically, when you want to derive one DataFrame to multiple DataFrames it might improve your performance if you persist the original DataFrame before creating the others. At the end you can unpersist the original DataFrame.

Keep in mind that Columns are not resolved at compile time but only when it is compared to the column names of your catalog which happens during analyser phase of the query execution. In case you need stronger type safety you could create a Dataset.

For completeness, here is the csv to try out above code:

// csv file:
// colA,colB,colC
// 1,"foo","bar"

val inputDf = spark.read.format("csv").option("header", "true").load(csvFilePath)

// resulting DataFrame schema
root
 |-- colA: string (nullable = true)
 |-- colB: string (nullable = true)
 |-- colC: string (nullable = true)