I have recently been working on a project that uses a neural network for virtual robot control. I used tensorflow to code it up and it runs smoothly. So far, I used sequential simulations to evaluate how good the neural network is, however, I want to run several simulations in parallel to reduce the amount of time it takes to get data.
To do this I am importing python's multiprocessing
package. Initially I was passing the sess variable (sess=tf.Session()
) to a function that would run the simulation. However, once I get to any statement that uses this sess
variable, the process quits without a warning. After searching around for a bit I found these two posts:
Tensorflow: Passing a session to a python multiprocess
and Running multiple tensorflow sessions concurrently
While they are highly related I haven't been able to figure out how to make it work. I tried creating a session for each individual process and assigning the weights of the neural net to its trainable parameters without success. I've also tried saving the session into a file and then loading it within a process, but no luck there either.
Has someone been able to pass a session (or clones of sessions) to several processes?
Thanks.
You can't use Python multiprocessing to pass a TensorFlow Session
into a multiprocessing.Pool
in the straightfoward way because the Session
object can't be pickled (it's fundamentally not serializable because it may manage GPU memory and state like that).
I'd suggest parallelizing the code using actors, which are essentially the parallel computing analog of "objects" and use used to manage state in the distributed setting.
Ray is a good framework for doing this. You can define a Python class which manages the TensorFlow Session
and exposes a method for running your simulation.
import ray
import tensorflow as tf
ray.init()
@ray.remote
class Simulator(object):
def __init__(self):
self.sess = tf.Session()
self.simple_model = tf.constant([1.0])
def simulate(self):
return self.sess.run(self.simple_model)
# Create two actors.
simulators = [Simulator.remote() for _ in range(2)]
# Run two simulations in parallel.
results = ray.get([s.simulate.remote() for s in simulators])
Here are a few more examples of parallelizing TensorFlow with Ray.
See the Ray documentation. Note that I'm one of the Ray developers.