tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?

nessuno picture nessuno · Jun 7, 2018 · Viewed 41k times · Source

Let's say I have defined a dataset in this way:

filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))

how can I get the number of elements that are inside the dataset (hence, the number of single elements that compose an epoch)?

I know that tf.data.Dataset already knows the dimension of the dataset, because the repeat() method allows repeating the input pipeline for a specified number of epochs. So it must be a way to get this information.

Answer

markemus picture markemus · May 29, 2019

len(list(dataset)) works in eager mode, although that's obviously not a good general solution.