How to cache data during the first epoch correctly (Tensorflow, dataset)?

Maosi Chen picture Maosi Chen · May 25, 2018 · Viewed 7.6k times · Source

I'm trying to used the cache transformation for a dataset. Here is my current code (simplified):

dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=1)
dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=5000, count=1))
dataset = dataset.map(_parser_a, num_parallel_calls=12)
dataset = dataset.padded_batch(
    20, 
    padded_shapes=padded_shapes,
    padding_values=padding_values
)
dataset = dataset.prefetch(buffer_size=1)
dataset = dataset.cache()

After the first epoch, I received the following error message:

The calling iterator did not fully read the dataset we were attempting to cache. In order to avoid unexpected truncation of the sequence, the current [partially cached] sequence will be dropped. This can occur if you have a sequence similar to dataset.cache().take(k).repeat(). Instead, swap the order (i.e. dataset.take(k).cache().repeat())

Then, the code proceeded and still read data from the hard drive instead of the cache. So, where should I place dataset.cache() to avoid the error? Thanks.

Answer

mrry picture mrry · May 25, 2018

The implementation of the Dataset.cache() transformation is fairly simple: it builds up a list of the elements that pass through it as you iterate over completely it the first time, and it returns elements from that list on subsequent attempts to iterate over it. If the first pass only performs a partial pass over the data then the list is incomplete, and TensorFlow doesn't try to use the cached data, because it doesn't know whether the remaining elements will be needed, and in general it might need to reprocess all the preceding elements to compute the remaining elements.

By modifying your program to consume the entire dataset, and iterate over it until tf.errors.OutOfRangeError is raised, the cache will have a complete list of the elements in the dataset, and it will be used on all subsequent iterations.