Scikit-learn cross validation scoring for regression

clwen picture clwen · Jun 10, 2014 · Viewed 53.1k times · Source

How can one use cross_val_score for regression? The default scoring seems to be accuracy, which is not very meaningful for regression. Supposedly I would like to use mean squared error, is it possible to specify that in cross_val_score?

Tried the following two but doesn't work:

scores = cross_validation.cross_val_score(svr, diabetes.data, diabetes.target, cv=5, scoring='mean_squared_error') 

and

scores = cross_validation.cross_val_score(svr, diabetes.data, diabetes.target, cv=5, scoring=metrics.mean_squared_error)

The first one generates a list of negative numbers while mean squared error should always be non-negative. The second one complains that:

mean_squared_error() takes exactly 2 arguments (3 given)

Answer

Sirrah picture Sirrah · Jun 15, 2014

I dont have the reputation to comment but I want to provide this link for you and/or a passersby where the negative output of the MSE in scikit learn is discussed - https://github.com/scikit-learn/scikit-learn/issues/2439

In addition (to make this a real answer) your first option is correct in that not only is MSE the metric you want to use to compare models but R^2 cannot be calculated depending (I think) on the type of cross-val you are using.

If you choose MSE as a scorer, it outputs a list of errors which you can then take the mean of, like so:

# Doing linear regression with leave one out cross val

from sklearn import cross_validation, linear_model
import numpy as np

# Including this to remind you that it is necessary to use numpy arrays rather 
# than lists otherwise you will get an error
X_digits = np.array(x)
Y_digits = np.array(y)

loo = cross_validation.LeaveOneOut(len(Y_digits))

regr = linear_model.LinearRegression()

scores = cross_validation.cross_val_score(regr, X_digits, Y_digits, scoring='mean_squared_error', cv=loo,)

# This will print the mean of the list of errors that were output and 
# provide your metric for evaluation
print scores.mean()