After I learned how to use einsum
, I am now trying to understand how np.tensordot
works.
However, I am a little bit lost especially regarding the various possibilities for the parameter axes
.
To understand it, as I have never practiced tensor calculus, I use the following example:
A = np.random.randint(2, size=(2, 3, 5))
B = np.random.randint(2, size=(3, 2, 4))
In this case, what are the different possible np.tensordot
and how would you compute it manually?
The idea with tensordot
is pretty simple - We input the arrays and the respective axes along which the sum-reductions are intended. The axes that take part in sum-reduction are removed in the output and all of the remaining axes from the input arrays are spread-out as different axes in the output keeping the order in which the input arrays are fed.
Let's look at few sample cases with one and two axes of sum-reductions and also swap the input places and see how the order is kept in the output.
Inputs :
In [7]: A = np.random.randint(2, size=(2, 6, 5))
...: B = np.random.randint(2, size=(3, 2, 4))
...:
Case #1:
In [9]: np.tensordot(A, B, axes=((0),(1))).shape
Out[9]: (6, 5, 3, 4)
A : (2, 6, 5) -> reduction of axis=0
B : (3, 2, 4) -> reduction of axis=1
Output : `(2, 6, 5)`, `(3, 2, 4)` ===(2 gone)==> `(6,5)` + `(3,4)` => `(6,5,3,4)`
Case #2 (same as case #1 but the inputs are fed swapped):
In [8]: np.tensordot(B, A, axes=((1),(0))).shape
Out[8]: (3, 4, 6, 5)
B : (3, 2, 4) -> reduction of axis=1
A : (2, 6, 5) -> reduction of axis=0
Output : `(3, 2, 4)`, `(2, 6, 5)` ===(2 gone)==> `(3,4)` + `(6,5)` => `(3,4,6,5)`.
Inputs :
In [11]: A = np.random.randint(2, size=(2, 3, 5))
...: B = np.random.randint(2, size=(3, 2, 4))
...:
Case #1:
In [12]: np.tensordot(A, B, axes=((0,1),(1,0))).shape
Out[12]: (5, 4)
A : (2, 3, 5) -> reduction of axis=(0,1)
B : (3, 2, 4) -> reduction of axis=(1,0)
Output : `(2, 3, 5)`, `(3, 2, 4)` ===(2,3 gone)==> `(5)` + `(4)` => `(5,4)`
Case #2:
In [14]: np.tensordot(B, A, axes=((1,0),(0,1))).shape
Out[14]: (4, 5)
B : (3, 2, 4) -> reduction of axis=(1,0)
A : (2, 3, 5) -> reduction of axis=(0,1)
Output : `(3, 2, 4)`, `(2, 3, 5)` ===(2,3 gone)==> `(4)` + `(5)` => `(4,5)`
We can extend this to as many axes as possible.