Speedup scipy griddata for multiple interpolations between two irregular grids

s_haskey picture s_haskey · Jan 4, 2014 · Viewed 9.4k times · Source

I have several values that are defined on the same irregular grid (x, y, z) that I want to interpolate onto a new grid (x1, y1, z1). i.e., I have f(x, y, z), g(x, y, z), h(x, y, z) and I want to calculate f(x1, y1, z1), g(x1, y1, z1), h(x1, y1, z1).

At the moment I am doing this using scipy.interpolate.griddata and it works well. However, because I have to perform each interpolation separately and there are many points, it is quite slow, with a great deal of duplication in the calculation (i.e finding which points are closest, setting up the grids etc...).

Is there a way to speedup the calculation and reduce the duplicated calculations? i.e something along the lines of defining the two grids, then changing the values for the interpolation?

Answer

Jaime picture Jaime · Jan 5, 2014

There are several things going on every time you make a call to scipy.interpolate.griddata:

  1. First, a call to sp.spatial.qhull.Delaunay is made to triangulate the irregular grid coordinates.
  2. Then, for each point in the new grid, the triangulation is searched to find in which triangle (actually, in which simplex, which in your 3D case will be in which tetrahedron) does it lay.
  3. The barycentric coordinates of each new grid point with respect to the vertices of the enclosing simplex are computed.
  4. An interpolated values is computed for that grid point, using the barycentric coordinates, and the values of the function at the vertices of the enclosing simplex.

The first three steps are identical for all your interpolations, so if you could store, for each new grid point, the indices of the vertices of the enclosing simplex and the weights for the interpolation, you would minimize the amount of computations by a lot. This is unfortunately not easy to do directly with the functionality available, although it is indeed possible:

import scipy.interpolate as spint
import scipy.spatial.qhull as qhull
import itertools

def interp_weights(xyz, uvw):
    tri = qhull.Delaunay(xyz)
    simplex = tri.find_simplex(uvw)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uvw - temp[:, d]
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
    return vertices, np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))

def interpolate(values, vtx, wts):
    return np.einsum('nj,nj->n', np.take(values, vtx), wts)

The function interp_weights does the calculations for the first three steps I listed above. Then the function interpolate uses those calcualted values to do step 4 very fast:

m, n, d = 3.5e4, 3e3, 3
# make sure no new grid point is extrapolated
bounding_cube = np.array(list(itertools.product([0, 1], repeat=d)))
xyz = np.vstack((bounding_cube,
                 np.random.rand(m - len(bounding_cube), d)))
f = np.random.rand(m)
g = np.random.rand(m)
uvw = np.random.rand(n, d)

In [2]: vtx, wts = interp_weights(xyz, uvw)

In [3]: np.allclose(interpolate(f, vtx, wts), spint.griddata(xyz, f, uvw))
Out[3]: True

In [4]: %timeit spint.griddata(xyz, f, uvw)
1 loops, best of 3: 2.81 s per loop

In [5]: %timeit interp_weights(xyz, uvw)
1 loops, best of 3: 2.79 s per loop

In [6]: %timeit interpolate(f, vtx, wts)
10000 loops, best of 3: 66.4 us per loop

In [7]: %timeit interpolate(g, vtx, wts)
10000 loops, best of 3: 67 us per loop

So first, it does the same as griddata, which is good. Second, setting up the interpolation, i.e. computing vtx and wts takes roughly the same as a call to griddata. But third, you can now interpolate for different values on the same grid in virtually no time.

The only thing that griddata does that is not contemplated here is assigning fill_value to points that have to be extrapolated. You could do that by checking for points for which at least one of the weights is negative, e.g.:

def interpolate(values, vtx, wts, fill_value=np.nan):
    ret = np.einsum('nj,nj->n', np.take(values, vtx), wts)
    ret[np.any(wts < 0, axis=1)] = fill_value
    return ret