Recently I am looking into the dataset API in Tensorflow, and there is a method dataset.shard()
which is for distributed computations.
This is what's stated in Tensorflow's documentation:
Creates a Dataset that includes only 1/num_shards of this dataset.
d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
This method is said to return a portion of the original dataset. If I have two workers, am I supposed to do:
d_0 = d.shard(FLAGS.num_workers, worker_0)
d_1 = d.shard(FLAGS.num_workers, worker_1)
......
iterator_0 = d_0.make_initializable_iterator()
iterator_1 = d_1.make_initializable_iterator()
for worker_id in workers:
with tf.device(worker_id):
if worker_id == 0:
data = iterator_0.get_next()
else:
data = iterator_1.get_next()
......
Because the documentation did not specify how to make subsequent calls, I am a bit confused here.
Thanks!
You should take a look at the tutorial on Distributed TensorFlow first to better understand how it works.
You have multiple workers, that each run the same code but with a small difference: each worker will have a different FLAGS.worker_index
.
When you use tf.data.Dataset.shard
, you will supply this worker index and the data will be split between workers equally.
Here is an example with 3 workers.
dataset = tf.data.Dataset.range(6)
dataset = dataset.shard(FLAGS.num_workers, FLAGS.worker_index)
iterator = dataset.make_one_shot_iterator()
res = iterator.get_next()
# Suppose you have 3 workers in total
with tf.Session() as sess:
for i in range(2):
print(sess.run(res))
We will have the output:
0, 3
on worker 01, 4
on worker 12, 5
on worker 2