How to iterate a dataset several times using TensorFlow's Dataset API?

void picture void · Nov 2, 2017 · Viewed 19.7k times · Source

How to output the value in a dataset several times? (dataset is created by Dataset API of TensorFlow)

import tensorflow as tf

dataset = tf.contrib.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
epoch = 10

for i in range(epoch):
   for j in range(100):
      value = sess.run(next_element)
      assert j == value
      print(j)

Error message:

tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]

How to make this work?

Answer

melgor89 picture melgor89 · Nov 10, 2017

First of all I advice you to read Data Set Guide. There is described all the details of DataSet API.

Your question is about iterating over the data several times. Here are two solutions for that:

  1. Iterating all epochs at once, no information about end of individual epochs
import tensorflow as tf

epoch   = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0
j = 0
while True:
    try:
        value = sess.run(next_element)
        assert j == value
        j += 1
        num_batch += 1
        if j > 99: # new epoch
            j = 0
    except tf.errors.OutOfRangeError:
        break

print ("Num Batch: ", num_batch)
  1. Second option inform you about ending each of epoch, so you can ex. check validation loss:
import tensorflow as tf

epoch = 10
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0

for e in range(epoch):
    print ("Epoch: ", e)
    j = 0
    sess.run(iterator.initializer)
    while True:
        try:
            value = sess.run(next_element)
            assert j == value
            j += 1
            num_batch += 1
        except tf.errors.OutOfRangeError:
            break

print ("Num Batch: ", num_batch)