np.ndarray with Periodic Boundary conditions

Alexander McFarlane picture Alexander McFarlane · Jun 28, 2016 · Viewed 7.8k times · Source

Problem

To impose np.ndarray periodic boundary conditions as laid out below

Details

  • Wrap the indexing of a python np.ndarray around the boundaries in n-dimensions
  • This is a periodic boundary condition forming an n-dimensional torus
  • Wrapping only occurs in the case that the value returned is scalar (a single point).
  • Slices will be treated as normal and will not be wrapped

An example and counterexample are given below:

a = np.arange(27).reshape(3,3,3)
b = Periodic_Lattice(a) # A theoretical class
    
# example: returning a scalar that shouldn't be accessible
print b[3,3,3] == b[0,0,0] # returns a scalar so invokes wrapping condition 
try: a[3,3,3] # the value is out of bounds in the original np.ndarray
except: print 'error'

# counter example: returning a slice
try: b[3,3] # this returns a slice and so shouldn't invoke the wrap
except: print 'error'

which should give the output:

True
error
error

I anticipate that I should be overloading __getitem__ and __setitem__ within np.ndarray but how to proceed with this is not entirely clear and there are many implementations on SO that fail for many test cases.

Answer

Alexander McFarlane picture Alexander McFarlane · Jun 28, 2016

Wrap function

A simple function can be written with the mod function, % in basic python and generalised to operate on an n-dimensional tuple given a specific shape.

def latticeWrapIdx(index, lattice_shape):
    """returns periodic lattice index 
    for a given iterable index
    
    Required Inputs:
        index :: iterable :: one integer for each axis
        lattice_shape :: the shape of the lattice to index to
    """
    if not hasattr(index, '__iter__'): return index         # handle integer slices
    if len(index) != len(lattice_shape): return index  # must reference a scalar
    if any(type(i) == slice for i in index): return index   # slices not supported
    if len(index) == len(lattice_shape):               # periodic indexing of scalars
        mod_index = tuple(( (i%s + s)%s for i,s in zip(index, lattice_shape)))
        return mod_index
    raise ValueError('Unexpected index: {}'.format(index))

This is tested as:

arr = np.array([[ 11.,  12.,  13.,  14.],
                [ 21.,  22.,  23.,  24.],
                [ 31.,  32.,  33.,  34.],
                [ 41.,  42.,  43.,  44.]])
test_vals = [[(1,1), 22.], [(3,3), 44.], [( 4, 4), 11.], # [index, expected value]
             [(3,4), 41.], [(4,3), 14.], [(10,10), 33.]]

passed = all([arr[latticeWrapIdx(idx, (4,4))] == act for idx, act in test_vals])
print "Iterating test values. Result: {}".format(passed)

and gives the output of,

Iterating test values. Result: True

Subclassing Numpy

The wrapping function can be incorporated into a subclassed np.ndarray as described here:

class Periodic_Lattice(np.ndarray):
    """Creates an n-dimensional ring that joins on boundaries w/ numpy
    
    Required Inputs
        array :: np.array :: n-dim numpy array to use wrap with
    
    Only currently supports single point selections wrapped around the boundary
    """
    def __new__(cls, input_array, lattice_spacing=None):
        """__new__ is called by numpy when and explicit constructor is used:
        obj = MySubClass(params) otherwise we must rely on __array_finalize
         """
        # Input array is an already formed ndarray instance
        # We first cast to be our class type
        obj = np.asarray(input_array).view(cls)
        
        # add the new attribute to the created instance
        obj.lattice_shape = input_array.shape
        obj.lattice_dim = len(input_array.shape)
        obj.lattice_spacing = lattice_spacing
        
        # Finally, we must return the newly created object:
        return obj
    
    def __getitem__(self, index):
        index = self.latticeWrapIdx(index)
        return super(Periodic_Lattice, self).__getitem__(index)
    
    def __setitem__(self, index, item):
        index = self.latticeWrapIdx(index)
        return super(Periodic_Lattice, self).__setitem__(index, item)
    
    def __array_finalize__(self, obj):
        """ ndarray.__new__ passes __array_finalize__ the new object, 
        of our own class (self) as well as the object from which the view has been taken (obj). 
        See
        http://docs.scipy.org/doc/numpy/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray
        for more info
        """
        # ``self`` is a new object resulting from
        # ndarray.__new__(Periodic_Lattice, ...), therefore it only has
        # attributes that the ndarray.__new__ constructor gave it -
        # i.e. those of a standard ndarray.
        #
        # We could have got to the ndarray.__new__ call in 3 ways:
        # From an explicit constructor - e.g. Periodic_Lattice():
        #   1. obj is None
        #       (we're in the middle of the Periodic_Lattice.__new__
        #       constructor, and self.info will be set when we return to
        #       Periodic_Lattice.__new__)
        if obj is None: return
        #   2. From view casting - e.g arr.view(Periodic_Lattice):
        #       obj is arr
        #       (type(obj) can be Periodic_Lattice)
        #   3. From new-from-template - e.g lattice[:3]
        #       type(obj) is Periodic_Lattice
        # 
        # Note that it is here, rather than in the __new__ method,
        # that we set the default value for 'spacing', because this
        # method sees all creation of default objects - with the
        # Periodic_Lattice.__new__ constructor, but also with
        # arr.view(Periodic_Lattice).
        #
        # These are in effect the default values from these operations
        self.lattice_shape = getattr(obj, 'lattice_shape', obj.shape)
        self.lattice_dim = getattr(obj, 'lattice_dim', len(obj.shape))
        self.lattice_spacing = getattr(obj, 'lattice_spacing', None)
        pass
    
    def latticeWrapIdx(self, index):
        """returns periodic lattice index 
        for a given iterable index
        
        Required Inputs:
            index :: iterable :: one integer for each axis
        
        This is NOT compatible with slicing
        """
        if not hasattr(index, '__iter__'): return index         # handle integer slices
        if len(index) != len(self.lattice_shape): return index  # must reference a scalar
        if any(type(i) == slice for i in index): return index   # slices not supported
        if len(index) == len(self.lattice_shape):               # periodic indexing of scalars
            mod_index = tuple(( (i%s + s)%s for i,s in zip(index, self.lattice_shape)))
            return mod_index
        raise ValueError('Unexpected index: {}'.format(index))

Testing demonstrates the lattice overloads correctly,

arr = np.array([[ 11.,  12.,  13.,  14.],
                [ 21.,  22.,  23.,  24.],
                [ 31.,  32.,  33.,  34.],
                [ 41.,  42.,  43.,  44.]])
test_vals = [[(1,1), 22.], [(3,3), 44.], [( 4, 4), 11.], # [index, expected value]
             [(3,4), 41.], [(4,3), 14.], [(10,10), 33.]]

periodic_arr  = Periodic_Lattice(arr)
passed = (periodic_arr == arr).all()
passed *= all([periodic_arr[idx] == act for idx, act in test_vals])
print "Iterating test values. Result: {}".format(passed)

and gives the output of,

Iterating test values. Result: True

Finally, using the code provided in the initial problem we obtain:

True
error
error