diff --git a/niceml/experiments/experimentmanager.py b/niceml/experiments/experimentmanager.py index 2e99f38..3d0a8c9 100644 --- a/niceml/experiments/experimentmanager.py +++ b/niceml/experiments/experimentmanager.py @@ -25,28 +25,26 @@ class ExperimentManager(object): def __init__(self, experiments: List[ExperimentData] = None): """Manages a list of experiments""" - self.experiments = [] if experiments is None else experiments - self.exp_dict = {exp.get_short_id(): exp for exp in self.experiments} - self.exp_dict.update({exp.get_run_id(): exp for exp in self.experiments}) + self.exp_dict_short_id = {exp.get_short_id(): exp for exp in experiments} + self.exp_dict_run_id = {exp.get_run_id(): exp for exp in experiments} def add_experiment(self, experiment: ExperimentData): """Adds an experiment to the manager""" - self.experiments.append(experiment) - self.exp_dict[experiment.get_short_id()] = experiment - self.exp_dict[experiment.get_run_id()] = experiment + self.exp_dict_short_id[experiment.get_short_id()] = experiment + self.exp_dict_run_id[experiment.get_run_id()] = experiment def __contains__(self, exp_id: Union[str, ExperimentInfo]): """Checks if the experiment is in the manager""" if type(exp_id) == ExperimentInfo: exp_id = exp_id.short_id - for experiment in self.experiments: + for experiment in self.exp_dict_short_id.values(): if exp_id.endswith(experiment.get_short_id()): return True return False def get_exp_count(self) -> int: """Returns the number of experiments""" - return len(self.experiments) + return len(self.exp_dict_short_id.keys()) def get_exp_prefix(self, exp_id) -> str: """extracts the prefix from the target exp data""" @@ -74,7 +72,7 @@ def get_best_experiments( """ if mode not in ["max", "min"]: raise Exception(f"mode is not max or min but : {mode}") - exp_list = self.experiments + exp_list = list(self.exp_dict_short_id.values()) number_of_exps = min(number_of_exps, len(exp_list)) value_exps = [ (exp, exp.get_best_metric_value(metric_name, mode)) @@ -93,7 +91,7 @@ def get_best_experiments( def get_metrics(self, experiments: Optional[List[str]] = None) -> List[str]: """Returns a list of all metrics used in the experiments""" metric_set: Set[str] = set() - for cur_exp in self.experiments: + for cur_exp in self.exp_dict_short_id.values(): if experiments is not None and cur_exp.get_short_id() not in experiments: continue metric_set.update(cur_exp.get_metrics()) @@ -102,7 +100,7 @@ def get_metrics(self, experiments: Optional[List[str]] = None) -> List[str]: def is_exp_modified(self, exp_info: ExperimentInfo) -> bool: """Checks if the experiment has been modified""" - if exp_info.short_id not in self.exp_dict: + if exp_info.short_id not in self.exp_dict_short_id: return True exp = self.get_exp_by_id(exp_info.short_id) return exp.exp_info.is_modified(exp_info) @@ -110,7 +108,7 @@ def is_exp_modified(self, exp_info: ExperimentInfo) -> bool: def get_datasets(self) -> List[str]: """Returns a list of all datasets used in the experiments""" dataset_set: Set[str] = set() - for cur_exp in self.experiments: + for cur_exp in self.exp_dict_short_id.values(): dataset_set.add(cur_exp.get_experiment_path().split("/")[0]) return sorted(list(dataset_set)) @@ -122,14 +120,18 @@ def get_experiment_type(self, experiment: ExperimentData) -> str: def get_experiment_types(self) -> List[str]: """Returns a list of all experiment types""" experiment_type_set: Set[str] = set() - for cur_exp in self.experiments: + for cur_exp in self.exp_dict_short_id.values(): experiment_type_set.add(self.get_experiment_type(cur_exp)) return sorted(list(experiment_type_set)) def get_experiments(self) -> List[ExperimentData]: """Returns a sorted list of all experiments (newest first)""" - return sorted(self.experiments, reverse=True, key=lambda x: x.get_run_id()) + return sorted( + list(self.exp_dict_short_id.values()), + reverse=True, + key=lambda x: x.get_run_id(), + ) def get_exp_by_id(self, exp_id: str) -> ExperimentData: """ @@ -152,10 +154,12 @@ def get_exp_by_id(self, exp_id: str) -> ExperimentData: """ if exp_id.lower() == "latest": ret_exp = sorted( - self.experiments, reverse=True, key=lambda x: x.get_run_id() + list(self.exp_dict_short_id.values()), + reverse=True, + key=lambda x: x.get_run_id(), )[0] else: - ret_exp = self.exp_dict[exp_id] + ret_exp = self.exp_dict_short_id[exp_id] return ret_exp def get_empty_exps( @@ -163,8 +167,8 @@ def get_empty_exps( ) -> List[ExperimentData]: """Finds all experiments which are empty""" if id_list is None: - id_list = [x.get_short_id() for x in self.experiments] - exp_list = [self.exp_dict[x] for x in id_list] + id_list = [x.get_short_id() for x in list(self.exp_dict_short_id.values())] + exp_list = [self.exp_dict_short_id[x] for x in id_list] empty_list = [x for x in exp_list if x.is_empty()] return empty_list @@ -244,7 +248,7 @@ def get_value_information_dict( ) -> Dict[Any, List[str]]: """Returns a dict with information about the values""" value_information_dict = defaultdict(list) - for exp in self.experiments: + for exp in self.exp_dict_short_id.values(): try: exp_info = exp.get_config_information(info_path) if type(exp_info) is list: @@ -258,14 +262,14 @@ def get_value_information_dict( def get_epochs_information_dict(self) -> Dict[int, List[str]]: """Returns a dict with information about the trained epochs""" epochs_information_dict = defaultdict(list) - for exp in self.experiments: + for exp in self.exp_dict_short_id.values(): epochs_information_dict[exp.get_trained_epochs()].append(exp.get_short_id()) return epochs_information_dict def get_datasets_information_dict(self) -> Dict[str, List[str]]: """Returns a dict with information about the datasets""" datasets_information_dict = defaultdict(list) - for exp in self.experiments: + for exp in self.exp_dict_short_id.values(): dataset = exp.get_experiment_path().split("/")[0] datasets_information_dict[dataset].append(exp.get_short_id()) return datasets_information_dict @@ -278,7 +282,7 @@ def get_dataset(self, exp: ExperimentData) -> str: def get_date_information_dict(self) -> Dict[date, List[str]]: """Returns a dict with information about the dates""" date_information_dict = defaultdict(list) - for exp in self.experiments: + for exp in self.exp_dict_short_id.values(): date_string = exp.exp_info.run_id.split("T")[0] date = datetime.strptime(date_string, "%Y-%m-%d").date() date_information_dict[date].append(exp.get_short_id()) @@ -287,7 +291,7 @@ def get_date_information_dict(self) -> Dict[date, List[str]]: def get_experiment_type_information_dict(self) -> Dict[str, List[str]]: """Returns a dict with information about the experiment types""" experiment_type_information_dict = defaultdict(list) - for exp in self.experiments: + for exp in self.exp_dict_short_id.values(): experiment_type = exp.get_experiment_path().split("/")[-1].split("-")[0] experiment_type_information_dict[experiment_type].append(exp.get_short_id()) return experiment_type_information_dict @@ -295,7 +299,7 @@ def get_experiment_type_information_dict(self) -> Dict[str, List[str]]: def get_max_trained_epochs(self) -> int: """Returns the max epochs of all trained experiments""" max_epochs = 0 - for exp in self.experiments: + for exp in self.exp_dict_short_id.values(): max_epochs = max(max_epochs, exp.get_trained_epochs()) return max_epochs