Is it ok to define your own cost function for logistic regression?

London guy picture London guy · Aug 28, 2012 · Viewed 17.5k times · Source

In least-squares models, the cost function is defined as the square of the difference between the predicted value and the actual value as a function of the input.

When we do logistic regression, we change the cost function to be a logarithmic function instead of defining it to be the square of the difference between the sigmoid function (the output value) and the actual output.

Is it OK to change and define our own cost function to determine the parameters?

Answer

Fred Foo picture Fred Foo · Feb 27, 2013

Yes, you can define your own loss function, but if you're a novice, you're probably better off using one from the literature. There are conditions that loss functions should meet:

  1. They should approximate the actual loss you're trying to minimize. As was said in the other answer, the standard loss functions for classification is zero-one-loss (misclassification rate) and the ones used for training classifiers are approximations of that loss.

    The squared-error loss from linear regression isn't used because it doesn't approximate zero-one-loss well: when your model predicts +50 for some sample while the intended answer was +1 (positive class), the prediction is on the correct side of the decision boundary so the zero-one-loss is zero, but the squared-error loss is still 49² = 2401. Some training algorithms will waste a lot of time getting predictions very close to {-1, +1} instead of focusing on getting just the sign/class label right.(*)

  2. The loss function should work with your intended optimization algorithm. That's why zero-one-loss is not used directly: it doesn't work with gradient-based optimization methods since it doesn't have a well-defined gradient (or even a subgradient, like the hinge loss for SVMs has).

    The main algorithm that optimizes the zero-one-loss directly is the old perceptron algorithm.

Also, when you plug in a custom loss function, you're no longer building a logistic regression model but some other kind of linear classifier.

(*) Squared error is used with linear discriminant analysis, but that's usually solved in close form instead of iteratively.