Is RNN initial state reset for subsequent mini-batches?

VM_AI picture VM_AI · Jul 18, 2016 · Viewed 15.1k times · Source

Could someone please clarify whether the initial state of the RNN in TF is reset for subsequent mini-batches, or the last state of the previous mini-batch is used as mentioned in Ilya Sutskever et al., ICLR 2015 ?

Answer

danijar picture danijar · Jul 19, 2016

The tf.nn.dynamic_rnn() or tf.nn.rnn() operations allow to specify the initial state of the RNN using the initial_state parameter. If you don't specify this parameter, the hidden states will be initialized to zero vectors at the beginning of each training batch.

In TensorFlow, you can wrap tensors in tf.Variable() to keep their values in the graph between multiple session runs. Just make sure to mark them as non-trainable because the optimizers tune all trainable variables by default.

data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))

cell = tf.nn.rnn_cell.GRUCell(256)
state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
output, new_state = tf.nn.dynamic_rnn(cell, data, initial_state=state)

with tf.control_dependencies([state.assign(new_state)]):
    output = tf.identity(output)

sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(output, {data: ...})

I haven't tested this code but it should give you a hint in the right direction. There is also a tf.nn.state_saving_rnn() to which you can provide a state saver object, but I didn't use it yet.