I have some data in the following format (either RDD or Spark DataFrame):
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
rdd = sc.parallelize([('X01',41,'US',3),
('X01',41,'UK',1),
('X01',41,'CA',2),
('X02',72,'US',4),
('X02',72,'UK',6),
('X02',72,'CA',7),
('X02',72,'XX',8)])
# convert to a Spark DataFrame
schema = StructType([StructField('ID', StringType(), True),
StructField('Age', IntegerType(), True),
StructField('Country', StringType(), True),
StructField('Score', IntegerType(), True)])
df = sqlContext.createDataFrame(rdd, schema)
What I would like to do is to 'reshape' the data, convert certain rows in Country(specifically US, UK and CA) into columns:
ID Age US UK CA
'X01' 41 3 1 2
'X02' 72 4 6 7
Essentially, I need something along the lines of Python's pivot
workflow:
categories = ['US', 'UK', 'CA']
new_df = df[df['Country'].isin(categories)].pivot(index = 'ID',
columns = 'Country',
values = 'Score')
My dataset is rather large so I can't really collect()
and ingest the data into memory to do the reshaping in Python itself. Is there a way to convert Python's .pivot()
into an invokable function while mapping either an RDD or a Spark DataFrame? Any help would be appreciated!
Since Spark 1.6 you can use pivot
function on GroupedData
and provide aggregate expression.
pivoted = (df
.groupBy("ID", "Age")
.pivot(
"Country",
['US', 'UK', 'CA']) # Optional list of levels
.sum("Score")) # alternatively you can use .agg(expr))
pivoted.show()
## +---+---+---+---+---+
## | ID|Age| US| UK| CA|
## +---+---+---+---+---+
## |X01| 41| 3| 1| 2|
## |X02| 72| 4| 6| 7|
## +---+---+---+---+---+
Levels can be omitted but if provided can both boost performance and serve as an internal filter.
This method is still relatively slow but certainly beats manual passing data manually between JVM and Python.