My question is in two connected parts:
How do I calculate the max along a certain axis of a tensor? For example, if I have
x = tf.constant([[1,220,55],[4,3,-1]])
I want something like
x_max = tf.max(x, axis=1)
print sess.run(x_max)
output: [220,4]
I know there is a tf.argmax
and a tf.maximum
, but neither give the maximum value along an axis of a single tensor. For now I have a workaround:
x_max = tf.slice(x, begin=[0,0], size=[-1,1])
for a in range(1,2):
x_max = tf.maximum(x_max , tf.slice(x, begin=[0,a], size=[-1,1]))
But it looks less than optimal. Is there a better way to do this?
Given the indices of an argmax
of a tensor, how do I index into another tensor using those indices? Using the example of x
above, how do I do something like the following:
ind_max = tf.argmax(x, dimension=1) #output is [1,0]
y = tf.constant([[1,2,3], [6,5,4])
y_ = y[:, ind_max] #y_ should be [2,6]
I know slicing, like the last line, does not exist in TensorFlow yet (#206).
My question is: what is the best workaround for my specific case (maybe using other methods like gather, select, etc.)?
Additional information: I know x
and y
are going to be two dimensional tensors only!
The tf.reduce_max()
operator provides exactly this functionality. By default it computes the global maximum of the given tensor, but you can specify a list of reduction_indices
, which has the same meaning as axis
in NumPy. To complete your example:
x = tf.constant([[1, 220, 55], [4, 3, -1]])
x_max = tf.reduce_max(x, reduction_indices=[1])
print sess.run(x_max) # ==> "array([220, 4], dtype=int32)"
If you compute the argmax using tf.argmax()
, you could obtain the the values from a different tensor y
by flattening y
using tf.reshape()
, converting the argmax indices into vector indices as follows, and using tf.gather()
to extract the appropriate values:
ind_max = tf.argmax(x, dimension=1)
y = tf.constant([[1, 2, 3], [6, 5, 4]])
flat_y = tf.reshape(y, [-1]) # Reshape to a vector.
# N.B. Handles 2-D case only.
flat_ind_max = ind_max + tf.cast(tf.range(tf.shape(y)[0]) * tf.shape(y)[1], tf.int64)
y_ = tf.gather(flat_y, flat_ind_max)
print sess.run(y_) # ==> "array([2, 6], dtype=int32)"