As I was going through pytorch documentation I came across a term layout = torch.strided
in many of the functions. Can anyone help me in understanding where is it used and how. The description says it's the the desired layout of returned Tensor. What does layout mean and how many types of layout are there ?
torch.rand(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)
strides
is number of steps (or jumps) that are needed to go from one element to next element, in a given dimension. In computer memory, the data is stored linearly in a contiguous block of memory. What we view is just a (re)presentation.
Let's take an example tensor for understanding this:
# a 2D tensor
In [62]: tensor = torch.arange(1, 16).reshape(3, 5)
In [63]: tensor
Out[63]:
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]])
With this tensor in place, the strides are:
# get the strides
In [64]: tensor.stride()
Out[64]: (5, 1)
What this resultant tuple (5, 1)
says is:
1
to 6
, we should take 5 steps (or jumps)7
to 8
, we should take 1 step (or jump)The order (or index) of 5
& 1
in the tuple represents the dimension/axis. You can also pass the dimension, for which you want the stride, as an argument:
# get stride for axis 0
In [65]: tensor.stride(0)
Out[65]: 5
# get stride for axis 1
In [66]: tensor.stride(1)
Out[66]: 1
With that understanding, we might have to ask why is this extra parameter needed when we create the tensors? The answer to that is for efficiency reasons. (How can we store/read/access the elements in the (sparse) tensor most efficiently?).
With sparse tensors (a tensor where most of the elements are just zeroes), so we don't want to store these values. we only store the non-zero values and their indices. With a desired shape, the rest of the values can then be filled with zeroes, yielding the desired sparse tensor.
For further reading on this, the following articles might be of help:
P.S: I guess there's a typo in the torch.layout
documentation which says
Strides are a list of integers ...
The composite data type returned by tensor.stride()
is a tuple, not a list.