Skip to content

Commit

Permalink
Merge pull request #138 from MaxFBurg/include_seed_into_random_hypers…
Browse files Browse the repository at this point in the history
…earch

Adds option to include seed into random search
  • Loading branch information
mohammadbashiri authored May 30, 2022
2 parents b1b388a + c2d81f1 commit f2e7c61
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions nnfabrik/utility/hypersearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ class Random:
trainer_fn (str): name of the trainer function
trainer_config (dict): dictionary of arguments for trainer function that are fixed
trainer_config_auto (dict): dictionary of arguments for trainer function that are to be randomly sampled
seed_config_auto (dict): dictionary of arguments for setting (`dict(seed={"type": "fixed", "value": <VALUE>})`)
or random sampling (`dict(seed={"type": "int"})`) the seed
architect (str): Name of the contributor that added this entry
trained_model_table (str): name (importable) of the trained_model_table
total_trials (int, optional): Number of experiments (i.e. training) to run. Defaults to 5.
Expand All @@ -283,6 +285,7 @@ def __init__(
trainer_fn,
trainer_config,
trainer_config_auto,
seed_config_auto,
architect,
trained_model_table,
total_trials=5,
Expand All @@ -291,7 +294,9 @@ def __init__(

self.fns = dict(dataset=dataset_fn, model=model_fn, trainer=trainer_fn)
self.fixed_params = self.get_fixed_params(dataset_config, model_config, trainer_config)
self.auto_params = self.get_auto_params(dataset_config_auto, model_config_auto, trainer_config_auto)
self.auto_params = self.get_auto_params(
dataset_config_auto, model_config_auto, trainer_config_auto, seed_config_auto
)
self.architect = architect
self.total_trials = total_trials
self.comment = comment
Expand All @@ -317,7 +322,7 @@ def get_fixed_params(dataset_config, model_config, trainer_config):
return dict(dataset=dataset_config, model=model_config, trainer=trainer_config)

@staticmethod
def get_auto_params(dataset_config_auto, model_config_auto, trainer_config_auto):
def get_auto_params(dataset_config_auto, model_config_auto, trainer_config_auto, seed_config_auto):
"""
Returns the parameters, which are to be randomly sampled, in a list.
Here we followed the same convention as in the Bayesian class, to have the API as similar as possible.
Expand Down Expand Up @@ -348,7 +353,13 @@ def get_auto_params(dataset_config_auto, model_config_auto, trainer_config_auto)
dd.update(v)
trainer_params.append(dd)

return dataset_params + model_params + trainer_params
seed_params = []
for k, v in seed_config_auto.items():
dd = {"name": "seed.{}".format(k)}
dd.update(v)
seed_params.append(dd)

return dataset_params + model_params + trainer_params + seed_params

@staticmethod
def _combine_params(auto_params, fixed_params):
Expand All @@ -365,10 +376,10 @@ def _combine_params(auto_params, fixed_params):
dict: dictionary of parameters (fixed and to-be-sampled), i.e. A dictionary of dictionaries where keys are dataset,
model, and trainer and the values are the corresponding dictionary of arguments.
"""
keys = ["dataset", "model", "trainer"]
keys = ["dataset", "model", "trainer", "seed"]
params = {}
for key in keys:
params[key] = fixed_params[key]
params[key] = fixed_params[key] if key in fixed_params else {}
params[key].update(auto_params[key])

return {key: params[key] for key in keys}
Expand All @@ -386,7 +397,7 @@ def _split_config(params):
dict: A dictionary of dictionaries where keys are dataset, model, and trainer and the values are the corresponding
dictionary of to-be-sampled arguments.
"""
config = dict(dataset={}, model={}, trainer={}, others={})
config = dict(dataset={}, model={}, trainer={}, seed={}, others={})
for k, v in params.items():
config[k.split(".")[0]][k.split(".")[1]] = v

Expand All @@ -404,6 +415,10 @@ def train_evaluate(self, auto_params):
config = self._combine_params(self._split_config(auto_params), self.fixed_params)

# insert the stuff into their corresponding tables
seed = config["seed"]["seed"]
if not dict(seed=seed) in self.trained_model_table().seed_table():
self.trained_model_table().seed_table().insert1(dict(seed=seed))

dataset_hash = make_hash(config["dataset"])
entry_exists = {
"dataset_fn": "{}".format(self.fns["dataset"])
Expand Down Expand Up @@ -446,6 +461,7 @@ def train_evaluate(self, auto_params):

# get the primary key values for all those entries
restriction = (
f'seed in ("{seed}")',
'dataset_fn in ("{}")'.format(self.fns["dataset"]),
'dataset_hash in ("{}")'.format(dataset_hash),
'model_fn in ("{}")'.format(self.fns["model"]),
Expand Down Expand Up @@ -476,14 +492,14 @@ def gen_params_value(self):
auto_params_val.update({param["name"]: loguniform.rvs(*param["bounds"])})
else:
auto_params_val.update({param["name"]: np.random.uniform(*param["bounds"])})
elif param["type"] == "int":
auto_params_val.update({param["name"]: np.random.randint(np.iinfo(np.int32).max)})

return auto_params_val

def run(self):
"""
Runs the random hyperparameter search, for as many trials as specified.
"""
n_trials = len(self.trained_model_table().seed_table()) * self.total_trials
init_len = len(self.trained_model_table())
while len(self.trained_model_table()) - init_len < n_trials:
for _ in range(self.total_trials):
self.train_evaluate(self.gen_params_value())

0 comments on commit f2e7c61

Please sign in to comment.