Could anyone help with an an explanation of what axis
is in TensorFlow
's one_hot
function?
According to the documentation:
axis: The axis to fill (default: -1, a new inner-most axis)
Closest I came to an answer on SO was an explanation relevant to Pandas:
Not sure if the context is just as applicable.
Here's an example:
x = tf.constant([0, 1, 2])
... is the input tensor and N=4
(each index is transformed into 4D vector).
axis=-1
Computing one_hot_1 = tf.one_hot(x, 4).eval()
yields a (3, 4)
tensor:
[[ 1. 0. 0. 0.]
[ 0. 1. 0. 0.]
[ 0. 0. 1. 0.]]
... where the last dimension is one-hot encoded (clearly visible). This corresponds to the default axis=-1
, i.e. the last one.
axis=0
Now, computing one_hot_2 = tf.one_hot(x, 4, axis=0).eval()
yields a (4, 3)
tensor, which is not immediately recognizable as one-hot encoded:
[[ 1. 0. 0.]
[ 0. 1. 0.]
[ 0. 0. 1.]
[ 0. 0. 0.]]
This is because the one-hot encoding is done along the 0-axis and one has to transpose the matrix to see the previous encoding. The situation becomes more complicated, when the input is higher dimensional, but the idea is the same: the difference is in placement of the extra dimension used for one-hot encoding.