Removing duplicates from rows based on specific columns in an RDD/Spark DataFrame

Jason picture Jason · May 15, 2015 · Viewed 130.8k times · Source

Let's say I have a rather large dataset in the following form:

data = sc.parallelize([('Foo',41,'US',3),
                       ('Foo',39,'UK',1),
                       ('Bar',57,'CA',2),
                       ('Bar',72,'CA',2),
                       ('Baz',22,'US',6),
                       ('Baz',36,'US',6)])

What I would like to do is remove duplicate rows based on the values of the first,third and fourth columns only.

Removing entirely duplicate rows is straightforward:

data = data.distinct()

and either row 5 or row 6 will be removed

But how do I only remove duplicate rows based on columns 1, 3 and 4 only? i.e. remove either one one of these:

('Baz',22,'US',6)
('Baz',36,'US',6)

In Python, this could be done by specifying columns with .drop_duplicates(). How can I achieve the same in Spark/Pyspark?

Answer

vaer-k picture vaer-k · Oct 5, 2016

Pyspark does include a dropDuplicates() method, which was introduced in 1.4. https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.DataFrame.dropDuplicates

>>> from pyspark.sql import Row
>>> df = sc.parallelize([ \
...     Row(name='Alice', age=5, height=80), \
...     Row(name='Alice', age=5, height=80), \
...     Row(name='Alice', age=10, height=80)]).toDF()
>>> df.dropDuplicates().show()
+---+------+-----+
|age|height| name|
+---+------+-----+
|  5|    80|Alice|
| 10|    80|Alice|
+---+------+-----+

>>> df.dropDuplicates(['name', 'height']).show()
+---+------+-----+
|age|height| name|
+---+------+-----+
|  5|    80|Alice|
+---+------+-----+