How do I create padded batches in Tensorflow for tf.train.SequenceExample data using the DataSet API?

Marijn Huijbregts picture Marijn Huijbregts · Aug 30, 2017 · Viewed 13.3k times · Source

For training an LSTM model in Tensorflow, I have structured my data into a tf.train.SequenceExample format and stored it into a TFRecord file. I would now like to use the new DataSet API to generate padded batches for training. In the documentation there is an example for using padded_batch, but for my data I can't figure out what the value of padded_shapes should be.

For reading the TFrecord file into the batches I have written the following Python code:

import math
import tensorflow as tf
import numpy as np
import struct
import sys
import array

if(len(sys.argv) != 2):
  print "Usage: createbatches.py [RFRecord file]"
  sys.exit(0)


vectorSize = 40
inFile = sys.argv[1]

def parse_function_dataset(example_proto):
  sequence_features = {
      'inputs': tf.FixedLenSequenceFeature(shape=[vectorSize],
                                           dtype=tf.float32),
      'labels': tf.FixedLenSequenceFeature(shape=[],
                                           dtype=tf.int64)}

  _, sequence = tf.parse_single_sequence_example(example_proto, sequence_features=sequence_features)

  length = tf.shape(sequence['inputs'])[0]
  return sequence['inputs'], sequence['labels']

sess = tf.InteractiveSession()

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parse_function_dataset)
# dataset = dataset.batch(1)
dataset = dataset.padded_batch(4, padded_shapes=[None])
iterator = dataset.make_initializable_iterator()

batch = iterator.get_next()

# Initialize `iterator` with training data.
training_filenames = [inFile]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})

print(sess.run(batch))

The code works well if I use dataset = dataset.batch(1) (no padding needed in that case), but when I use the padded_batch variant, I get the following error:

TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: .

Can you help me figuring out what I should pass for the padded_shapes parameter?

(I know there is lots of example code using threading and queues for this, but I'd rather use the new DataSet API for this project)

Answer

Zaher Wanli picture Zaher Wanli · Aug 30, 2017

You need to pass a tuple of shapes. In your case you should pass

dataset = dataset.padded_batch(4, padded_shapes=([vectorSize],[None]))

or try

dataset = dataset.padded_batch(4, padded_shapes=([None],[None]))

Check this code for more details. I had to debug this method to figure out why it wasn't working for me.