From 0c84b545d1bd4214a23484ce02d8e2e03eaf0cf1 Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Mon, 18 Mar 2024 01:50:24 -0400 Subject: [PATCH] Add use_internal_timer parameter to Convergence_checker_optuna class --- bnpm/optimization.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/bnpm/optimization.py b/bnpm/optimization.py index 3d3d758..b0ffb6b 100644 --- a/bnpm/optimization.py +++ b/bnpm/optimization.py @@ -220,6 +220,9 @@ class Convergence_checker_optuna: max_duration (float): Maximum number of seconds to run before stopping. (Default is *600*) + use_internal_timer (bool): + If ``True``, uses the internal timer. \n + If ``False``, uses the optuna study timer. verbose (bool): If ``True``, print messages. (Default is ``True``) @@ -253,6 +256,7 @@ def __init__( tol_frac: float = 0.05, max_trials: int = 350, max_duration: float = 60*10, + use_internal_timer: bool = False, verbose: bool = True, ): """ @@ -267,7 +271,10 @@ def __init__( self.num_trial = 0 self.converged = False self.reason_converged = None + self.use_internal_timer = use_internal_timer self.verbose = verbose + + self.time_start = time.time() def check( self, @@ -284,11 +291,14 @@ def check( trial (optuna.trial.FrozenTrial): Optuna trial object. """ - dur_first, dur_last = study.trials[0].datetime_complete, trial.datetime_complete - if (dur_first is not None) and (dur_last is not None): - duration = (dur_last - dur_first).total_seconds() + if self.use_internal_timer: + duration = time.time() - self.time_start else: - duration = 0 + dur_first, dur_last = study.trials[0].datetime_complete, trial.datetime_complete + if (dur_first is not None) and (dur_last is not None): + duration = (dur_last - dur_first).total_seconds() + else: + duration = 0 if trial.value is not None: if trial.value < self.best: @@ -300,11 +310,11 @@ def check( self.converged, self.reason_converged = True, 'tol_frac' print(f'Stopping. Convergence reached. Best value ({self.best*10000}) over last ({self.n_patience}) trials fractionally changed less than ({self.tol_frac})') if self.verbose else None study.stop() - if self.num_trial >= self.max_trials: + elif self.num_trial >= self.max_trials: self.converged, self.reason_converged = True, 'max_trials' print(f'Stopping. Trial number limit reached. num_trial={self.num_trial}, max_trials={self.max_trials}.') if self.verbose else None study.stop() - if duration > self.max_duration: + elif duration > self.max_duration: self.converged, self.reason_converged = True, 'max_duration' print(f'Stopping. Duration limit reached. study.duration={duration}, max_duration={self.max_duration}.') if self.verbose else None study.stop()