Implementation of a softmax activation function for neural networks

alfa picture alfa · Mar 28, 2012 · Viewed 14.4k times · Source

I am using a Softmax activation function in the last layer of a neural network. But I have problems with a safe implementation of this function.

A naive implementation would be this one:

Vector y = mlp(x); // output of the neural network without softmax activation function
for(int f = 0; f < y.rows(); f++)
  y(f) = exp(y(f));
y /= y.sum();

This does not work very well for > 100 hidden nodes because the y will be NaN in many cases (if y(f) > 709, exp(y(f)) will return inf). I came up with this version:

Vector y = mlp(x); // output of the neural network without softmax activation function
for(int f = 0; f < y.rows(); f++)
  y(f) = safeExp(y(f), y.rows());
y /= y.sum();

where safeExp is defined as

double safeExp(double x, int div)
{
  static const double maxX = std::log(std::numeric_limits<double>::max());
  const double max = maxX / (double) div;
  if(x > max)
    x = max;
  return std::exp(x);
}

This function limits the input of exp. In most of the cases this works but not in all cases and I did not really manage to find out in which cases it does not work. When I have 800 hidden neurons in the previous layer it does not work at all.

However, even if this worked I somehow "distort" the result of the ANN. Can you think of any other way to calculate the correct solution? Are there any C++ libraries or tricks that I can use to calculate the exact output of this ANN?

edit: The solution provided by Itamar Katz is:

Vector y = mlp(x); // output of the neural network without softmax activation function
double ymax = maximal component of y
for(int f = 0; f < y.rows(); f++)
  y(f) = exp(y(f) - ymax);
y /= y.sum();

And it really is mathematically the same. In practice however, some small values become 0 because of the floating point precision. I wonder why nobody ever writes these implementation details down in textbooks.

Answer

Itamar Katz picture Itamar Katz · Mar 28, 2012

First go to log scale, i.e calculate log(y) instead of y. The log of the numerator is trivial. In order to calculate the log of the denominator, you can use the following 'trick': http://lingpipe-blog.com/2009/06/25/log-sum-of-exponentials/