I have the following code:
a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])
I have a multi-dimensional index b
and want to use it to select a single cell in a
. If b wasn't a tensor, I could do:
a[1,1,1,1]
Which returns the correct cell, but:
a[b]
Doesn't work, because it just selects a[1]
four times.
How can I do this? Thanks
A more elegant (and simpler) solution might be to simply cast b
as a tuple:
a[tuple(b)]
Out[10]: tensor(5.)
I was curious to see how this works with "regular" numpy, and found a related article explaining this quite well here.