python numpy argmax to max in multidimensional array

Rayyan Riaz picture Rayyan Riaz · Feb 28, 2017 · Viewed 8.3k times · Source

I have the following code:

import numpy as np
sample = np.random.random((10,10,3))
argmax_indices = np.argmax(sample, axis=2)

i.e. I take the argmax along axis=2 and it gives me a (10,10) matrix. Now, I want to assign these indices value 0. For this, I want to index the sample array. I tried:

max_values = sample[argmax_indices]

but it doesn't work. I want something like

max_values = sample[argmax_indices]
sample[argmax_indices] = 0

I simply validate by checking that max_values - np.max(sample, axis=2) should give a zero matrix of shape (10,10). Any help will be appreciated.

Answer

Divakar picture Divakar · Feb 28, 2017

Here's one approach -

m,n = sample.shape[:2]
I,J = np.ogrid[:m,:n]
max_values = sample[I,J, argmax_indices]
sample[I,J, argmax_indices] = 0

Sample step-by-step run

1) Sample input array :

In [261]: a = np.random.randint(0,9,(2,2,3))

In [262]: a
Out[262]: 
array([[[8, 4, 6],
        [7, 6, 2]],

       [[1, 8, 1],
        [4, 6, 4]]])

2) Get the argmax indices along axis=2 :

In [263]: idx = a.argmax(axis=2)

3) Get the shape and arrays for indexing into first two dims :

In [264]: m,n = a.shape[:2]

In [265]: I,J = np.ogrid[:m,:n]

4) Index using I, J and idx for storing the max values using advanced-indexing :

In [267]: max_values = a[I,J,idx]

In [268]: max_values
Out[268]: 
array([[8, 7],
       [8, 6]])

5) Verify that we are getting an all zeros array after subtracting np.max(a,axis=2) from max_values :

In [306]: max_values - np.max(a, axis=2)
Out[306]: 
array([[0, 0],
       [0, 0]])

6) Again using advanced-indexing assign those places as zeros and do one more level of visual verification :

In [269]: a[I,J,idx] = 0

In [270]: a
Out[270]: 
array([[[0, 4, 6], # <=== Compare this against the original version
        [0, 6, 2]],

       [[1, 0, 1],
        [4, 0, 4]]])