From f24e34bfb2a52379e6b39945c1dced5041bed38f Mon Sep 17 00:00:00 2001
From: RichieHakim <RichHakim@gmail.com>
Date: Fri, 22 Mar 2024 16:29:58 -0400
Subject: [PATCH] NEW duration_type arg. REMOVED use_internal_timer arg

---
 bnpm/optimization.py | 27 +++++++++++++++++++--------
 1 file changed, 19 insertions(+), 8 deletions(-)

diff --git a/bnpm/optimization.py b/bnpm/optimization.py
index cd170f1..2730c1e 100644
--- a/bnpm/optimization.py
+++ b/bnpm/optimization.py
@@ -220,9 +220,13 @@ class Convergence_checker_optuna:
         max_duration (float): 
             Maximum number of seconds to run before stopping. 
             (Default is *600*)
-        use_internal_timer (bool):
-            If ``True``, uses the internal timer. \n
-            If ``False``, uses the optuna study timer.
+        duration_type (str):
+            Type of timer to use for duration:\n
+                * 'internal': Use the time difference between the initialization
+                  time and when the checker is called.
+                * 'study': Use the time difference between the first trial and when
+                   the checker is called.
+                * 'trials': Use the sum of the durations of the trials.
         verbose (bool): 
             If ``True``, print messages. 
             (Default is ``True``)
@@ -256,7 +260,7 @@ def __init__(
         tol_frac: float = 0.05, 
         max_trials: int = 350, 
         max_duration: float = 60*10, 
-        use_internal_timer: bool = False,
+        duration_type: str = 'internal',
         verbose: bool = True,
     ):
         """
@@ -271,11 +275,14 @@ def __init__(
         self.num_trial = 0
         self.converged = False
         self.reason_converged = None
-        self.use_internal_timer = use_internal_timer
+        self.duration_type = duration_type
         self.verbose = verbose
 
-        self.time_start = time.time()
+        assert self.duration_type in ['internal', 'study', 'trials'], f"duration_type '{self.duration_type}' not recognized"
         
+        if self.duration_type == 'internal':
+            self.time_start = time.time()
+
     def check(
         self, 
         study: object, 
@@ -291,14 +298,18 @@ def check(
             trial (optuna.trial.FrozenTrial): 
                 Optuna trial object.
         """
-        if self.use_internal_timer:
+        if self.duration_type == 'internal':
             duration = time.time() - self.time_start
-        else:
+        elif self.duration_type == 'study':
             dur_first, dur_last = study.trials[0].datetime_complete, trial.datetime_complete
             if (dur_first is not None) and (dur_last is not None):
                 duration = (dur_last - dur_first).total_seconds()
             else:
                 duration = 0
+        elif self.duration_type == 'trials':
+            duration = sum([t.duration.total_seconds() for t in study.trials])
+        else:
+            raise ValueError(f"duration_type '{self.duration_type}' not recognized")
         
         if trial.value is not None:
             if trial.value < self.best: