When does keras reset an LSTM state?

Daniel Möller picture Daniel Möller · May 10, 2017 · Viewed 18.5k times · Source

I read all sorts of texts about it, and none seem to answer this very basic question. It's always ambiguous:

In a stateful = False LSTM layer, does keras reset states after:

  • Each sequence; or
  • Each batch?

Suppose I have X_train shaped as (1000,20,1), meaning 1000 sequences of 20 steps of a single value. If I make:

model.fit(X_train, y_train, batch_size=200, nb_epoch=15)

Will it reset states for every single sequence (resets states 1000 times)?
Or will it reset states for every batch (resets states 5 times)?

Answer

Daniel Möller picture Daniel Möller · Sep 20, 2017

Cheking with some tests, I got to the following conclusion, which is according to the documentation and to Nassim's answer:

First, there isn't a single state in a layer, but one state per sample in the batch. There are batch_size parallel states in such a layer.

Stateful=False

In a stateful=False case, all the states are resetted together after each batch.

  • A batch with 10 sequences would create 10 states, and all 10 states are resetted automatically after it's processed.

  • The next batch with 10 sequences will create 10 new states, which will also be resetted after this batch is processed

If all those sequences have length (timesteps) = 7, the practical result of these two batches is:

20 individual sequences, each with length 7

None of the sequences are related. But of course: the weights (not the states) will be unique for the layer, and will represent what the layer has learned from all the sequences.

  • A state is: Where am I now inside a sequence? Which time step is it? How is this particular sequence behaving since its beginning up to now?
  • A weight is: What do I know about the general behavior of all sequences I've seen so far?

Stateful=True

In this case, there is also the same number of parallel states, but they will simply not be resetted at all.

  • A batch with 10 sequences will create 10 states that will remain as they are at the end of the batch.

  • The next batch with 10 sequences (it's required to be 10, since the first was 10) will reuse the same 10 states that were created before.

The practical result is: the 10 sequences in the second batch are just continuing the 10 sequences of the first batch, as if there had been no interruption at all.

If each sequence has length (timesteps) = 7, then the actual meaning is:

10 individual sequences, each with length 14

When you see that you reached the total length of the sequences, then you call model.reset_states(), meaning you will not continue the previous sequences anymore, now you will start feeding new sequences.