Lets start with a simple function which always returns a random integer:
import numpy as np
def f(x):
return np.random.randint(1000)
and a RDD filled with zeros and mapped using f
:
rdd = sc.parallelize([0] * 10).map(f)
Since above RDD is not persisted I expect I'll get a different output every time I collect:
> rdd.collect()
[255, 512, 512, 512, 255, 512, 255, 512, 512, 255]
If we ignore the fact that distribution of values doesn't really look random it is more or less what happens. Problem starts we we when take only a first element:
assert len(set(rdd.first() for _ in xrange(100))) == 1
or
assert len(set(tuple(rdd.take(1)) for _ in xrange(100))) == 1
It seems to return the same number each time. I've been able to reproduce this behavior on two different machines with Spark 1.2, 1.3 and 1.4. Here I am using np.random.randint
but it behaves the same way with random.randint
.
This issue, same as non-exactly-random results with collect
, seems to be Python specific and I couldn't reproduce it using Scala:
def f(x: Int) = scala.util.Random.nextInt(1000)
val rdd = sc.parallelize(List.fill(10)(0)).map(f)
(1 to 100).map(x => rdd.first).toSet.size
rdd.collect()
Did I miss something obvious here?
Edit:
Turns out the source of the problem is Python RNG implementation. To quote official documentation:
The functions supplied by this module are actually bound methods of a hidden instance of the random.Random class. You can instantiate your own instances of Random to get generators that don’t share state.
I assume NumPy works the same way and rewriting f
using RandomState
instance as follows
import os
import binascii
def f(x, seed=None):
seed = (
seed if seed is not None
else int(binascii.hexlify(os.urandom(4)), 16))
rs = np.random.RandomState(seed)
return rs.randint(1000)
makes it slower but solves the problem.
While above explains not random results from collect I still don't understand how it affects first
/ take(1)
between multiple actions.
So the actual problem here is relatively simple. Each subprocess in Python inherits its state from its parent:
len(set(sc.parallelize(range(4), 4).map(lambda _: random.getstate()).collect()))
# 1
Since parent state has no reason to change in this particular scenario and workers have a limited lifespan, state of every child will be exactly the same on each run.