I am starting to use Spark DataFrames and I need to be able to pivot the data to create multiple columns out of 1 column with multiple rows. There is built in functionality for that in Scalding and I believe in Pandas in Python, but I can't find anything for the new Spark Dataframe.
I assume I can write custom function of some sort that will do this but I'm not even sure how to start, especially since I am a novice with Spark. I anyone knows how to do this with built in functionality or suggestions for how to write something in Scala, it is greatly appreciated.
As mentioned by David Anderson Spark provides pivot
function since version 1.6. General syntax looks as follows:
df
.groupBy(grouping_columns)
.pivot(pivot_column, [values])
.agg(aggregate_expressions)
Usage examples using nycflights13
and csv
format:
Python:
from pyspark.sql.functions import avg
flights = (sqlContext
.read
.format("csv")
.options(inferSchema="true", header="true")
.load("flights.csv")
.na.drop())
flights.registerTempTable("flights")
sqlContext.cacheTable("flights")
gexprs = ("origin", "dest", "carrier")
aggexpr = avg("arr_delay")
flights.count()
## 336776
%timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count()
## 10 loops, best of 3: 1.03 s per loop
Scala:
val flights = sqlContext
.read
.format("csv")
.options(Map("inferSchema" -> "true", "header" -> "true"))
.load("flights.csv")
flights
.groupBy($"origin", $"dest", $"carrier")
.pivot("hour")
.agg(avg($"arr_delay"))
Java:
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.*;
Dataset<Row> df = spark.read().format("csv")
.option("inferSchema", "true")
.option("header", "true")
.load("flights.csv");
df.groupBy(col("origin"), col("dest"), col("carrier"))
.pivot("hour")
.agg(avg(col("arr_delay")));
R / SparkR:
library(magrittr)
flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE)
flights %>%
groupBy("origin", "dest", "carrier") %>%
pivot("hour") %>%
agg(avg(column("arr_delay")))
R / sparklyr
library(dplyr)
flights <- spark_read_csv(sc, "flights", "flights.csv")
avg.arr.delay <- function(gdf) {
expr <- invoke_static(
sc,
"org.apache.spark.sql.functions",
"avg",
"arr_delay"
)
gdf %>% invoke("agg", expr, list())
}
flights %>%
sdf_pivot(origin + dest + carrier ~ hour, fun.aggregate=avg.arr.delay)
SQL:
Note that PIVOT keyword in Spark SQL is supported starting from version 2.4.
CREATE TEMPORARY VIEW flights
USING csv
OPTIONS (header 'true', path 'flights.csv', inferSchema 'true') ;
SELECT * FROM (
SELECT origin, dest, carrier, arr_delay, hour FROM flights
) PIVOT (
avg(arr_delay)
FOR hour IN (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)
);
Example data:
"year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","origin","dest","air_time","distance","hour","minute","time_hour"
2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00
2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00
2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00
2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00
2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00
2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00
2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00
2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00
2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00
2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00
Performance considerations:
Generally speaking pivoting is an expensive operation.
if you can, try to provide values
list, as this avoids an extra hit to compute the uniques:
vs = list(range(25))
%timeit -n10 flights.groupBy(*gexprs ).pivot("hour", vs).agg(aggexpr).count()
## 10 loops, best of 3: 392 ms per loop
in some cases it proved to be beneficial (likely no longer worth the effort in 2.0 or later) to repartition
and / or pre-aggregate the data
for reshaping only, you can use first
: Pivot String column on Pyspark Dataframe
Related questions: