Skip to content

Commit

Permalink
Add test_hyperopt_optuna_requeue_zombie_trials
Browse files Browse the repository at this point in the history
  • Loading branch information
elcorto committed May 30, 2024
1 parent e15c8e9 commit acff385
Showing 1 changed file with 126 additions and 0 deletions.
126 changes: 126 additions & 0 deletions test/hyperopt_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import importlib
import sqlite3

import optuna

import mala
import numpy as np
Expand Down Expand Up @@ -375,3 +378,126 @@ def __optimize_hyperparameters(hyper_optimizer):
test_trainer.train_network()
test_parameters.show()
return test_trainer.final_test_loss

def test_hyperopt_optuna_requeue_zombie_trials(self, tmp_path):

##tmp_path = os.environ["HOME"]

db_filename = f"{tmp_path}/test_ho.db"

# Set up parameters.
test_parameters = mala.Parameters()
test_parameters.data.data_splitting_type = "by_snapshot"
test_parameters.data.input_rescaling_type = "feature-wise-standard"
test_parameters.data.output_rescaling_type = "normal"
test_parameters.running.max_number_epochs = 2
test_parameters.running.mini_batch_size = 40
test_parameters.running.learning_rate = 0.00001
test_parameters.running.trainingtype = "Adam"
test_parameters.hyperparameters.n_trials = 2
test_parameters.hyperparameters.hyper_opt_method = "optuna"
test_parameters.hyperparameters.study_name = "test_ho"
test_parameters.hyperparameters.rdb_storage = (
f"sqlite:///{db_filename}"
)

# Load data.
data_handler = mala.DataHandler(test_parameters)
data_handler.add_snapshot(
"Be_snapshot0.in.npy",
data_path,
"Be_snapshot0.out.npy",
data_path,
"tr",
)
data_handler.add_snapshot(
"Be_snapshot1.in.npy",
data_path,
"Be_snapshot1.out.npy",
data_path,
"va",
)
data_handler.add_snapshot(
"Be_snapshot2.in.npy",
data_path,
"Be_snapshot2.out.npy",
data_path,
"te",
)
data_handler.prepare_data()

# Perform the hyperparameter optimization.
test_hp_optimizer = mala.HyperOpt(test_parameters, data_handler)
test_hp_optimizer.add_hyperparameter(
"float", "learning_rate", 0.0000001, 0.01
)
test_hp_optimizer.add_hyperparameter(
"int", "ff_neurons_layer_00", 10, 100
)
test_hp_optimizer.add_hyperparameter(
"int", "ff_neurons_layer_01", 10, 100
)
test_hp_optimizer.add_hyperparameter(
"categorical", "layer_activation_00", choices=["ReLU", "Sigmoid"]
)
test_hp_optimizer.add_hyperparameter(
"categorical", "layer_activation_01", choices=["ReLU", "Sigmoid"]
)
test_hp_optimizer.add_hyperparameter(
"categorical", "layer_activation_02", choices=["ReLU", "Sigmoid"]
)

def load_study():
return optuna.load_study(
study_name=test_parameters.hyperparameters.study_name,
storage=test_parameters.hyperparameters.rdb_storage,
)

# First run, create database.
test_hp_optimizer.perform_study()

assert (
test_hp_optimizer.study.trials_dataframe()["state"].to_list()
== ["COMPLETE"] * 2
)

# This is basically the same code as in requeue_zombie_trials() but it
# doesn't work. The trials here are FrozenTrial objects (in
# requeue_zombie_trials() as well!) and we get
# RuntimeError: Trial#0 has already finished and can not be updated.
# However this code below in requeue_zombie_trials() *does* work. Why?
#
##study = load_study()
####study = test_hp_optimizer.study
##for trial in study.get_trials():
## study._storage.set_trial_state_values(
## trial_id=trial._trial_id, state=optuna.trial.TrialState.RUNNING
## )

# Hack the db directly.
con = sqlite3.connect(db_filename)
cur = con.cursor()
cur.execute("update trials set state='RUNNING'")
con.commit()
con.close()

assert (
load_study().trials_dataframe()["state"].to_list()
== ["RUNNING"] * 2
)

test_hp_optimizer.requeue_zombie_trials(
study_name=test_parameters.hyperparameters.study_name,
rdb_storage=test_parameters.hyperparameters.rdb_storage,
)
assert (
load_study().trials_dataframe()["state"].to_list()
== ["WAITING"] * 2
)

# Second run adds one more trial.
test_hp_optimizer.perform_study()
assert (
test_hp_optimizer.study.trials_dataframe()["state"].to_list()
== ["COMPLETE"] * 3
)

0 comments on commit acff385

Please sign in to comment.