Could someone please clarify whether the initial state of the RNN in TF is reset for subsequent mini-batches, or the last state of the previous mini-batch is used as mentioned in Ilya Sutskever et al., ICLR 2015 ?
The tf.nn.dynamic_rnn()
or tf.nn.rnn()
operations allow to specify the initial state of the RNN using the initial_state
parameter. If you don't specify this parameter, the hidden states will be initialized to zero vectors at the beginning of each training batch.
In TensorFlow, you can wrap tensors in tf.Variable()
to keep their values in the graph between multiple session runs. Just make sure to mark them as non-trainable because the optimizers tune all trainable variables by default.
data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell = tf.nn.rnn_cell.GRUCell(256)
state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
output, new_state = tf.nn.dynamic_rnn(cell, data, initial_state=state)
with tf.control_dependencies([state.assign(new_state)]):
output = tf.identity(output)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(output, {data: ...})
I haven't tested this code but it should give you a hint in the right direction. There is also a tf.nn.state_saving_rnn()
to which you can provide a state saver object, but I didn't use it yet.