I'm trying to used the cache
transformation for a dataset
. Here is my current code (simplified):
dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=1)
dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=5000, count=1))
dataset = dataset.map(_parser_a, num_parallel_calls=12)
dataset = dataset.padded_batch(
20,
padded_shapes=padded_shapes,
padding_values=padding_values
)
dataset = dataset.prefetch(buffer_size=1)
dataset = dataset.cache()
After the first epoch, I received the following error message:
The calling iterator did not fully read the dataset we were attempting to cache. In order to avoid unexpected truncation of the sequence, the current [partially cached] sequence will be dropped. This can occur if you have a sequence similar to
dataset.cache().take(k).repeat()
. Instead, swap the order (i.e.dataset.take(k).cache().repeat()
)
Then, the code proceeded and still read data from the hard drive instead of the cache. So, where should I place dataset.cache()
to avoid the error?
Thanks.
The implementation of the Dataset.cache()
transformation is fairly simple: it builds up a list of the elements that pass through it as you iterate over completely it the first time, and it returns elements from that list on subsequent attempts to iterate over it. If the first pass only performs a partial pass over the data then the list is incomplete, and TensorFlow doesn't try to use the cached data, because it doesn't know whether the remaining elements will be needed, and in general it might need to reprocess all the preceding elements to compute the remaining elements.
By modifying your program to consume the entire dataset, and iterate over it until tf.errors.OutOfRangeError
is raised, the cache will have a complete list of the elements in the dataset, and it will be used on all subsequent iterations.