Filling gaps in a numpy array

Jose picture Jose · Apr 5, 2011 · Viewed 8.1k times · Source

I just want to interpolate, in the simplest possible terms, a 3D dataset. Linear interpolation, nearest neighbour, all that would suffice (this is to start off some algorithm, so no accurate estimate is required).

In new scipy versions, things like griddata would be useful, but currently I only have scipy 0.8. So I have a "cube" (data[:,:,:], (NixNjxNk)) array, and an array of flags (flags[:,:,:,], True or False) of the same size. I want to interpolate my data for the elements of data where the corresponding element of flag is False, using eg the nearest valid datapoint in data, or some linear combination of "close by" points.

There can be large gaps in the dataset in at least two dimensions. Other than coding a full-blown nearest neighbour algorithm using kdtrees or similar, I can't really find a generic, N-dimensional nearest-neighbour interpolator.

Answer

Juh_ picture Juh_ · Feb 13, 2012

Using scipy.ndimage, your problem can be solved with nearest neighbor interpolation in 2 lines :

from scipy import ndimage as nd

indices = nd.distance_transform_edt(invalid_cell_mask, return_distances=False, return_indices=True)
data = data[tuple(ind)]

Now, in the form of a function:

import numpy as np
from scipy import ndimage as nd

def fill(data, invalid=None):
    """
    Replace the value of invalid 'data' cells (indicated by 'invalid') 
    by the value of the nearest valid data cell

    Input:
        data:    numpy array of any dimension
        invalid: a binary array of same shape as 'data'. 
                 data value are replaced where invalid is True
                 If None (default), use: invalid  = np.isnan(data)

    Output: 
        Return a filled array. 
    """    
    if invalid is None: invalid = np.isnan(data)

    ind = nd.distance_transform_edt(invalid, 
                                    return_distances=False, 
                                    return_indices=True)
    return data[tuple(ind)]

Exemple of use:

def test_fill(s,d):
     # s is size of one dimension, d is the number of dimension
    data = np.arange(s**d).reshape((s,)*d)
    seed = np.zeros(data.shape,dtype=bool)
    seed.flat[np.random.randint(0,seed.size,int(data.size/20**d))] = True

    return fill(data,-seed), seed

import matplotlib.pyplot as plt
data,seed  = test_fill(500,2)
data[nd.binary_dilation(seed,iterations=2)] = 0   # draw (dilated) seeds in black
plt.imshow(np.mod(data,42))                       # show cluster

result: enter image description here