diff --git a/mala/network/hyper_opt_optuna.py b/mala/network/hyper_opt_optuna.py index 5024864d1..173ed4cec 100644 --- a/mala/network/hyper_opt_optuna.py +++ b/mala/network/hyper_opt_optuna.py @@ -176,9 +176,18 @@ def requeue_zombie_trials(study_name, rdb_storage): cleaned_trials = [] for trial in trials: if trial.state == optuna.trial.TrialState.RUNNING: - study_to_clean._storage.set_trial_state( - trial._trial_id, optuna.trial.TrialState.WAITING + kwds = dict( + trial_id=trial._trial_id, + state=optuna.trial.TrialState.WAITING, ) + if hasattr(study_to_clean._storage, "set_trial_state"): + # Optuna 2.x + study_to_clean._storage.set_trial_state(**kwds) + else: + # Optuna 3.x + study_to_clean._storage.set_trial_state_values( + values=None, **kwds + ) cleaned_trials.append(trial.number) printout("Cleaned trials: ", cleaned_trials, min_verbosity=0)