How to flatten columns of type array of structs (as returned by Spark ML API)?

Nick Resnick picture Nick Resnick · Oct 13, 2017 · Viewed 8k times · Source

Maybe it's just because I'm relatively new to the API, but I feel like Spark ML methods often return DFs that are unnecessarily difficult to work with.

This time, it's the ALS model that's tripping me up. Specifically, the recommendForAllUsers method. Let's reconstruct the type of DF it would return:

scala> val arrayType = ArrayType(new StructType().add("itemId", IntegerType).add("rating", FloatType))

scala> val recs = Seq((1, Array((1, .7), (2, .5))), (2, Array((0, .9), (4, .1)))).
  toDF("userId", "recommendations").
  select($"userId", $"recommendations".cast(arrayType))

scala> recs.show()
+------+------------------+
|userId|   recommendations|
+------+------------------+
|     1|[[1,0.7], [2,0.5]]|
|     2|[[0,0.9], [4,0.1]]|
+------+------------------+
scala> recs.printSchema
root
 |-- userId: integer (nullable = false)
 |-- recommendations: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- itemId: integer (nullable = true)
 |    |    |-- rating: float (nullable = true)

Now, I only care about the itemId in the recommendations column. After all, the method is recommendForAllUsers not recommendAndScoreForAllUsers (ok ok I'll stop being sassy...)

How do I do this??

I thought I had it when I created a UDF:

scala> val itemIds = udf((arr: Array[(Int, Float)]) => arr.map(_._1))

but that produces an error:

scala> recs.withColumn("items", items($"recommendations"))
org.apache.spark.sql.AnalysisException: cannot resolve 'UDF(recommendations)' due to data type mismatch: argument 1 requires array<struct<_1:int,_2:float>> type, however, '`recommendations`' is of array<struct<itemId:int,rating:float>> type.;;
'Project [userId#87, recommendations#92, UDF(recommendations#92) AS items#238]
+- Project [userId#87, cast(recommendations#88 as array<struct<itemId:int,rating:float>>) AS recommendations#92]
   +- Project [_1#84 AS userId#87, _2#85 AS recommendations#88]
      +- LocalRelation [_1#84, _2#85]

Any ideas? thanks!

Answer

Nick Resnick picture Nick Resnick · Oct 13, 2017

wow, my coworker came up with an extremely elegant solution:

scala> recs.select($"userId", $"recommendations.itemId").show
+------+------+
|userId|itemId|
+------+------+
|     1|[1, 2]|
|     2|[0, 4]|
+------+------+

So maybe the Spark ML API isn't that difficult after all :)