Skip to content

Commit

Permalink
handling n_trials > mutation space use case
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian James Phillips committed Apr 9, 2024
1 parent 619d1bc commit 78de478
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions nomelt/thermo_estimation/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def __init__(self, wt: str, variant: str, estimator: nomelt.thermo_estimation.es
if args.optuna_overwrite and os.path.exists(self.params.optuna_storage):
logger.info(f"Overwriting optuna storage file: {self.params.optuna_storage}")
os.remove(self.params.optuna_storage)
##C.P. Edits
if self.params.n_trials > 2 ** len(self.mutation_set):
logger.info(f"mutation set < n_trials")
self.params.n_trials = 2 ** len(self.mutation_set)
logger.info(f"n_trials is now {self.params.n_trials}")

def _init_estimator_call(self, study):
logger.info("Running initial estimator call")
Expand Down Expand Up @@ -304,6 +309,7 @@ def _get_storage(self):
return storage

def run(self, n_jobs: int=1, client: Client=None):

"""Run the optimization.
Uses parallel workers if n_jobs > 1 and a dask.distributed.Client object is provided.
Expand Down Expand Up @@ -336,6 +342,7 @@ def run(self, n_jobs: int=1, client: Client=None):
pass

args = (OptunaObjective(self), self.params.n_trials, self.params.sampler, self.name, storage)

if n_jobs == 1:
_worker_optimize(*args)
else:
Expand Down Expand Up @@ -381,7 +388,6 @@ def prune(self, study, trial):
# type: (Study, FrozenTrial) -> bool

trials = study.get_trials(deepcopy=False)

numbers=np.array([t.number for t in trials])
bool_params= np.array([trial.params==t.params for t in trials]).astype(bool)
bool_in_play = np.array(
Expand Down Expand Up @@ -424,6 +430,7 @@ def _worker_optimize(objective, n_trials, sampler, study_name, storage):
worker = get_worker()
worker_id = worker.id


# Generating a unique log file name
log_file = f'./workers/worker_{worker_id}_{int(time.time())}.log'
os.makedirs(os.path.dirname(log_file), exist_ok=True)
Expand Down Expand Up @@ -640,4 +647,4 @@ def make_optimization_movie(self):





0 comments on commit 78de478

Please sign in to comment.