Proper way to iterate tf.data.Dataset in session for 2.0

leonard picture leonard · May 31, 2019 · Viewed 7.9k times · Source

I have downloaded some *.tfrecord data from the youtube-8m project. You can download a 'small' portion of the data with this command:

curl data.yt8m.org/download.py | shard=1,100 partition=2/video/train mirror=us python

I am trying to get an idea of how to use the new tf.data API. I would like to become familiar with the typical ways people iterate through datasets. I have been using the guide on TF website and this slide: Derek Murray's Slides

Here is how I define the dataset:

# Use interleave() and prefetch() to read many files concurrently.
files = tf.data.Dataset.list_files("./youtube_vids/*.tfrecord")
dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(100),
                           cycle_length=8)

# Use num_parallel_calls to parallelize map().
dataset = dataset.map(lambda record: tf.parse_single_example(record, feature_map),
                     num_parallel_calls=2) #

# put in x,y output form
dataset = dataset.map(lambda x: (x['mean_rgb'], x['id']))

# shuffle
dataset = dataset.shuffle(10000)

#one epoch
dataset = dataset.repeat(1)
dataset = dataset.batch(200)

#Use prefetch() to overlap the producer and consumer.
dataset = dataset.prefetch(10)

Now, I know in eager execution mode I can just

for x,y in dataset:
    x,y

However, when I attempt to create an iterator as follows:

# A one-shot iterator automatically initializes itself on first use.
iterator = dset.make_one_shot_iterator()

# The return value of get_next() matches the dataset element type.
images, labels = iterator.get_next()

And run with session

with tf.Session() as sess:

    # Loop until all elements have been consumed.
    try:
        while True:
            r = sess.run(images)
    except tf.errors.OutOfRangeError:
        pass

I get the warning

Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.

So, here is my question:

What is the proper way to iterate through a dataset within a session? Is it just a matter of v1 and v2 differences?

Also, the advice to pass the dataset directly to an estimator implies that the input function also has an iterator defined as in Derek Murray's slides above, correct?

Answer

Sharky picture Sharky · May 31, 2019

As for Estimator API, no you don't have to specify iterator, just pass dataset object as input function.

def input_fn(filename):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.shuffle().repeat()
    dataset = dataset.map(parse_func)
    dataset = dataset.batch()
    return dataset

estimator.train(input_fn=lambda: input_fn())

In TF 2.0 dataset became iterable, so, just as warning message says, you can use

for x,y in dataset:
    x,y