Skip to content

Commit

Permalink
fix: simplify ExperimentManager using exp_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
ankeko committed Feb 8, 2024
1 parent 1e94add commit 9a54dc9
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions niceml/experiments/experimentmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,26 @@ class ExperimentManager(object):

def __init__(self, experiments: List[ExperimentData] = None):
"""Manages a list of experiments"""
self.experiments = [] if experiments is None else experiments
self.exp_dict = {exp.get_short_id(): exp for exp in self.experiments}
self.exp_dict.update({exp.get_run_id(): exp for exp in self.experiments})
self.exp_dict_short_id = {exp.get_short_id(): exp for exp in experiments}
self.exp_dict_run_id = {exp.get_run_id(): exp for exp in experiments}

def add_experiment(self, experiment: ExperimentData):
"""Adds an experiment to the manager"""
self.experiments.append(experiment)
self.exp_dict[experiment.get_short_id()] = experiment
self.exp_dict[experiment.get_run_id()] = experiment
self.exp_dict_short_id[experiment.get_short_id()] = experiment
self.exp_dict_run_id[experiment.get_run_id()] = experiment

def __contains__(self, exp_id: Union[str, ExperimentInfo]):
"""Checks if the experiment is in the manager"""
if type(exp_id) == ExperimentInfo:
exp_id = exp_id.short_id
for experiment in self.experiments:
for experiment in self.exp_dict_short_id.values():
if exp_id.endswith(experiment.get_short_id()):
return True
return False

def get_exp_count(self) -> int:
"""Returns the number of experiments"""
return len(self.experiments)
return len(self.exp_dict_short_id.keys())

def get_exp_prefix(self, exp_id) -> str:
"""extracts the prefix from the target exp data"""
Expand Down Expand Up @@ -74,7 +72,7 @@ def get_best_experiments(
"""
if mode not in ["max", "min"]:
raise Exception(f"mode is not max or min but : {mode}")
exp_list = self.experiments
exp_list = list(self.exp_dict_short_id.values())
number_of_exps = min(number_of_exps, len(exp_list))
value_exps = [
(exp, exp.get_best_metric_value(metric_name, mode))
Expand All @@ -93,7 +91,7 @@ def get_best_experiments(
def get_metrics(self, experiments: Optional[List[str]] = None) -> List[str]:
"""Returns a list of all metrics used in the experiments"""
metric_set: Set[str] = set()
for cur_exp in self.experiments:
for cur_exp in self.exp_dict_short_id.values():
if experiments is not None and cur_exp.get_short_id() not in experiments:
continue
metric_set.update(cur_exp.get_metrics())
Expand All @@ -102,15 +100,15 @@ def get_metrics(self, experiments: Optional[List[str]] = None) -> List[str]:

def is_exp_modified(self, exp_info: ExperimentInfo) -> bool:
"""Checks if the experiment has been modified"""
if exp_info.short_id not in self.exp_dict:
if exp_info.short_id not in self.exp_dict_short_id:
return True
exp = self.get_exp_by_id(exp_info.short_id)
return exp.exp_info.is_modified(exp_info)

def get_datasets(self) -> List[str]:
"""Returns a list of all datasets used in the experiments"""
dataset_set: Set[str] = set()
for cur_exp in self.experiments:
for cur_exp in self.exp_dict_short_id.values():
dataset_set.add(cur_exp.get_experiment_path().split("/")[0])

return sorted(list(dataset_set))
Expand All @@ -122,14 +120,18 @@ def get_experiment_type(self, experiment: ExperimentData) -> str:
def get_experiment_types(self) -> List[str]:
"""Returns a list of all experiment types"""
experiment_type_set: Set[str] = set()
for cur_exp in self.experiments:
for cur_exp in self.exp_dict_short_id.values():
experiment_type_set.add(self.get_experiment_type(cur_exp))

return sorted(list(experiment_type_set))

def get_experiments(self) -> List[ExperimentData]:
"""Returns a sorted list of all experiments (newest first)"""
return sorted(self.experiments, reverse=True, key=lambda x: x.get_run_id())
return sorted(
list(self.exp_dict_short_id.values()),
reverse=True,
key=lambda x: x.get_run_id(),
)

def get_exp_by_id(self, exp_id: str) -> ExperimentData:
"""
Expand All @@ -152,19 +154,21 @@ def get_exp_by_id(self, exp_id: str) -> ExperimentData:
"""
if exp_id.lower() == "latest":
ret_exp = sorted(
self.experiments, reverse=True, key=lambda x: x.get_run_id()
list(self.exp_dict_short_id.values()),
reverse=True,
key=lambda x: x.get_run_id(),
)[0]
else:
ret_exp = self.exp_dict[exp_id]
ret_exp = self.exp_dict_short_id[exp_id]
return ret_exp

def get_empty_exps(
self, id_list: Optional[List[str]] = None
) -> List[ExperimentData]:
"""Finds all experiments which are empty"""
if id_list is None:
id_list = [x.get_short_id() for x in self.experiments]
exp_list = [self.exp_dict[x] for x in id_list]
id_list = [x.get_short_id() for x in list(self.exp_dict_short_id.values())]
exp_list = [self.exp_dict_short_id[x] for x in id_list]
empty_list = [x for x in exp_list if x.is_empty()]
return empty_list

Expand Down Expand Up @@ -244,7 +248,7 @@ def get_value_information_dict(
) -> Dict[Any, List[str]]:
"""Returns a dict with information about the values"""
value_information_dict = defaultdict(list)
for exp in self.experiments:
for exp in self.exp_dict_short_id.values():
try:
exp_info = exp.get_config_information(info_path)
if type(exp_info) is list:
Expand All @@ -258,14 +262,14 @@ def get_value_information_dict(
def get_epochs_information_dict(self) -> Dict[int, List[str]]:
"""Returns a dict with information about the trained epochs"""
epochs_information_dict = defaultdict(list)
for exp in self.experiments:
for exp in self.exp_dict_short_id.values():
epochs_information_dict[exp.get_trained_epochs()].append(exp.get_short_id())
return epochs_information_dict

def get_datasets_information_dict(self) -> Dict[str, List[str]]:
"""Returns a dict with information about the datasets"""
datasets_information_dict = defaultdict(list)
for exp in self.experiments:
for exp in self.exp_dict_short_id.values():
dataset = exp.get_experiment_path().split("/")[0]
datasets_information_dict[dataset].append(exp.get_short_id())
return datasets_information_dict
Expand All @@ -278,7 +282,7 @@ def get_dataset(self, exp: ExperimentData) -> str:
def get_date_information_dict(self) -> Dict[date, List[str]]:
"""Returns a dict with information about the dates"""
date_information_dict = defaultdict(list)
for exp in self.experiments:
for exp in self.exp_dict_short_id.values():
date_string = exp.exp_info.run_id.split("T")[0]
date = datetime.strptime(date_string, "%Y-%m-%d").date()
date_information_dict[date].append(exp.get_short_id())
Expand All @@ -287,15 +291,15 @@ def get_date_information_dict(self) -> Dict[date, List[str]]:
def get_experiment_type_information_dict(self) -> Dict[str, List[str]]:
"""Returns a dict with information about the experiment types"""
experiment_type_information_dict = defaultdict(list)
for exp in self.experiments:
for exp in self.exp_dict_short_id.values():
experiment_type = exp.get_experiment_path().split("/")[-1].split("-")[0]
experiment_type_information_dict[experiment_type].append(exp.get_short_id())
return experiment_type_information_dict

def get_max_trained_epochs(self) -> int:
"""Returns the max epochs of all trained experiments"""
max_epochs = 0
for exp in self.experiments:
for exp in self.exp_dict_short_id.values():
max_epochs = max(max_epochs, exp.get_trained_epochs())
return max_epochs

Expand Down

0 comments on commit 9a54dc9

Please sign in to comment.