Skip to content

Commit

Permalink
Add use_internal_timer parameter to Convergence_checker_optuna class
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Mar 18, 2024
1 parent 2992de9 commit 0c84b54
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions bnpm/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ class Convergence_checker_optuna:
max_duration (float):
Maximum number of seconds to run before stopping.
(Default is *600*)
use_internal_timer (bool):
If ``True``, uses the internal timer. \n
If ``False``, uses the optuna study timer.
verbose (bool):
If ``True``, print messages.
(Default is ``True``)
Expand Down Expand Up @@ -253,6 +256,7 @@ def __init__(
tol_frac: float = 0.05,
max_trials: int = 350,
max_duration: float = 60*10,
use_internal_timer: bool = False,
verbose: bool = True,
):
"""
Expand All @@ -267,7 +271,10 @@ def __init__(
self.num_trial = 0
self.converged = False
self.reason_converged = None
self.use_internal_timer = use_internal_timer
self.verbose = verbose

self.time_start = time.time()

def check(
self,
Expand All @@ -284,11 +291,14 @@ def check(
trial (optuna.trial.FrozenTrial):
Optuna trial object.
"""
dur_first, dur_last = study.trials[0].datetime_complete, trial.datetime_complete
if (dur_first is not None) and (dur_last is not None):
duration = (dur_last - dur_first).total_seconds()
if self.use_internal_timer:
duration = time.time() - self.time_start
else:
duration = 0
dur_first, dur_last = study.trials[0].datetime_complete, trial.datetime_complete
if (dur_first is not None) and (dur_last is not None):
duration = (dur_last - dur_first).total_seconds()
else:
duration = 0

if trial.value is not None:
if trial.value < self.best:
Expand All @@ -300,11 +310,11 @@ def check(
self.converged, self.reason_converged = True, 'tol_frac'
print(f'Stopping. Convergence reached. Best value ({self.best*10000}) over last ({self.n_patience}) trials fractionally changed less than ({self.tol_frac})') if self.verbose else None
study.stop()
if self.num_trial >= self.max_trials:
elif self.num_trial >= self.max_trials:
self.converged, self.reason_converged = True, 'max_trials'
print(f'Stopping. Trial number limit reached. num_trial={self.num_trial}, max_trials={self.max_trials}.') if self.verbose else None
study.stop()
if duration > self.max_duration:
elif duration > self.max_duration:
self.converged, self.reason_converged = True, 'max_duration'
print(f'Stopping. Duration limit reached. study.duration={duration}, max_duration={self.max_duration}.') if self.verbose else None
study.stop()
Expand Down

0 comments on commit 0c84b54

Please sign in to comment.