How do I split Tensorflow datasets?

Lukas Hestermeyer picture Lukas Hestermeyer · Jul 1, 2018 · Viewed 22.6k times · Source

I have a tensorflow dataset based on one .tfrecord file. How do I split the dataset into test and train datasets? E.g. 70% Train and 30% test?

Edit:

My Tensorflow Version: 1.8 I've checked, there is no "split_v" function as mentioned in the possible duplicate. Also I am working with a tfrecord file.

Answer

ted picture ted · Jul 1, 2018

You may use Dataset.take() and Dataset.skip():

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)

For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.

Take:

Creates a Dataset with at most count elements from this dataset.

Skip:

Creates a Dataset that skips count elements from this dataset.

You may also want to look into Dataset.shard():

Creates a Dataset that includes only 1/num_shards of this dataset.