How to use dataset.shard in tensorflow?

Jiang Wenbo picture Jiang Wenbo · Feb 13, 2018 · Viewed 7.7k times · Source

Recently I am looking into the dataset API in Tensorflow, and there is a method dataset.shard() which is for distributed computations.

This is what's stated in Tensorflow's documentation:

Creates a Dataset that includes only 1/num_shards of this dataset.

d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)

This method is said to return a portion of the original dataset. If I have two workers, am I supposed to do:

d_0 = d.shard(FLAGS.num_workers, worker_0)
d_1 = d.shard(FLAGS.num_workers, worker_1)
......
iterator_0 = d_0.make_initializable_iterator()
iterator_1 = d_1.make_initializable_iterator()

for worker_id in workers:
    with tf.device(worker_id):
        if worker_id == 0:
            data = iterator_0.get_next()
        else:
            data = iterator_1.get_next()
        ......

Because the documentation did not specify how to make subsequent calls, I am a bit confused here.

Thanks!

Answer

Olivier Moindrot picture Olivier Moindrot · Feb 20, 2018

You should take a look at the tutorial on Distributed TensorFlow first to better understand how it works.

You have multiple workers, that each run the same code but with a small difference: each worker will have a different FLAGS.worker_index.

When you use tf.data.Dataset.shard, you will supply this worker index and the data will be split between workers equally.

Here is an example with 3 workers.

dataset = tf.data.Dataset.range(6)
dataset = dataset.shard(FLAGS.num_workers, FLAGS.worker_index)


iterator = dataset.make_one_shot_iterator()
res = iterator.get_next()

# Suppose you have 3 workers in total
with tf.Session() as sess:
    for i in range(2):
        print(sess.run(res))

We will have the output:

  • 0, 3 on worker 0
  • 1, 4 on worker 1
  • 2, 5 on worker 2