Inferring Spark DataType from string literals
Spark/Scala probably already have a helper/util method that will do this for me.
You're right. Spark already has its own schema and data type inference code that it uses to infer the schema from underlying data sources (csv, json etc.) So you can look at that to implement your own (the actual implementation is marked private to Spark and is tied to RDD and internal classes, so it can not be used directly from code outside of Spark but should give you a good idea on how to go about it.)
Given that csv is flat type (and json can have nested structure), csv schema inference is relative more straight forward and should help you with the task you're trying to achieve above. So I will explain how csv inference works (json inference just needs to take possibly nested structure into account but data type inference is pretty analogous).
With that prologue, the thing you want to have a look at is CSVInferSchema object. Particularly, look at the infer
method which takes an RDD[Array[String]]
and infer the data type for each element of the array across the whole of RDD. The way it does is -- it marks each field as NullType
to begin with and then as it iterates over next row of values (Array[String]
) in the RDD
it updates the already inferred DataType
to a new DataType
if the new DataType
is more specific. This is happening here:
val rootTypes: Array[DataType] =
tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
Now inferRowType
calls inferField
for each of the field in the row. inferField
implementation is what you're probably looking for -- it takes type inferred so far for a particular field and the string value of the field for current row as parameter. It then returns either the existing inferred type or if the new type inferred is more specific then the new type.
Relevant section of the code is as follows:
typeSoFar match {
case NullType => tryParseInteger(field, options)
case IntegerType => tryParseInteger(field, options)
case LongType => tryParseLong(field, options)
case _: DecimalType => tryParseDecimal(field, options)
case DoubleType => tryParseDouble(field, options)
case TimestampType => tryParseTimestamp(field, options)
case BooleanType => tryParseBoolean(field, options)
case StringType => StringType
case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
}
Please note that if the typeSoFar
is NullType then it first tries to parse it as Integer
but tryParseInteger
call is a chain of call to lower type parsing. So if it is not able to parse the value as Integer then it will invoke tryParseLong
which on failure will invoke tryParseDecimal
which on failure will invoke tryParseDouble
w.o.f.w.i. tryParseTimestamp
w.o.f.w.i tryParseBoolean
w.o.f.w.i. finally stringType
.
So you can use pretty much the similar logic to implement whatever your use case is. (If you do not need to merge across rows then you simply implement all the tryParse*
methods verbatim and simply invoke tryParseInteger
. No need to write your own regex.)
Hope this helps.
Yes, of course Spark has magic you need.
In Spark 2.x it's CatalystSqlParser
object, defined here.
For example:
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
CatalystSqlParser.parseDataType("string") // StringType
CatalystSqlParser.parseDataType("int") // IntegerType
And so on.
But as I understand, it's not a part of public API and so may change in next versions without any warnings.
So you may just implement your method as:
def toSparkType(inputType: String): DataType = CatalystSqlParser.parseDataType(inputType)