How to get accuracy precision, recall and ROC from cross validation in Spark ml lib?

user3309479 picture user3309479 · Jan 18, 2017 · Viewed 13.8k times · Source

I am using Spark 2.0.2. I am also using the "ml" library for Machine Learning with Datasets. What I want to do is run algorithms with cross validation and extract the mentioned metrics (accuracy, precision, recall, ROC, confusion matrix). My data labels are binary.

By using the MulticlassClassificationEvaluator I can only get the accuracy of the algorithm by accessing "avgMetrics". Also, by using the BinaryClassificationEvaluator I can get the area under ROC. But I cannot use them both. So, is there a way that I can extract all of the wanted metrics?

Answer

ShuoshuoFan picture ShuoshuoFan · Jan 23, 2018

Have tried to use MLlib to evaluate your result.

I've transformed the dataset to RDD, then used MulticlassMetrics in MLlib

You can see a demo here: Spark DecisionTreeExample.scala

private[ml] def evaluateClassificationModel(
      model: Transformer,
      data: DataFrame,
      labelColName: String): Unit = {
    val fullPredictions = model.transform(data).cache()
    val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0))
    val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0))
    // Print number of classes for reference.
    val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match {
      case Some(n) => n
      case None => throw new RuntimeException(
        "Unknown failure when indexing labels for classification.")
    }
    val accuracy = new MulticlassMetrics(predictions.zip(labels)).accuracy
    println(s"  Accuracy ($numClasses classes): $accuracy")
  }