I'm using a broadcast variable about 100 MB pickled in size, which I'm approximating with:
>>> data = list(range(int(10*1e6)))
>>> import cPickle as pickle
>>> len(pickle.dumps(data))
98888896
Running on a cluster with 3 c3.2xlarge executors, and a m3.large driver, with the following command launching the interactive session:
IPYTHON=1 pyspark --executor-memory 10G --driver-memory 5G --conf spark.driver.maxResultSize=5g
In an RDD, if I persist a reference to this broadcast variable, the memory usage explodes. For 100 references to a 100 MB variable, even if it were copied 100 times, I'd expect the data usage to be no more than 10 GB total (let alone 30 GB over 3 nodes). However, I see out of memory errors when I run the following test:
data = list(range(int(10*1e6)))
metadata = sc.broadcast(data)
ids = sc.parallelize(zip(range(100), range(100)))
joined_rdd = ids.mapValues(lambda _: metadata.value)
joined_rdd.persist()
print('count: {}'.format(joined_rdd.count()))
The stack trace:
TaskSetManager: Lost task 17.3 in stage 0.0 (TID 75, 10.22.10.13):
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/worker.py", line 111, in main
process()
File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/worker.py", line 106, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/usr/lib/spark/python/pyspark/rdd.py", line 2355, in pipeline_func
return func(split, prev_func(split, iterator))
File "/usr/lib/spark/python/pyspark/rdd.py", line 2355, in pipeline_func
return func(split, prev_func(split, iterator))
File "/usr/lib/spark/python/pyspark/rdd.py", line 317, in func
return f(iterator)
File "/usr/lib/spark/python/pyspark/rdd.py", line 1006, in <lambda>
return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
File "/usr/lib/spark/python/pyspark/rdd.py", line 1006, in <genexpr>
return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 139, in load_stream
yield self._read_with_length(stream)
File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 164, in _read_with_length
return self.loads(obj)
File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 422, in loads
return pickle.loads(obj)
MemoryError
at org.apache.spark.api.python.PythonRDD$$anon$1.read(PythonRDD.scala:138)
at org.apache.spark.api.python.PythonRDD$$anon$1.<init>(PythonRDD.scala:179)
at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:97)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:297)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:264)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
at org.apache.spark.scheduler.Task.run(Task.scala:88)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
at org.apache.spark.scheduler.Task.run(Task.scala:88)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
at java.lang.Thread.run(Thread.java:745)
16/05/25 23:57:15 ERROR TaskSetManager: Task 17 in stage 0.0 failed 4 times; aborting job
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
<ipython-input-1-7a262fdfa561> in <module>()
7 joined_rdd.persist()
8 print('persist called')
----> 9 print('count: {}'.format(joined_rdd.count()))
/usr/lib/spark/python/pyspark/rdd.py in count(self)
1004 3
1005 """
-> 1006 return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
1007
1008 def stats(self):
/usr/lib/spark/python/pyspark/rdd.py in sum(self)
995 6.0
996 """
--> 997 return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add)
998
999 def count(self):
/usr/lib/spark/python/pyspark/rdd.py in fold(self, zeroValue, op)
869 # zeroValue provided to each partition is unique from the one provided
870 # to the final reduce call
--> 871 vals = self.mapPartitions(func).collect()
872 return reduce(op, vals, zeroValue)
873
/usr/lib/spark/python/pyspark/rdd.py in collect(self)
771 """
772 with SCCallSiteSync(self.context) as css:
--> 773 port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
774 return list(_load_from_socket(port, self._jrdd_deserializer))
775
/usr/lib/spark/python/lib/py4j-0.8.2.1-src.zip/py4j/java_gateway.py in __call__(self, *args)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:231)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:379)
at py4j.Gateway.invoke(Gateway.java:259)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:207)
at java.lang.Thread.run(Thread.java:745)
I've seen previous threads about the memory usage of pickle deserialization being an issue. However, I would expect a broadcast variable to only be deserialized (and loaded into memory on an executor) once, and subsequent references to .value
to reference that in-memory address. That doesn't seem to be the case, however. Am I missing something?
The examples I've seen with broadcast variables have them as dictionaries, used one time to transform a set of data (i.e. replace airport acronyms with airport names). The motivation behind persisting them here is to create objects with knowledge of a broadcast variable and how to interact with it, persist those objects, and perform multiple computations using them (with spark taking care of holding them in memory).
What are some tips for using large (100 MB+) broadcast variables? Is persisting a broadcast variable misguided? Is this an issue that is possibly specific to PySpark?
Thank you! Your help is appreciated.
Note, I've also posted this question on the databricks forums
Edit - followup question:
It was suggested that the default Spark serializer has a batch size of 65337. Objects serialized in different batches are not identified as the same and are assigned different memory addresses, examined here via the builtin id
function. However, even with a larger broadcast variable that would in theory take 256 batches to serialize, I still see only 2 distinct copies. Shouldn't I see many more? Is my understanding of how batch serialization works incorrect?
>>> sc.serializer.bestSize
65536
>>> import cPickle as pickle
>>> broadcast_data = {k: v for (k, v) in enumerate(range(int(1e6)))}
>>> len(pickle.dumps(broadcast_data))
16777786
>>> len(pickle.dumps({k: v for (k, v) in enumerate(range(int(1e6)))})) / sc.serializer.bestSize
256
>>> bd = sc.broadcast(broadcast_data)
>>> rdd = sc.parallelize(range(100), 1).map(lambda _: bd.value)
>>> rdd.map(id).distinct().count()
1
>>> rdd.cache().count()
100
>>> rdd.map(id).distinct().count()
2
Well, the devil is in the detail. To understand the reason why this may happen we'll have to take a closer look at the PySpark serializers. First lets create SparkContext
with default settings:
from pyspark import SparkContext
sc = SparkContext("local", "foo")
and check what is a default serializer:
sc.serializer
## AutoBatchedSerializer(PickleSerializer())
sc.serializer.bestSize
## 65536
It tells us three different things:
AutoBatchedSerializer
serializerPickleSerializer
to perform actual jobbestSize
of the serialized batched is 65536 bytes A quick glance at the source code will show you that this serialize adjusts number of records serialized at the time on the runtime and tries to keep batch size less than 10 * bestSize
. The important point is that not all records in the single partition are serialized at the same time.
We can check that experimentally as follows:
from operator import add
bd = sc.broadcast({})
rdd = sc.parallelize(range(10), 1).map(lambda _: bd.value)
rdd.map(id).distinct().count()
## 1
rdd.cache().count()
## 10
rdd.map(id).distinct().count()
## 2
As you can see even in this simple example after serialization-deserialization we get two distinct objects. You can observe similar behavior working directly with pickle
:
v = {}
vs = [v, v, v, v]
v1, *_, v4 = pickle.loads(pickle.dumps(vs))
v1 is v4
## True
(v1_, v2_), (v3_, v4_) = (
pickle.loads(pickle.dumps(vs[:2])),
pickle.loads(pickle.dumps(vs[2:]))
)
v1_ is v4_
## False
v3_ is v4_
## True
Values serialized in the same batch reference, after unpickling, the same object. Values from different batches point to different objects.
In practice Spark multiple serializes and different serialization strategies. You can for example use batches of infinite size:
from pyspark.serializers import BatchedSerializer, PickleSerializer
rdd_ = (sc.parallelize(range(10), 1).map(lambda _: bd.value)
._reserialize(BatchedSerializer(PickleSerializer())))
rdd_.cache().count()
rdd_.map(id).distinct().count()
## 1
You can change serializer by passing serializer
and / or batchSize
parameters to SparkContext
constructor:
sc = SparkContext(
"local", "bar",
serializer=PickleSerializer(), # Default serializer
# Unlimited batch size -> BatchedSerializer instead of AutoBatchedSerializer
batchSize=-1
)
sc.serializer
## BatchedSerializer(PickleSerializer(), -1)
Choosing different serializers and batching strategies results in different trade-offs (speed, ability to serialize arbitrary objects, memory requirements, etc.).
You should also remember that broadcast variables in Spark are not shared between executor threads so on the same worker can exist multiple deserialized copies at the same time.
Moreover you'll see a similar behavior to this if you execute a transformation which requires shuffling.