How to store custom objects in Dataset?

  1. Using generic encoders.

    There are two generic encoders available for now kryo and javaSerialization where the latter one is explicitly described as:

    extremely inefficient and should only be used as the last resort.

    Assuming following class

    class Bar(i: Int) {
      override def toString = s"bar $i"
      def bar = i
    }
    

    you can use these encoders by adding implicit encoder:

    object BarEncoders {
      implicit def barEncoder: org.apache.spark.sql.Encoder[Bar] = 
      org.apache.spark.sql.Encoders.kryo[Bar]
    }
    

    which can be used together as follows:

    object Main {
      def main(args: Array[String]) {
        val sc = new SparkContext("local",  "test", new SparkConf())
        val sqlContext = new SQLContext(sc)
        import sqlContext.implicits._
        import BarEncoders._
    
        val ds = Seq(new Bar(1)).toDS
        ds.show
    
        sc.stop()
      }
    }
    

    It stores objects as binary column so when converted to DataFrame you get following schema:

    root
     |-- value: binary (nullable = true)
    

    It is also possible to encode tuples using kryo encoder for specific field:

    val longBarEncoder = Encoders.tuple(Encoders.scalaLong, Encoders.kryo[Bar])
    
    spark.createDataset(Seq((1L, new Bar(1))))(longBarEncoder)
    // org.apache.spark.sql.Dataset[(Long, Bar)] = [_1: bigint, _2: binary]
    

    Please note that we don't depend on implicit encoders here but pass encoder explicitly so this most likely won't work with toDS method.

  2. Using implicit conversions:

    Provide implicit conversions between representation which can be encoded and custom class, for example:

    object BarConversions {
      implicit def toInt(bar: Bar): Int = bar.bar
      implicit def toBar(i: Int): Bar = new Bar(i)
    }
    
    object Main {
      def main(args: Array[String]) {
        val sc = new SparkContext("local",  "test", new SparkConf())
        val sqlContext = new SQLContext(sc)
        import sqlContext.implicits._
        import BarConversions._
    
        type EncodedBar = Int
    
        val bars: RDD[EncodedBar]  = sc.parallelize(Seq(new Bar(1)))
        val barsDS = bars.toDS
    
        barsDS.show
        barsDS.map(_.bar).show
    
        sc.stop()
      }
    }
    

Related questions:

  • How to create encoder for Option type constructor, e.g. Option[Int]?

Update

This answer is still valid and informative, although things are now better since 2.2/2.3, which adds built-in encoder support for Set, Seq, Map, Date, Timestamp, and BigDecimal. If you stick to making types with only case classes and the usual Scala types, you should be fine with just the implicit in SQLImplicits.


Unfortunately, virtually nothing has been added to help with this. Searching for @since 2.0.0 in Encoders.scala or SQLImplicits.scala finds things mostly to do with primitive types (and some tweaking of case classes). So, first thing to say: there currently is no real good support for custom class encoders. With that out of the way, what follows is some tricks which do as good a job as we can ever hope to, given what we currently have at our disposal. As an upfront disclaimer: this won't work perfectly and I'll do my best to make all limitations clear and upfront.

What exactly is the problem

When you want to make a dataset, Spark "requires an encoder (to convert a JVM object of type T to and from the internal Spark SQL representation) that is generally created automatically through implicits from a SparkSession, or can be created explicitly by calling static methods on Encoders" (taken from the docs on createDataset). An encoder will take the form Encoder[T] where T is the type you are encoding. The first suggestion is to add import spark.implicits._ (which gives you these implicit encoders) and the second suggestion is to explicitly pass in the implicit encoder using this set of encoder related functions.

There is no encoder available for regular classes, so

import spark.implicits._
class MyObj(val i: Int)
// ...
val d = spark.createDataset(Seq(new MyObj(1),new MyObj(2),new MyObj(3)))

will give you the following implicit related compile time error:

Unable to find encoder for type stored in a Dataset. Primitive types (Int, String, etc) and Product types (case classes) are supported by importing sqlContext.implicits._ Support for serializing other types will be added in future releases

However, if you wrap whatever type you just used to get the above error in some class that extends Product, the error confusingly gets delayed to runtime, so

import spark.implicits._
case class Wrap[T](unwrap: T)
class MyObj(val i: Int)
// ...
val d = spark.createDataset(Seq(Wrap(new MyObj(1)),Wrap(new MyObj(2)),Wrap(new MyObj(3))))

Compiles just fine, but fails at runtime with

java.lang.UnsupportedOperationException: No Encoder found for MyObj

The reason for this is that the encoders Spark creates with the implicits are actually only made at runtime (via scala relfection). In this case, all Spark checks at compile time is that the outermost class extends Product (which all case classes do), and only realizes at runtime that it still doesn't know what to do with MyObj (the same problem occurs if I tried to make a Dataset[(Int,MyObj)] - Spark waits until runtime to barf on MyObj). These are central problems that are in dire need of being fixed:

  • some classes that extend Product compile despite always crashing at runtime and
  • there is no way of passing in custom encoders for nested types (I have no way of feeding Spark an encoder for just MyObj such that it then knows how to encode Wrap[MyObj] or (Int,MyObj)).

Just use kryo

The solution everyone suggests is to use the kryo encoder.

import spark.implicits._
class MyObj(val i: Int)
implicit val myObjEncoder = org.apache.spark.sql.Encoders.kryo[MyObj]
// ...
val d = spark.createDataset(Seq(new MyObj(1),new MyObj(2),new MyObj(3)))

This gets pretty tedious fast though. Especially if your code is manipulating all sorts of datasets, joining, grouping etc. You end up racking up a bunch of extra implicits. So, why not just make an implicit that does this all automatically?

import scala.reflect.ClassTag
implicit def kryoEncoder[A](implicit ct: ClassTag[A]) = 
  org.apache.spark.sql.Encoders.kryo[A](ct)

And now, it seems like I can do almost anything I want (the example below won't work in the spark-shell where spark.implicits._ is automatically imported)

class MyObj(val i: Int)

val d1 = spark.createDataset(Seq(new MyObj(1),new MyObj(2),new MyObj(3)))
val d2 = d1.map(d => (d.i+1,d)).alias("d2") // mapping works fine and ..
val d3 = d1.map(d => (d.i,  d)).alias("d3") // .. deals with the new type
val d4 = d2.joinWith(d3, $"d2._1" === $"d3._1") // Boom!

Or almost. The problem is that using kryo leads to Spark just storing every row in the dataset as a flat binary object. For map, filter, foreach that is enough, but for operations like join, Spark really needs these to be separated into columns. Inspecting the schema for d2 or d3, you see there is just one binary column:

d2.printSchema
// root
//  |-- value: binary (nullable = true)

Partial solution for tuples

So, using the magic of implicits in Scala (more in 6.26.3 Overloading Resolution), I can make myself a series of implicits that will do as good a job as possible, at least for tuples, and will work well with existing implicits:

import org.apache.spark.sql.{Encoder,Encoders}
import scala.reflect.ClassTag
import spark.implicits._  // we can still take advantage of all the old implicits

implicit def single[A](implicit c: ClassTag[A]): Encoder[A] = Encoders.kryo[A](c)

implicit def tuple2[A1, A2](
  implicit e1: Encoder[A1],
           e2: Encoder[A2]
): Encoder[(A1,A2)] = Encoders.tuple[A1,A2](e1, e2)

implicit def tuple3[A1, A2, A3](
  implicit e1: Encoder[A1],
           e2: Encoder[A2],
           e3: Encoder[A3]
): Encoder[(A1,A2,A3)] = Encoders.tuple[A1,A2,A3](e1, e2, e3)

// ... you can keep making these

Then, armed with these implicits, I can make my example above work, albeit with some column renaming

class MyObj(val i: Int)

val d1 = spark.createDataset(Seq(new MyObj(1),new MyObj(2),new MyObj(3)))
val d2 = d1.map(d => (d.i+1,d)).toDF("_1","_2").as[(Int,MyObj)].alias("d2")
val d3 = d1.map(d => (d.i  ,d)).toDF("_1","_2").as[(Int,MyObj)].alias("d3")
val d4 = d2.joinWith(d3, $"d2._1" === $"d3._1")

I haven't yet figured out how to get the expected tuple names (_1, _2, ...) by default without renaming them - if someone else wants to play around with this, this is where the name "value" gets introduced and this is where the tuple names are usually added. However, the key point is that that I now have a nice structured schema:

d4.printSchema
// root
//  |-- _1: struct (nullable = false)
//  |    |-- _1: integer (nullable = true)
//  |    |-- _2: binary (nullable = true)
//  |-- _2: struct (nullable = false)
//  |    |-- _1: integer (nullable = true)
//  |    |-- _2: binary (nullable = true)

So, in summary, this workaround:

  • allows us to get separate columns for tuples (so we can join on tuples again, yay!)
  • we can again just rely on the implicits (so no need to be passing in kryo all over the place)
  • is almost entirely backwards compatible with import spark.implicits._ (with some renaming involved)
  • does not let us join on the kyro serialized binary columns, let alone on fields those may have
  • has the unpleasant side-effect of renaming some of the tuple columns to "value" (if necessary, this can be undone by converting .toDF, specifying new column names, and converting back to a dataset - and the schema names seem to be preserved through joins, where they are most needed).

Partial solution for classes in general

This one is less pleasant and has no good solution. However, now that we have the tuple solution above, I have a hunch the implicit conversion solution from another answer will be a bit less painful too since you can convert your more complex classes to tuples. Then, after creating the dataset, you'd probably rename the columns using the dataframe approach. If all goes well, this is really an improvement since I can now perform joins on the fields of my classes. If I had just used one flat binary kryo serializer that wouldn't have been possible.

Here is an example that does a bit of everything: I have a class MyObj which has fields of types Int, java.util.UUID, and Set[String]. The first takes care of itself. The second, although I could serialize using kryo would be more useful if stored as a String (since UUIDs are usually something I'll want to join against). The third really just belongs in a binary column.

class MyObj(val i: Int, val u: java.util.UUID, val s: Set[String])

// alias for the type to convert to and from
type MyObjEncoded = (Int, String, Set[String])

// implicit conversions
implicit def toEncoded(o: MyObj): MyObjEncoded = (o.i, o.u.toString, o.s)
implicit def fromEncoded(e: MyObjEncoded): MyObj =
  new MyObj(e._1, java.util.UUID.fromString(e._2), e._3)

Now, I can create a dataset with a nice schema using this machinery:

val d = spark.createDataset(Seq[MyObjEncoded](
  new MyObj(1, java.util.UUID.randomUUID, Set("foo")),
  new MyObj(2, java.util.UUID.randomUUID, Set("bar"))
)).toDF("i","u","s").as[MyObjEncoded]

And the schema shows me I columns with the right names and with the first two both things I can join against.

d.printSchema
// root
//  |-- i: integer (nullable = false)
//  |-- u: string (nullable = true)
//  |-- s: binary (nullable = true)

You can use UDTRegistration and then Case Classes, Tuples, etc... all work correctly with your User Defined Type!

Say you want to use a custom Enum:

trait CustomEnum { def value:String }
case object Foo extends CustomEnum  { val value = "F" }
case object Bar extends CustomEnum  { val value = "B" }
object CustomEnum {
  def fromString(str:String) = Seq(Foo, Bar).find(_.value == str).get
}

Register it like this:

// First define a UDT class for it:
class CustomEnumUDT extends UserDefinedType[CustomEnum] {
  override def sqlType: DataType = org.apache.spark.sql.types.StringType
  override def serialize(obj: CustomEnum): Any = org.apache.spark.unsafe.types.UTF8String.fromString(obj.value)
  // Note that this will be a UTF8String type
  override def deserialize(datum: Any): CustomEnum = CustomEnum.fromString(datum.toString)
  override def userClass: Class[CustomEnum] = classOf[CustomEnum]
}

// Then Register the UDT Class!
// NOTE: you have to put this file into the org.apache.spark package!
UDTRegistration.register(classOf[CustomEnum].getName, classOf[CustomEnumUDT].getName)

Then USE IT!

case class UsingCustomEnum(id:Int, en:CustomEnum)

val seq = Seq(
  UsingCustomEnum(1, Foo),
  UsingCustomEnum(2, Bar),
  UsingCustomEnum(3, Foo)
).toDS()
seq.filter(_.en == Foo).show()
println(seq.collect())

Say you want to use a Polymorphic Record:

trait CustomPoly
case class FooPoly(id:Int) extends CustomPoly
case class BarPoly(value:String, secondValue:Long) extends CustomPoly

... and the use it like this:

case class UsingPoly(id:Int, poly:CustomPoly)

Seq(
  UsingPoly(1, new FooPoly(1)),
  UsingPoly(2, new BarPoly("Blah", 123)),
  UsingPoly(3, new FooPoly(1))
).toDS

polySeq.filter(_.poly match {
  case FooPoly(value) => value == 1
  case _ => false
}).show()

You can write a custom UDT that encodes everything to bytes (I'm using java serialization here but it's probably better to instrument Spark's Kryo context).

First define the UDT class:

class CustomPolyUDT extends UserDefinedType[CustomPoly] {
  val kryo = new Kryo()

  override def sqlType: DataType = org.apache.spark.sql.types.BinaryType
  override def serialize(obj: CustomPoly): Any = {
    val bos = new ByteArrayOutputStream()
    val oos = new ObjectOutputStream(bos)
    oos.writeObject(obj)

    bos.toByteArray
  }
  override def deserialize(datum: Any): CustomPoly = {
    val bis = new ByteArrayInputStream(datum.asInstanceOf[Array[Byte]])
    val ois = new ObjectInputStream(bis)
    val obj = ois.readObject()
    obj.asInstanceOf[CustomPoly]
  }

  override def userClass: Class[CustomPoly] = classOf[CustomPoly]
}

Then register it:

// NOTE: The file you do this in has to be inside of the org.apache.spark package!
UDTRegistration.register(classOf[CustomPoly].getName, classOf[CustomPolyUDT].getName)

Then you can use it!

// As shown above:
case class UsingPoly(id:Int, poly:CustomPoly)

Seq(
  UsingPoly(1, new FooPoly(1)),
  UsingPoly(2, new BarPoly("Blah", 123)),
  UsingPoly(3, new FooPoly(1))
).toDS

polySeq.filter(_.poly match {
  case FooPoly(value) => value == 1
  case _ => false
}).show()