Maybe it's just because I'm relatively new to the API, but I feel like Spark ML methods often return DFs that are unnecessarily difficult to work with.
This time, it's the ALS model that's tripping me up. Specifically, the recommendForAllUsers method. Let's reconstruct the type of DF it would return:
scala> val arrayType = ArrayType(new StructType().add("itemId", IntegerType).add("rating", FloatType))
scala> val recs = Seq((1, Array((1, .7), (2, .5))), (2, Array((0, .9), (4, .1)))).
toDF("userId", "recommendations").
select($"userId", $"recommendations".cast(arrayType))
scala> recs.show()
+------+------------------+
|userId| recommendations|
+------+------------------+
| 1|[[1,0.7], [2,0.5]]|
| 2|[[0,0.9], [4,0.1]]|
+------+------------------+
scala> recs.printSchema
root
|-- userId: integer (nullable = false)
|-- recommendations: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- itemId: integer (nullable = true)
| | |-- rating: float (nullable = true)
Now, I only care about the itemId
in the recommendations
column. After all, the method is recommendForAllUsers
not recommendAndScoreForAllUsers
(ok ok I'll stop being sassy...)
How do I do this??
I thought I had it when I created a UDF:
scala> val itemIds = udf((arr: Array[(Int, Float)]) => arr.map(_._1))
but that produces an error:
scala> recs.withColumn("items", items($"recommendations"))
org.apache.spark.sql.AnalysisException: cannot resolve 'UDF(recommendations)' due to data type mismatch: argument 1 requires array<struct<_1:int,_2:float>> type, however, '`recommendations`' is of array<struct<itemId:int,rating:float>> type.;;
'Project [userId#87, recommendations#92, UDF(recommendations#92) AS items#238]
+- Project [userId#87, cast(recommendations#88 as array<struct<itemId:int,rating:float>>) AS recommendations#92]
+- Project [_1#84 AS userId#87, _2#85 AS recommendations#88]
+- LocalRelation [_1#84, _2#85]
Any ideas? thanks!
wow, my coworker came up with an extremely elegant solution:
scala> recs.select($"userId", $"recommendations.itemId").show
+------+------+
|userId|itemId|
+------+------+
| 1|[1, 2]|
| 2|[0, 4]|
+------+------+
So maybe the Spark ML API isn't that difficult after all :)