Early stopping with tf.estimator, how?

Carl Thomé picture Carl Thomé · Nov 6, 2017 · Viewed 12k times · Source

I'm using tf.estimator in TensorFlow 1.4 and tf.estimator.train_and_evaluate is great but I need early stopping. What's the prefered way of adding that?

I assume there is some tf.train.SessionRunHook somewhere for this. I saw that there was an old contrib package with a ValidationMonitor that seemed to have early stopping, but it doesn't seem to be around anymore in 1.4. Or will the preferred way in the future be to rely on tf.keras (with which early stopping is really easy) instead of tf.estimator/tf.layers/tf.data, perhaps?

Answer

Carl Thomé picture Carl Thomé · Jul 11, 2018

Good news! tf.estimator now has early stopping support on master and it looks like it will be in 1.10.

estimator = tf.estimator.Estimator(model_fn, model_dir)

os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name='loss',
    max_steps_without_decrease=1000,
    min_steps=100)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
    eval_spec=tf.estimator.EvalSpec(eval_input_fn))