Spark Dataframes UPSERT to Postgres Table

If you are going to do it manually and via option 1 mentioned by zero323, you should take a look at Spark source code for the insert statement here

  def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = {
    val columns = rddSchema.fields.map(_.name).mkString(",")
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
    val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)"
    conn.prepareStatement(sql)
  }

The PreparedStatement is part of java.sql and it has methods like execute() and executeUpdate(). You still have to modify the sql accordingly, of course.


It is not supported. DataFrameWriter can either append to or overwrite existing table. If your application requires more complex logic you'll have to deal with this manually.

One option is to use an action (foreach, foreachPartition) with standard JDBC connection. Another one is to write to a temporary and handle the rest directly in the database.

See also SPARK-19335 (Spark should support doing an efficient DataFrame Upsert via JDBC) and related proposals.


To insert JDBC you can use

dataframe.write.mode(SaveMode.Append).jdbc(jdbc_url,table_name,connection_properties)

Also,Dataframe.write gives you a DataFrameWriter and it has some methods to insert the dataframe.

def insertInto(tableName: String): Unit

Inserts the content of the DataFrame to the specified table. It requires that the schema of the DataFrame is the same as the schema of the table.

Because it inserts data to an existing table, format or options will be ignored.

http://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.sql.DataFrameWriter

Nothing yet to update individual records out of the box from spark though


KrisP has the right of it. The best way to do an upsert is not through a prepared statement. It's important to note that this method will insert one at a time with as many partitions as the number of workers you have. If you want to do this in batch you can as well

import java.sql._
dataframe.coalesce("NUMBER OF WORKERS").mapPartitions((d) => Iterator(d)).foreach { batch =>
  val dbc: Connection = DriverManager.getConnection("JDBCURL")
  val st: PreparedStatement = dbc.prepareStatement("YOUR PREPARED STATEMENT")

  batch.grouped("# Of Rows you want per batch").foreach { session =>
    session.foreach { x =>
      st.setDouble(1, x.getDouble(1)) 
      st.addBatch()
    }
    st.executeBatch()
  }
  dbc.close()
}

This will execute batches for each worker and close the DB connection. It gives you control over how many workers, how many batches and allows you to work within those confines.