How does numpy.swapaxes work?

phoenix picture phoenix · Feb 18, 2017 · Viewed 15k times · Source

I created a sample array:

a = np.arange(18).reshape(9,2)

On printing, I get this as output:

[[ 0  1]
[ 2  3]
[ 4  5]
[ 6  7]
[ 8  9]
[10 11]
[12 13]
[14 15]
[16 17]]

On executing this reshaping:

b = a.reshape(2,3,3).swapaxes(0,2)

I get:

[[[ 0  9]
[ 3 12]
[ 6 15]]

[[ 1 10]
[ 4 13]
[ 7 16]]

[[ 2 11]
[ 5 14]
[ 8 17]]]

I went through this question, but it does not solve my problem.

Reshape an array in NumPy

The documentation is not useful either.

https://docs.scipy.org/doc/numpy/reference/generated/numpy.swapaxes.html

I need to know how the swapping is working(which is x-axis, y-axis, z-axis). A diagrammatic explanation would be most helpful.

Answer

GoingMyWay picture GoingMyWay · Mar 29, 2018

Here is my understanding of swapaxes

Suppose you have an array

In [1]: arr = np.arange(16).reshape((2, 2, 4))

In [2]: arr
Out[2]: 
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7]],

       [[ 8,  9, 10, 11],
        [12, 13, 14, 15]]])

And the shape of arr is (2, 2, 4), for the value 7, you can get the value by

In [3]: arr[0, 1, 3]
Out[3]: 7

There are 3 axes 0, 1 and 2, now, we swap axis 0 and 2

In [4]: arr_swap = arr.swapaxes(0, 2)

In [5]: arr_swap
Out[5]: 
array([[[ 0,  8],
        [ 4, 12]],

       [[ 1,  9],
        [ 5, 13]],

       [[ 2, 10],
        [ 6, 14]],

       [[ 3, 11],
        [ 7, 15]]])

And as you can guess, the index of 7 is (3, 1, 0), with axis 1 unchanged,

In [6]: arr_swap[3, 1, 0]
Out[6]: 7

So, now from the perspective of the index, swapping axis is just change the index of values. For example

In [7]: arr[0, 0, 1]
Out[7]: 1

In [8]: arr_swap[1, 0, 0]
Out[8]: 1

In [9]: arr[0, 1, 2]
Out[9]: 6

In [9]: arr_swap[2, 1, 0]
Out[9]: 6

So, if you feel difficult to get the swapped-axis array, just change the index, say arr_swap[2, 1, 0] = arr[0, 1, 2].