Skip to content

Commit

Permalink
Allow EarlyStopping to be reused between multiple fits.
Browse files Browse the repository at this point in the history
All values were already reset properly in `on_train_begin` except `best`.

Fixes keras-team#20521
  • Loading branch information
hertschuh committed Nov 22, 2024
1 parent a93828a commit 977e080
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
6 changes: 3 additions & 3 deletions keras/src/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,12 @@ def _set_monitor_op(self):
)
if self.monitor_op == ops.less:
self.min_delta *= -1
self.best = (
float("inf") if self.monitor_op == ops.less else -float("inf")
)

def on_train_begin(self, logs=None):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = None
self.best_weights = None
self.best_epoch = 0

Expand Down Expand Up @@ -210,4 +208,6 @@ def get_monitor_value(self, logs):
return monitor_value

def _is_improvement(self, monitor_value, reference_value):
if reference_value is None:
return True
return self.monitor_op(monitor_value - self.min_delta, reference_value)
13 changes: 7 additions & 6 deletions keras/src/callbacks/early_stopping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,17 @@ def test_early_stopping_reuse(self):
loss="mae",
metrics=["mse"],
)
weights = model.get_weights()
stopper = callbacks.EarlyStopping(monitor="mse", patience=patience)

# This should allow training to go for at least `patience` epochs
model.set_weights(weights)
history1 = model.fit(
data, labels, callbacks=[stopper], verbose=0, epochs=20
)
self.assertGreaterEqual(len(history1.epoch), patience)

stopper = callbacks.EarlyStopping(monitor="mse", patience=patience)
hist = model.fit(
history2 = model.fit(
data, labels, callbacks=[stopper], verbose=0, epochs=20
)
assert len(hist.epoch) >= patience
self.assertGreaterEqual(len(history2.epoch), patience)

@pytest.mark.requires_trainable_backend
def test_early_stopping_with_baseline(self):
Expand Down

0 comments on commit 977e080

Please sign in to comment.