Is there a built-in KL divergence loss function in TensorFlow?

Transcendental picture Transcendental · Jan 26, 2017 · Viewed 15.1k times · Source

I have two tensors, prob_a and prob_b with shape [None, 1000], and I want to compute the KL divergence from prob_a to prob_b. Is there a built-in function for this in TensorFlow? I tried using tf.contrib.distributions.kl(prob_a, prob_b), but it gives:

NotImplementedError: No KL(dist_a || dist_b) registered for dist_a type Tensor and dist_b type Tensor

If there is no built-in function, what would be a good workaround?

Answer

meferne picture meferne · Jun 25, 2018

Assuming that your input tensors prob_a and prob_b are probability tensors that sum to 1 along the last axis, you could do it like this:

def kl(x, y):
    X = tf.distributions.Categorical(probs=x)
    Y = tf.distributions.Categorical(probs=y)
    return tf.distributions.kl_divergence(X, Y)

result = kl(prob_a, prob_b)

A simple example:

import numpy as np
import tensorflow as tf
a = np.array([[0.25, 0.1, 0.65], [0.8, 0.15, 0.05]])
b = np.array([[0.7, 0.2, 0.1], [0.15, 0.8, 0.05]])
sess = tf.Session()
print(kl(a, b).eval(session=sess))  # [0.88995184 1.08808468]

You would get the same result with

np.sum(a * np.log(a / b), axis=1) 

However, this implementation is a bit buggy (checked in Tensorflow 1.8.0).

If you have zero probabilities in a, e.g. if you try [0.8, 0.2, 0.0] instead of [0.8, 0.15, 0.05], you will get nan even though by Kullback-Leibler definition 0 * log(0 / b) should contribute as zero.

To mitigate this, one should add some small numerical constant. It is also prudent to use tf.distributions.kl_divergence(X, Y, allow_nan_stats=False) to cause a runtime error in such situations.

Also, if there are some zeros in b, you will get inf values which won't be caught by the allow_nan_stats=False option so those have to be handled as well.