Indexing a multi-dimensional tensor with a tensor in PyTorch

Chum-Chum Scarecrows picture Chum-Chum Scarecrows · Aug 30, 2018 · Viewed 13.8k times · Source

I have the following code:

a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])

I have a multi-dimensional index b and want to use it to select a single cell in a. If b wasn't a tensor, I could do:

a[1,1,1,1]

Which returns the correct cell, but:

a[b]

Doesn't work, because it just selects a[1] four times.

How can I do this? Thanks

Answer

dennlinger picture dennlinger · Aug 30, 2018

A more elegant (and simpler) solution might be to simply cast b as a tuple:

a[tuple(b)]
Out[10]: tensor(5.)

I was curious to see how this works with "regular" numpy, and found a related article explaining this quite well here.