Spark/Scala repeated calls to withColumn() using the same function on multiple columns

Damian Satterthwaite-Phillips picture Damian Satterthwaite-Phillips · Dec 30, 2016 · Viewed 16.6k times · Source

I currently have code in which I repeatedly apply the same procedure to multiple DataFrame Columns via multiple chains of .withColumn, and am wanting to create a function to streamline the procedure. In my case, I am finding cumulative sums over columns aggregated by keys:

val newDF = oldDF
  .withColumn("cumA", sum("A").over(Window.partitionBy("ID").orderBy("time")))
  .withColumn("cumB", sum("B").over(Window.partitionBy("ID").orderBy("time")))
  .withColumn("cumC", sum("C").over(Window.partitionBy("ID").orderBy("time")))
  //.withColumn(...)

What I would like is either something like:

def createCumulativeColums(cols: Array[String], df: DataFrame): DataFrame = {
  // Implement the above cumulative sums, partitioning, and ordering
}

or better yet:

def withColumns(cols: Array[String], df: DataFrame, f: function): DataFrame = {
  // Implement a udf/arbitrary function on all the specified columns
}

Answer

zero323 picture zero323 · Dec 30, 2016

You can use select with varargs including *:

import spark.implicits._

df.select($"*" +: Seq("A", "B", "C").map(c => 
  sum(c).over(Window.partitionBy("ID").orderBy("time")).alias(s"cum$c")
): _*)

This:

  • Maps columns names to window expressions with Seq("A", ...).map(...)
  • Prepends all pre-existing columns with $"*" +: ....
  • Unpacks combined sequence with ... : _*.

and can be generalized as:

import org.apache.spark.sql.{Column, DataFrame}

/**
 * @param cols a sequence of columns to transform
 * @param df an input DataFrame
 * @param f a function to be applied on each col in cols
 */
def withColumns(cols: Seq[String], df: DataFrame, f: String => Column) =
  df.select($"*" +: cols.map(c => f(c)): _*)

If you find withColumn syntax more readable you can use foldLeft:

Seq("A", "B", "C").foldLeft(df)((df, c) =>
  df.withColumn(s"cum$c",  sum(c).over(Window.partitionBy("ID").orderBy("time")))
)

which can be generalized for example to:

/**
 * @param cols a sequence of columns to transform
 * @param df an input DataFrame
 * @param f a function to be applied on each col in cols
 * @param name a function mapping from input to output name.
 */
def withColumns(cols: Seq[String], df: DataFrame, 
    f: String =>  Column, name: String => String = identity) =
  cols.foldLeft(df)((df, c) => df.withColumn(name(c), f(c)))