Skip to content

Commit

Permalink
NEW duration_type arg. REMOVED use_internal_timer arg
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Mar 22, 2024
1 parent 8defbb6 commit f24e34b
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions bnpm/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,13 @@ class Convergence_checker_optuna:
max_duration (float):
Maximum number of seconds to run before stopping.
(Default is *600*)
use_internal_timer (bool):
If ``True``, uses the internal timer. \n
If ``False``, uses the optuna study timer.
duration_type (str):
Type of timer to use for duration:\n
* 'internal': Use the time difference between the initialization
time and when the checker is called.
* 'study': Use the time difference between the first trial and when
the checker is called.
* 'trials': Use the sum of the durations of the trials.
verbose (bool):
If ``True``, print messages.
(Default is ``True``)
Expand Down Expand Up @@ -256,7 +260,7 @@ def __init__(
tol_frac: float = 0.05,
max_trials: int = 350,
max_duration: float = 60*10,
use_internal_timer: bool = False,
duration_type: str = 'internal',
verbose: bool = True,
):
"""
Expand All @@ -271,11 +275,14 @@ def __init__(
self.num_trial = 0
self.converged = False
self.reason_converged = None
self.use_internal_timer = use_internal_timer
self.duration_type = duration_type
self.verbose = verbose

self.time_start = time.time()
assert self.duration_type in ['internal', 'study', 'trials'], f"duration_type '{self.duration_type}' not recognized"

if self.duration_type == 'internal':
self.time_start = time.time()

def check(
self,
study: object,
Expand All @@ -291,14 +298,18 @@ def check(
trial (optuna.trial.FrozenTrial):
Optuna trial object.
"""
if self.use_internal_timer:
if self.duration_type == 'internal':
duration = time.time() - self.time_start
else:
elif self.duration_type == 'study':
dur_first, dur_last = study.trials[0].datetime_complete, trial.datetime_complete
if (dur_first is not None) and (dur_last is not None):
duration = (dur_last - dur_first).total_seconds()
else:
duration = 0
elif self.duration_type == 'trials':
duration = sum([t.duration.total_seconds() for t in study.trials])
else:
raise ValueError(f"duration_type '{self.duration_type}' not recognized")

if trial.value is not None:
if trial.value < self.best:
Expand Down

0 comments on commit f24e34b

Please sign in to comment.