How Pytorch Tensor get the index of specific value

Han Bing picture Han Bing · Dec 18, 2017 · Viewed 49.5k times · Source

In python list, we can use list.index(somevalue). How can pytorch do this?
For example:

    a=[1,2,3]
    print(a.index(2))

Then, 1 will be output. How can a pytorch tensor do this without converting it to a python list?

Answer

Manuel Lagunas picture Manuel Lagunas · Dec 18, 2017

I think there is no direct translation from list.index() to a pytorch function. However, you can achieve similar results using tensor==number and then the nonzero() function. For example:

t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero())

This piece of code returns

1

[torch.LongTensor of size 1x1]