How to efficiently retrieve the indices of maximum values in a Torch tensor?

Chris picture Chris · Nov 8, 2018 · Viewed 19.5k times · Source

Assume to have a torch tensor, for example of the following shape:

x = torch.rand(20, 1, 120, 120)

What I would like now, is to get the indices of the maximum values of each 120x120 matrix. To simplify the problem I would first x.squeeze() to work with shape [20, 120, 120]. I would then like to get torch tensor which is a list of indices with shape [20, 2].

How can I do this fast?

Answer

tejasvi88 picture tejasvi88 · Apr 29, 2020

torch.topk() is what you are looking for. From the docs,

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

Returns the k largest elements of the given input tensor along a given dimension.

  • If dim is not given, the last dimension of the input is chosen.

  • If largest is False then the k smallest elements are returned.

  • A namedtuple of (values, indices) is returned, where the indices are the indices of the elements in the original input tensor.

  • The boolean option sorted if True, will make sure that the returned k elements are themselves sorted