Assume to have a torch tensor, for example of the following shape:
x = torch.rand(20, 1, 120, 120)
What I would like now, is to get the indices of the maximum values of each 120x120 matrix. To simplify the problem I would first x.squeeze()
to work with shape [20, 120, 120]
. I would then like to get torch tensor which is a list of indices with shape [20, 2]
.
How can I do this fast?
torch.topk() is what you are looking for. From the docs,
torch.topk
(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
Returns the k
largest elements of the given input
tensor along
a given dimension.
If dim
is not given, the last dimension of the input is chosen.
If largest
is False
then the k smallest elements are returned.
A namedtuple of (values, indices) is returned, where the indices are the indices of the elements in the original input tensor.
The boolean option sorted
if True
, will make sure that the returned k elements are themselves sorted