Random numbers generation in PySpark

zero323 picture zero323 · Aug 9, 2015 · Viewed 10.1k times · Source

Lets start with a simple function which always returns a random integer:

import numpy as np

def f(x):
    return np.random.randint(1000)

and a RDD filled with zeros and mapped using f:

rdd = sc.parallelize([0] * 10).map(f)

Since above RDD is not persisted I expect I'll get a different output every time I collect:

> rdd.collect()
[255, 512, 512, 512, 255, 512, 255, 512, 512, 255]

If we ignore the fact that distribution of values doesn't really look random it is more or less what happens. Problem starts we we when take only a first element:

assert len(set(rdd.first() for _ in xrange(100))) == 1

or

assert len(set(tuple(rdd.take(1)) for _ in xrange(100))) == 1

It seems to return the same number each time. I've been able to reproduce this behavior on two different machines with Spark 1.2, 1.3 and 1.4. Here I am using np.random.randint but it behaves the same way with random.randint.

This issue, same as non-exactly-random results with collect, seems to be Python specific and I couldn't reproduce it using Scala:

def f(x: Int) = scala.util.Random.nextInt(1000)

val rdd = sc.parallelize(List.fill(10)(0)).map(f)
(1 to 100).map(x => rdd.first).toSet.size

rdd.collect()

Did I miss something obvious here?

Edit:

Turns out the source of the problem is Python RNG implementation. To quote official documentation:

The functions supplied by this module are actually bound methods of a hidden instance of the random.Random class. You can instantiate your own instances of Random to get generators that don’t share state.

I assume NumPy works the same way and rewriting f using RandomState instance as follows

import os
import binascii

def f(x, seed=None):
    seed = (
        seed if seed is not None 
        else int(binascii.hexlify(os.urandom(4)), 16))
    rs = np.random.RandomState(seed)
    return rs.randint(1000)

makes it slower but solves the problem.

While above explains not random results from collect I still don't understand how it affects first / take(1) between multiple actions.

Answer

zero323 picture zero323 · Jun 21, 2016

So the actual problem here is relatively simple. Each subprocess in Python inherits its state from its parent:

len(set(sc.parallelize(range(4), 4).map(lambda _: random.getstate()).collect()))
# 1

Since parent state has no reason to change in this particular scenario and workers have a limited lifespan, state of every child will be exactly the same on each run.