How to make tf.data.Dataset return all of the elements in one call?

Milad picture Milad · Jan 6, 2018 · Viewed 9.6k times · Source

Is there an easy way to get the entire set of elements in a tf.data.Dataset? i.e. I want to set batch size of the Dataset to be the size of my dataset without specifically passing it the number of elements. This would be useful for validation dataset where I want to measure accuracy on the entire dataset in one go. I'm surprised there isn't a method to get the size of a tf.data.Dataset

Answer

muskrat picture muskrat · Jan 6, 2018

In short, there is not a good way to get the size/length; tf.data.Dataset is built for pipelines of data, so has an iterator structure (in my understanding and according to my read of the Dataset ops code. From the programmer's guide:

A tf.data.Iterator provides the main way to extract elements from a dataset. The operation returned by Iterator.get_next() yields the next element of a Dataset when executed, and typically acts as the interface between input pipeline code and your model.

And, by their nature, iterators do not have a convenient notion of size/length; see here: Getting number of elements in an iterator in Python

More generally though, why does this problem arise? If you are calling batch, you are also getting a tf.data.Dataset, so whatever you are running on a batch you should be able to run on the whole dataset; it will iterate through all the elements and calculate validation accuracy. Put differently, I don't think you actually need the size/length to do what you want to do.