In TensorFlow, what is the argument 'axis' in the function 'tf.one_hot'

user919426 picture user919426 · Jan 3, 2018 · Viewed 22.7k times · Source

Could anyone help with an an explanation of what axis is in TensorFlow's one_hot function?

According to the documentation:

axis: The axis to fill (default: -1, a new inner-most axis)

Closest I came to an answer on SO was an explanation relevant to Pandas:

Not sure if the context is just as applicable.


Maxim picture Maxim · Jan 3, 2018

Here's an example:

x = tf.constant([0, 1, 2])

... is the input tensor and N=4 (each index is transformed into 4D vector).


Computing one_hot_1 = tf.one_hot(x, 4).eval() yields a (3, 4) tensor:

[[ 1.  0.  0.  0.]
 [ 0.  1.  0.  0.]
 [ 0.  0.  1.  0.]]

... where the last dimension is one-hot encoded (clearly visible). This corresponds to the default axis=-1, i.e. the last one.


Now, computing one_hot_2 = tf.one_hot(x, 4, axis=0).eval() yields a (4, 3) tensor, which is not immediately recognizable as one-hot encoded:

[[ 1.  0.  0.]
 [ 0.  1.  0.]
 [ 0.  0.  1.]
 [ 0.  0.  0.]]

This is because the one-hot encoding is done along the 0-axis and one has to transpose the matrix to see the previous encoding. The situation becomes more complicated, when the input is higher dimensional, but the idea is the same: the difference is in placement of the extra dimension used for one-hot encoding.