What does batch, repeat, and shuffle do with TensorFlow Dataset?

blue picture blue · Nov 28, 2018 · Viewed 16.5k times · Source

I'm currently learning TensorFlow but i come across a confusion within this code:

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()

i know first the dataset will hold all the data but what shuffle(),repeat(), and batch() do to the dataset? please give me an explanation with an example

Answer

Vlad-HC picture Vlad-HC · Nov 28, 2018

Imagine, you have a dataset: [1, 2, 3, 4, 5, 6], then:

How ds.shuffle() works

dataset.shuffle(buffer_size=3) will allocate a buffer of size 3 for picking random entries. This buffer will be connected to the source dataset. We could image it like this:

Random buffer
   |
   |   Source dataset where all other elements live
   |         |
   ↓         ↓
[1,2,3] <= [4,5,6]

Let's assume that the entry 2 was taken from the random buffer. Free space is filled by the next element from the source buffer, that is 4:

2 <= [1,3,4] <= [5,6]

We continue reading till nothing is left:

1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]      <= []
4 <= []      <= []

How ds.repeat() works

As soon as all the entries are read from the dataset and you try to read the next element, the dataset will throw an error. That's where ds.repeat() comes into play. It will re-initialize the dataset, making it again like this:

[1,2,3] <= [4,5,6]

What will ds.batch() produce

The ds.batch() will take first batch_size entries and make a batch out of them. So, batch size of 3 for our example dataset will produce two batch records:

[2,1,5]
[3,6,4]

As we have a ds.repeat() before the batch, the generation of the data will continue. But the order of the elements will be different, due to the ds.random(). What should be taken into account is that 6 will never be present in the first batch, due to the size of the random buffer.