I'm trying to use Spark dataframes instead of RDDs since they appear to be more high-level than RDDs and tend to produce more readable code.
In a 14-nodes Google Dataproc cluster, I have about 6 millions names that are translated to ids by two different systems: sa
and sb
. Each Row
contains name
, id_sa
and id_sb
. My goal is to produce a mapping from id_sa
to id_sb
such that for each id_sa
, the corresponding id_sb
is the most frequent id among all names attached to id_sa
.
Let's try to clarify with an example. If I have the following rows:
[Row(name='n1', id_sa='a1', id_sb='b1'),
Row(name='n2', id_sa='a1', id_sb='b2'),
Row(name='n3', id_sa='a1', id_sb='b2'),
Row(name='n4', id_sa='a2', id_sb='b2')]
My goal is to produce a mapping from a1
to b2
. Indeed, the names associated to a1
are n1
, n2
and n3
, which map respectively to b1
, b2
and b2
, so b2
is the most frequent mapping in the names associated to a1
. In the same way, a2
will be mapped to b2
. It's OK to assume that there will always be a winner: no need to break ties.
I was hoping that I could use groupBy(df.id_sa)
on my dataframe, but I don't know what to do next. I was hoping for an aggregation that could produce, in the end, the following rows:
[Row(id_sa=a1, max_id_sb=b2),
Row(id_sa=a2, max_id_sb=b2)]
But maybe I'm trying to use the wrong tool and I should just go back to using RDDs.
Using join
(it will result in more than one row in group in case of ties):
import pyspark.sql.functions as F
from pyspark.sql.functions import count, col
cnts = df.groupBy("id_sa", "id_sb").agg(count("*").alias("cnt")).alias("cnts")
maxs = cnts.groupBy("id_sa").agg(F.max("cnt").alias("mx")).alias("maxs")
cnts.join(maxs,
(col("cnt") == col("mx")) & (col("cnts.id_sa") == col("maxs.id_sa"))
).select(col("cnts.id_sa"), col("cnts.id_sb"))
Using window functions (will drop ties):
from pyspark.sql.functions import row_number
from pyspark.sql.window import Window
w = Window().partitionBy("id_sa").orderBy(col("cnt").desc())
(cnts
.withColumn("rn", row_number().over(w))
.where(col("rn") == 1)
.select("id_sa", "id_sb"))
Using struct
ordering:
from pyspark.sql.functions import struct
(cnts
.groupBy("id_sa")
.agg(F.max(struct(col("cnt"), col("id_sb"))).alias("max"))
.select(col("id_sa"), col("max.id_sb")))