How to deal with batches with variable-length sequences in TensorFlow?

Seja Nair picture Seja Nair · Jan 8, 2016 · Viewed 43k times · Source

I was trying to use an RNN (specifically, LSTM) for sequence prediction. However, I ran into an issue with variable sequence lengths. For example,

sent_1 = "I am flying to Dubain"
sent_2 = "I was traveling from US to Dubai"

I am trying to predicting the next word after the current one with a simple RNN based on this Benchmark for building a PTB LSTM model.

However, the num_steps parameter (used for unrolling to the previous hidden states), should remain the same in each Tensorflow's epoch. Basically, batching sentences is not possible as the sentences vary in length.

 # inputs = [tf.squeeze(input_, [1])
 #           for input_ in tf.split(1, num_steps, inputs)]
 # outputs, states = rnn.rnn(cell, inputs, initial_state=self._initial_state)

Here, num_steps need to be changed in my case for every sentence. I have tried several hacks, but nothing seems working.

Answer

Taras Sereda picture Taras Sereda · Jan 8, 2016

You can use the ideas of bucketing and padding which are described in:

    Sequence-to-Sequence Models

Also, the rnn function which creates RNN network accepts parameter sequence_length.

As an example, you can create buckets of sentences of the same size, pad them with the necessary amount of zeros, or placeholders which stand for zero word and afterwards feed them along with seq_length = len(zero_words).

seq_length = tf.placeholder(tf.int32)
outputs, states = rnn.rnn(cell, inputs, initial_state=initial_state, sequence_length=seq_length)

sess = tf.Session()
feed = {
    seq_length: 20,
    #other feeds
}
sess.run(outputs, feed_dict=feed)

Take a look at this reddit thread as well:

   Tensorflow basic RNN example with 'variable length' sequences