tensorflow: check if a scalar boolean tensor is True

Tu Bui picture Tu Bui · Apr 6, 2017 · Viewed 18.4k times · Source

I want to control the execution of a function using a placeholder, but keep getting an error "Using a tf.Tensor as a Python bool is not allowed". Here is the code that produces this error:

import tensorflow as tf
def foo(c):
  if c:
    print('This is true')
    #heavy code here
    return 10
  else:
    print('This is false')
    #different code here
    return 0

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()

I changed if c to if c is not None without luck. How can I control foo by turning on and off the placeholder a then?

Update: as @nessuno and @nemo point out, we must use tf.cond instead of if..else. The answer to my question is to re-design my function like this:

import tensorflow as tf
def foo(c):
  return tf.cond(c, func1, func2)

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close() 

Answer

nessuno picture nessuno · Apr 6, 2017

You have to use tf.cond to define a conditional operation within the graph and change, thus, the flow of the tensors.

import tensorflow as tf

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = tf.cond(tf.equal(a, tf.constant(True)), lambda: tf.constant(10), lambda: tf.constant(0))
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()
print(res)

10