For training an LSTM model in Tensorflow, I have structured my data into a tf.train.SequenceExample format and stored it into a TFRecord file. I would now like to use the new DataSet API to generate padded batches for training. In the documentation there is an example for using padded_batch, but for my data I can't figure out what the value of padded_shapes should be.
For reading the TFrecord file into the batches I have written the following Python code:
import math
import tensorflow as tf
import numpy as np
import struct
import sys
import array
if(len(sys.argv) != 2):
print "Usage: createbatches.py [RFRecord file]"
sys.exit(0)
vectorSize = 40
inFile = sys.argv[1]
def parse_function_dataset(example_proto):
sequence_features = {
'inputs': tf.FixedLenSequenceFeature(shape=[vectorSize],
dtype=tf.float32),
'labels': tf.FixedLenSequenceFeature(shape=[],
dtype=tf.int64)}
_, sequence = tf.parse_single_sequence_example(example_proto, sequence_features=sequence_features)
length = tf.shape(sequence['inputs'])[0]
return sequence['inputs'], sequence['labels']
sess = tf.InteractiveSession()
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parse_function_dataset)
# dataset = dataset.batch(1)
dataset = dataset.padded_batch(4, padded_shapes=[None])
iterator = dataset.make_initializable_iterator()
batch = iterator.get_next()
# Initialize `iterator` with training data.
training_filenames = [inFile]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
print(sess.run(batch))
The code works well if I use dataset = dataset.batch(1)
(no padding needed in that case), but when I use the padded_batch
variant, I get the following error:
TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: .
Can you help me figuring out what I should pass for the padded_shapes parameter?
(I know there is lots of example code using threading and queues for this, but I'd rather use the new DataSet API for this project)
You need to pass a tuple of shapes. In your case you should pass
dataset = dataset.padded_batch(4, padded_shapes=([vectorSize],[None]))
or try
dataset = dataset.padded_batch(4, padded_shapes=([None],[None]))
Check this code for more details. I had to debug this method to figure out why it wasn't working for me.