I just want to interpolate, in the simplest possible terms, a 3D dataset. Linear interpolation, nearest neighbour, all that would suffice (this is to start off some algorithm, so no accurate estimate is required).
In new scipy versions, things like griddata would be useful, but currently I only have scipy 0.8. So I have a "cube" (data[:,:,:]
, (NixNjxNk)) array, and an array of flags (flags[:,:,:,]
, True
or False
) of the same size. I want to interpolate my data for the elements of data where the corresponding element of flag is False, using eg the nearest valid datapoint in data, or some linear combination of "close by" points.
There can be large gaps in the dataset in at least two dimensions. Other than coding a full-blown nearest neighbour algorithm using kdtrees or similar, I can't really find a generic, N-dimensional nearest-neighbour interpolator.
Using scipy.ndimage, your problem can be solved with nearest neighbor interpolation in 2 lines :
from scipy import ndimage as nd
indices = nd.distance_transform_edt(invalid_cell_mask, return_distances=False, return_indices=True)
data = data[tuple(ind)]
Now, in the form of a function:
import numpy as np
from scipy import ndimage as nd
def fill(data, invalid=None):
"""
Replace the value of invalid 'data' cells (indicated by 'invalid')
by the value of the nearest valid data cell
Input:
data: numpy array of any dimension
invalid: a binary array of same shape as 'data'.
data value are replaced where invalid is True
If None (default), use: invalid = np.isnan(data)
Output:
Return a filled array.
"""
if invalid is None: invalid = np.isnan(data)
ind = nd.distance_transform_edt(invalid,
return_distances=False,
return_indices=True)
return data[tuple(ind)]
Exemple of use:
def test_fill(s,d):
# s is size of one dimension, d is the number of dimension
data = np.arange(s**d).reshape((s,)*d)
seed = np.zeros(data.shape,dtype=bool)
seed.flat[np.random.randint(0,seed.size,int(data.size/20**d))] = True
return fill(data,-seed), seed
import matplotlib.pyplot as plt
data,seed = test_fill(500,2)
data[nd.binary_dilation(seed,iterations=2)] = 0 # draw (dilated) seeds in black
plt.imshow(np.mod(data,42)) # show cluster
result: