Spark SQL replacement for MySQL's GROUP_CONCAT aggregate function

Before you proceed: This operations is yet another another groupByKey. While it has multiple legitimate applications it is relatively expensive so be sure to use it only when required.


Not exactly concise or efficient solution but you can use UserDefinedAggregateFunction introduced in Spark 1.5.0:

object GroupConcat extends UserDefinedAggregateFunction {
    def inputSchema = new StructType().add("x", StringType)
    def bufferSchema = new StructType().add("buff", ArrayType(StringType))
    def dataType = StringType
    def deterministic = true 

    def initialize(buffer: MutableAggregationBuffer) = {
      buffer.update(0, ArrayBuffer.empty[String])
    }

    def update(buffer: MutableAggregationBuffer, input: Row) = {
      if (!input.isNullAt(0)) 
        buffer.update(0, buffer.getSeq[String](0) :+ input.getString(0))
    }

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      buffer1.update(0, buffer1.getSeq[String](0) ++ buffer2.getSeq[String](0))
    }

    def evaluate(buffer: Row) = UTF8String.fromString(
      buffer.getSeq[String](0).mkString(","))
}

Example usage:

val df = sc.parallelize(Seq(
  ("username1", "friend1"),
  ("username1", "friend2"),
  ("username2", "friend1"),
  ("username2", "friend3")
)).toDF("username", "friend")

df.groupBy($"username").agg(GroupConcat($"friend")).show

## +---------+---------------+
## | username|        friends|
## +---------+---------------+
## |username1|friend1,friend2|
## |username2|friend1,friend3|
## +---------+---------------+

You can also create a Python wrapper as shown in Spark: How to map Python with Scala or Java User Defined Functions?

In practice it can be faster to extract RDD, groupByKey, mkString and rebuild DataFrame.

You can get a similar effect by combining collect_list function (Spark >= 1.6.0) with concat_ws:

import org.apache.spark.sql.functions.{collect_list, udf, lit}

df.groupBy($"username")
  .agg(concat_ws(",", collect_list($"friend")).alias("friends"))

You can try the collect_list function

sqlContext.sql("select A, collect_list(B), collect_list(C) from Table1 group by A

Or you can regieter a UDF something like

sqlContext.udf.register("myzip",(a:Long,b:Long)=>(a+","+b))

and you can use this function in the query

sqlConttext.sql("select A,collect_list(myzip(B,C)) from tbl group by A")

Here is a function you can use in PySpark:

import pyspark.sql.functions as F

def group_concat(col, distinct=False, sep=','):
    if distinct:
        collect = F.collect_set(col.cast(StringType()))
    else:
        collect = F.collect_list(col.cast(StringType()))
    return F.concat_ws(sep, collect)


table.groupby('username').agg(F.group_concat('friends').alias('friends'))

In SQL:

select username, concat_ws(',', collect_list(friends)) as friends
from table
group by username