How do I get a SQL row_number equivalent for a Spark RDD?

Glenn Strycker picture Glenn Strycker · Nov 20, 2014 · Viewed 38.8k times · Source

I need to generate a full list of row_numbers for a data table with many columns.

In SQL, this would look like this:

select
   key_value,
   col1,
   col2,
   col3,
   row_number() over (partition by key_value order by col1, col2 desc, col3)
from
   temp
;

Now, let's say in Spark I have an RDD of the form (K, V), where V=(col1, col2, col3), so my entries are like

(key1, (1,2,3))
(key1, (1,4,7))
(key1, (2,2,3))
(key2, (5,5,5))
(key2, (5,5,9))
(key2, (7,5,5))
etc.

I want to order these using commands like sortBy(), sortWith(), sortByKey(), zipWithIndex, etc. and have a new RDD with the correct row_number

(key1, (1,2,3), 2)
(key1, (1,4,7), 1)
(key1, (2,2,3), 3)
(key2, (5,5,5), 1)
(key2, (5,5,9), 2)
(key2, (7,5,5), 3)
etc.

(I don't care about the parentheses, so the form can also be (K, (col1,col2,col3,rownum)) instead)

How do I do this?

Here's my first attempt:

val sample_data = Seq(((3,4),5,5,5),((3,4),5,5,9),((3,4),7,5,5),((1,2),1,2,3),((1,2),1,4,7),((1,2),2,2,3))

val temp1 = sc.parallelize(sample_data)

temp1.collect().foreach(println)

// ((3,4),5,5,5)
// ((3,4),5,5,9)
// ((3,4),7,5,5)
// ((1,2),1,2,3)
// ((1,2),1,4,7)
// ((1,2),2,2,3)

temp1.map(x => (x, 1)).sortByKey().zipWithIndex.collect().foreach(println)

// ((((1,2),1,2,3),1),0)
// ((((1,2),1,4,7),1),1)
// ((((1,2),2,2,3),1),2)
// ((((3,4),5,5,5),1),3)
// ((((3,4),5,5,9),1),4)
// ((((3,4),7,5,5),1),5)

// note that this isn't ordering with a partition on key value K!

val temp2 = temp1.???

Also note that the function sortBy cannot be applied directly to an RDD, but one must run collect() first, and then the output isn't an RDD, either, but an array

temp1.collect().sortBy(a => a._2 -> -a._3 -> a._4).foreach(println)

// ((1,2),1,4,7)
// ((1,2),1,2,3)
// ((1,2),2,2,3)
// ((3,4),5,5,5)
// ((3,4),5,5,9)
// ((3,4),7,5,5)

Here's a little more progress, but still not partitioned:

val temp2 = sc.parallelize(temp1.map(a => (a._1,(a._2, a._3, a._4))).collect().sortBy(a => a._2._1 -> -a._2._2 -> a._2._3)).zipWithIndex.map(a => (a._1._1, a._1._2._1, a._1._2._2, a._1._2._3, a._2 + 1))

temp2.collect().foreach(println)

// ((1,2),1,4,7,1)
// ((1,2),1,2,3,2)
// ((1,2),2,2,3,3)
// ((3,4),5,5,5,4)
// ((3,4),5,5,9,5)
// ((3,4),7,5,5,6)

Answer

dnlbrky picture dnlbrky · Jun 26, 2015

The row_number() over (partition by ... order by ...) functionality was added to Spark 1.4. This answer uses PySpark/DataFrames.

Create a test DataFrame:

from pyspark.sql import Row, functions as F

testDF = sc.parallelize(
    (Row(k="key1", v=(1,2,3)),
     Row(k="key1", v=(1,4,7)),
     Row(k="key1", v=(2,2,3)),
     Row(k="key2", v=(5,5,5)),
     Row(k="key2", v=(5,5,9)),
     Row(k="key2", v=(7,5,5))
    )
).toDF()

Add the partitioned row number:

from pyspark.sql.window import Window

(testDF
 .select("k", "v",
         F.rowNumber()
         .over(Window
               .partitionBy("k")
               .orderBy("k")
              )
         .alias("rowNum")
        )
 .show()
)

+----+-------+------+
|   k|      v|rowNum|
+----+-------+------+
|key1|[1,2,3]|     1|
|key1|[1,4,7]|     2|
|key1|[2,2,3]|     3|
|key2|[5,5,5]|     1|
|key2|[5,5,9]|     2|
|key2|[7,5,5]|     3|
+----+-------+------+