tensorflow dataset shuffle then batch or batch then shuffle

Lim Kaizhuo picture Lim Kaizhuo · May 20, 2018 · Viewed 7.2k times · Source

I recently began learning tensorflow.

I am unsure about whether there is a difference

x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.shuffle(buffer_size=4)
ds.batch(4)

and

x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.batch(4)
ds.shuffle(buffer_size=4)

Also, I am not sure why I cannot use

dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)

as it gives the error

dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)
AttributeError: 'TensorSliceDataset' object has no attribute 'shuffle_batch'

Thank you!

Answer

mrry picture mrry · May 21, 2018

TL;DR: Yes, there is a difference. Almost always, you will want to call Dataset.shuffle() before Dataset.batch(). There is no shuffle_batch() method on the tf.data.Dataset class, and you must call the two methods separately to shuffle and batch a dataset.


The transformations of a tf.data.Dataset are applied in the same sequence that they are called. Dataset.batch() combines consecutive elements of its input into a single, batched element in the output. We can see the effect of the order of operations by considering the following two datasets:

tf.enable_eager_execution()  # To simplify the example code.

# Batch before shuffle.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.batch(3)
dataset = dataset.shuffle(9)

for elem in dataset:
  print(elem)

# Prints:
# tf.Tensor([1 1 1], shape=(3,), dtype=int32)
# tf.Tensor([2 2 2], shape=(3,), dtype=int32)
# tf.Tensor([0 0 0], shape=(3,), dtype=int32)

# Shuffle before batch.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.shuffle(9)
dataset = dataset.batch(3)

for elem in dataset:
  print(elem)

# Prints:
# tf.Tensor([2 0 2], shape=(3,), dtype=int32)
# tf.Tensor([2 1 0], shape=(3,), dtype=int32)
# tf.Tensor([0 1 1], shape=(3,), dtype=int32)

In the first version (batch before shuffle), the elements of each batch are 3 consecutive elements from the input; whereas in the second version (shuffle before batch), they are randomly sampled from the input. Typically, when training by (some variant of) mini-batch stochastic gradient descent, the elements of each batch should be sampled as uniformly as possible from the total input. Otherwise, it is possible that the network will overfit to whatever structure was in the input data, and the resulting network will not achieve as high an accuracy.