Guided Back-propagation in TensorFlow

Peter picture Peter · Jul 13, 2016 · Viewed 10.3k times · Source

I would like to implement in TensorFlow the technique of "Guided back-propagation" introduced in this Paper and which is described in this recipe .

Computationally that means that when I compute the gradient e.g., of the input wrt. the output of the NN, I will have to modify the gradients computed at every RELU unit. Concretely, the back-propagated signal on those units must be thresholded on zero, to make this technique work. In other words the partial derivative of the RELUs that are negative must be ignored.

Given that I am interested in applying these gradient computations only on test examples, i.e., I don't want to update the model's parameters - how shall I do it?

I tried (unsuccessfully) two things so far:

  1. Use tf.py_func to wrap my simple numpy version of a RELU, which then is eligible to redefine it's gradient operation via the g.gradient_override_map context manager.

  2. Gather the forward/backward values of BackProp and apply the thresholding on those stemming from Relus.

I failed with both approaches because they require some knowledge of the internals of TF that currently I don't have.

Can anyone suggest any other route, or sketch the code?

Thanks a lot.

Answer

Falcon picture Falcon · Aug 5, 2016

The better solution (your approach 1) with ops.RegisterGradient and tf.Graph.gradient_override_map. Together they override the gradient computation for a pre-defined Op, e.g. Relu within the gradient_override_map context using only python code.

@ops.RegisterGradient("GuidedRelu")
def _GuidedReluGrad(op, grad):
    return tf.where(0. < grad, gen_nn_ops._relu_grad(grad, op.outputs[0]), tf.zeros(grad.get_shape()))

...
with g.gradient_override_map({'Relu': 'GuidedRelu'}):
    y = tf.nn.relu(x)

here is the full example implementation of guided relu: https://gist.github.com/falcondai/561d5eec7fed9ebf48751d124a77b087

Update: in Tensorflow >= 1.0, tf.select is renamed to tf.where. I updated the snippet accordingly. (Thanks @sbond for bringing this to my attention :)