Spark: How to translate count(distinct(value)) in Dataframe API's

Fabio Fantoni picture Fabio Fantoni · May 13, 2015 · Viewed 61.7k times · Source

I'm trying to compare different ways to aggregate my data.

This is my input data with 2 elements (page,visitor):

(PAG1,V1)
(PAG1,V1)
(PAG2,V1)
(PAG2,V2)
(PAG2,V1)
(PAG1,V1)
(PAG1,V2)
(PAG1,V1)
(PAG1,V2)
(PAG1,V1)
(PAG2,V2)
(PAG1,V3)

Working with a SQL command into Spark SQL with this code:

import sqlContext.implicits._
case class Log(page: String, visitor: String)
val logs = data.map(p => Log(p._1,p._2)).toDF()
logs.registerTempTable("logs")
val sqlResult= sqlContext.sql(
                              """select page
                                       ,count(distinct visitor) as visitor
                                   from logs
                               group by page
                              """)
val result = sqlResult.map(x=>(x(0).toString,x(1).toString))
result.foreach(println)

I get this output:

(PAG1,3) // PAG1 has been visited by 3 different visitors
(PAG2,2) // PAG2 has been visited by 2 different visitors

Now, I would like to get the same result using Dataframes and thiers API, but I can't get the same output:

import sqlContext.implicits._
case class Log(page: String, visitor: String)
val logs = data.map(p => Coppia(p._1,p._2)).toDF()
val result = log.select("page","visitor").groupBy("page").count().distinct
result.foreach(println)

In fact, that's what I get as output:

[PAG1,8]  // just the simple page count for every page
[PAG2,4]

It's probably something dumb, but I can't see it right now.

Thanks in advance!

FF

Answer

yjshen picture yjshen · May 13, 2015

What you need is the DataFrame aggregation function countDistinct:

import sqlContext.implicits._
import org.apache.spark.sql.functions._

case class Log(page: String, visitor: String)

val logs = data.map(p => Log(p._1,p._2))
            .toDF()

val result = logs.select("page","visitor")
            .groupBy('page)
            .agg('page, countDistinct('visitor))

result.foreach(println)