Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release preparation v0.13.2 #115

Merged
merged 7 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions niceml/dashboard/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def run_dashboard(conf_instances):
exp_cache = conf_instances.get("exp_cache", None)
st.sidebar.title("Filter Experiments")

exp_manager = exp_manager_factory(id(storage))
exp_list: List[ExperimentInfo] = query_experiments(storage)
exp_manager = exp_manager_factory(handler_name)
exp_list: List[ExperimentInfo] = query_experiments(storage, handler_name)
exps_to_load = select_to_load_exps(exp_list, exp_manager)
experiments = load_experiments(
storage,
Expand Down
8 changes: 5 additions & 3 deletions niceml/dashboard/remotettrainutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@ def exp_manager_factory(*args): # pylint: disable=unused-argument
return ExperimentManager([])


def query_experiments(storage: StorageInterface) -> List[ExperimentInfo]:
def query_experiments(
storage: StorageInterface, storage_identifier: str
) -> List[ExperimentInfo]:
"""Query the experiments from the cloud storage"""

@st.cache_data(ttl=3600)
def _local_query_exps(*args): # pylint: disable=unused-argument
exp_info_list: List[ExperimentInfo] = storage.list_experiments()
return exp_info_list

return _local_query_exps(id(storage))
return _local_query_exps(storage_identifier)


def select_to_load_exps(
Expand All @@ -40,7 +42,7 @@ def select_to_load_exps(
That means which are not in the experiment manager"""
experiments_to_load = []
for exp_info in exp_info_list:
if exp_manager.is_exp_modified(exp_info.short_id, exp_info.last_modified):
if exp_manager.is_exp_modified(exp_info):
experiments_to_load.append(exp_info)
return experiments_to_load

Expand Down
2 changes: 1 addition & 1 deletion niceml/experiments/experimentinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def as_save_dict(self) -> dict:
LAST_MODIFIED_KEY: self.last_modified,
}

def is_modified(self, other) -> bool:
def is_modified(self, other: "ExperimentInfo") -> bool:
"""Checks if the other experiment info is modified"""
return self.last_modified != other.last_modified

Expand Down
58 changes: 31 additions & 27 deletions niceml/experiments/experimentmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,26 @@ class ExperimentManager(object):

def __init__(self, experiments: List[ExperimentData] = None):
"""Manages a list of experiments"""
self.experiments = [] if experiments is None else experiments
self.exp_dict = {exp.get_short_id(): exp for exp in self.experiments}
self.exp_dict.update({exp.get_run_id(): exp for exp in self.experiments})
self.exp_dict_short_id = {exp.get_short_id(): exp for exp in experiments}
self.exp_dict_run_id = {exp.get_run_id(): exp for exp in experiments}

def add_experiment(self, experiment: ExperimentData):
"""Adds an experiment to the manager"""
self.experiments.append(experiment)
self.exp_dict[experiment.get_short_id()] = experiment
self.exp_dict[experiment.get_run_id()] = experiment
self.exp_dict_short_id[experiment.get_short_id()] = experiment
self.exp_dict_run_id[experiment.get_run_id()] = experiment

def __contains__(self, exp_id: Union[str, ExperimentInfo]):
"""Checks if the experiment is in the manager"""
if type(exp_id) == ExperimentInfo:
exp_id = exp_id.short_id
for experiment in self.experiments:
for experiment in self.exp_dict_short_id.values():
if exp_id.endswith(experiment.get_short_id()):
return True
return False

def get_exp_count(self) -> int:
"""Returns the number of experiments"""
return len(self.experiments)
return len(self.exp_dict_short_id.keys())

def get_exp_prefix(self, exp_id) -> str:
"""extracts the prefix from the target exp data"""
Expand Down Expand Up @@ -74,7 +72,7 @@ def get_best_experiments(
"""
if mode not in ["max", "min"]:
raise Exception(f"mode is not max or min but : {mode}")
exp_list = self.experiments
exp_list = list(self.exp_dict_short_id.values())
number_of_exps = min(number_of_exps, len(exp_list))
value_exps = [
(exp, exp.get_best_metric_value(metric_name, mode))
Expand All @@ -93,24 +91,24 @@ def get_best_experiments(
def get_metrics(self, experiments: Optional[List[str]] = None) -> List[str]:
"""Returns a list of all metrics used in the experiments"""
metric_set: Set[str] = set()
for cur_exp in self.experiments:
for cur_exp in self.exp_dict_short_id.values():
if experiments is not None and cur_exp.get_short_id() not in experiments:
continue
metric_set.update(cur_exp.get_metrics())

return sorted(list(metric_set))

def is_exp_modified(self, exp_id: str, new_time_str: str) -> bool:
def is_exp_modified(self, exp_info: ExperimentInfo) -> bool:
"""Checks if the experiment has been modified"""
if exp_id not in self.exp_dict:
if exp_info.short_id not in self.exp_dict_short_id:
return True
exp = self.get_exp_by_id(exp_id)
return exp.exp_info.is_modified(new_time_str)
exp = self.get_exp_by_id(exp_info.short_id)
return exp.exp_info.is_modified(exp_info)

def get_datasets(self) -> List[str]:
"""Returns a list of all datasets used in the experiments"""
dataset_set: Set[str] = set()
for cur_exp in self.experiments:
for cur_exp in self.exp_dict_short_id.values():
dataset_set.add(cur_exp.get_experiment_path().split("/")[0])

return sorted(list(dataset_set))
Expand All @@ -122,14 +120,18 @@ def get_experiment_type(self, experiment: ExperimentData) -> str:
def get_experiment_types(self) -> List[str]:
"""Returns a list of all experiment types"""
experiment_type_set: Set[str] = set()
for cur_exp in self.experiments:
for cur_exp in self.exp_dict_short_id.values():
experiment_type_set.add(self.get_experiment_type(cur_exp))

return sorted(list(experiment_type_set))

def get_experiments(self) -> List[ExperimentData]:
"""Returns a sorted list of all experiments (newest first)"""
return sorted(self.experiments, reverse=True, key=lambda x: x.get_run_id())
return sorted(
list(self.exp_dict_short_id.values()),
reverse=True,
key=lambda x: x.get_run_id(),
)

def get_exp_by_id(self, exp_id: str) -> ExperimentData:
"""
Expand All @@ -152,19 +154,21 @@ def get_exp_by_id(self, exp_id: str) -> ExperimentData:
"""
if exp_id.lower() == "latest":
ret_exp = sorted(
self.experiments, reverse=True, key=lambda x: x.get_run_id()
list(self.exp_dict_short_id.values()),
reverse=True,
key=lambda x: x.get_run_id(),
)[0]
else:
ret_exp = self.exp_dict[exp_id]
ret_exp = self.exp_dict_short_id[exp_id]
return ret_exp

def get_empty_exps(
self, id_list: Optional[List[str]] = None
) -> List[ExperimentData]:
"""Finds all experiments which are empty"""
if id_list is None:
id_list = [x.get_short_id() for x in self.experiments]
exp_list = [self.exp_dict[x] for x in id_list]
id_list = [x.get_short_id() for x in list(self.exp_dict_short_id.values())]
exp_list = [self.exp_dict_short_id[x] for x in id_list]
empty_list = [x for x in exp_list if x.is_empty()]
return empty_list

Expand Down Expand Up @@ -244,7 +248,7 @@ def get_value_information_dict(
) -> Dict[Any, List[str]]:
"""Returns a dict with information about the values"""
value_information_dict = defaultdict(list)
for exp in self.experiments:
for exp in self.exp_dict_short_id.values():
try:
exp_info = exp.get_config_information(info_path)
if type(exp_info) is list:
Expand All @@ -258,14 +262,14 @@ def get_value_information_dict(
def get_epochs_information_dict(self) -> Dict[int, List[str]]:
"""Returns a dict with information about the trained epochs"""
epochs_information_dict = defaultdict(list)
for exp in self.experiments:
for exp in self.exp_dict_short_id.values():
epochs_information_dict[exp.get_trained_epochs()].append(exp.get_short_id())
return epochs_information_dict

def get_datasets_information_dict(self) -> Dict[str, List[str]]:
"""Returns a dict with information about the datasets"""
datasets_information_dict = defaultdict(list)
for exp in self.experiments:
for exp in self.exp_dict_short_id.values():
dataset = exp.get_experiment_path().split("/")[0]
datasets_information_dict[dataset].append(exp.get_short_id())
return datasets_information_dict
Expand All @@ -278,7 +282,7 @@ def get_dataset(self, exp: ExperimentData) -> str:
def get_date_information_dict(self) -> Dict[date, List[str]]:
"""Returns a dict with information about the dates"""
date_information_dict = defaultdict(list)
for exp in self.experiments:
for exp in self.exp_dict_short_id.values():
date_string = exp.exp_info.run_id.split("T")[0]
date = datetime.strptime(date_string, "%Y-%m-%d").date()
date_information_dict[date].append(exp.get_short_id())
Expand All @@ -287,15 +291,15 @@ def get_date_information_dict(self) -> Dict[date, List[str]]:
def get_experiment_type_information_dict(self) -> Dict[str, List[str]]:
"""Returns a dict with information about the experiment types"""
experiment_type_information_dict = defaultdict(list)
for exp in self.experiments:
for exp in self.exp_dict_short_id.values():
experiment_type = exp.get_experiment_path().split("/")[-1].split("-")[0]
experiment_type_information_dict[experiment_type].append(exp.get_short_id())
return experiment_type_information_dict

def get_max_trained_epochs(self) -> int:
"""Returns the max epochs of all trained experiments"""
max_epochs = 0
for exp in self.experiments:
for exp in self.exp_dict_short_id.values():
max_epochs = max(max_epochs, exp.get_trained_epochs())
return max_epochs

Expand Down
18 changes: 8 additions & 10 deletions niceml/utilities/chartutils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Module for chart utilities."""
from typing import List, Optional

import altair


def generate_hover_charts( # QUEST: still in use?
def generate_hover_charts( # noqa: PLR0913
source,
x_name: str,
text_name: str,
Expand All @@ -17,9 +18,7 @@ def generate_hover_charts( # QUEST: still in use?
if additional_layers is None:
additional_layers = []
# Create a selection that chooses the nearest point & selects based on x-value
nearest = altair.selection(
type="single", nearest=True, on="mouseover", fields=[x_name], empty="none"
)
nearest = altair.selection_point(nearest=True, on="mouseover", fields=[x_name])

# Transparent selectors across the chart. This is what tells us
# the x-value of the cursor
Expand All @@ -30,7 +29,7 @@ def generate_hover_charts( # QUEST: still in use?
x=x_name,
opacity=altair.value(0),
)
.add_selection(nearest)
.add_params(nearest)
)

# Draw points on the line, and highlight based on selection
Expand All @@ -56,7 +55,8 @@ def generate_hover_charts( # QUEST: still in use?
).properties(width=width, height=height)


def generate_chart(source, metric): # TODO: rename function and add docstrings
def generate_chart(source, metric):
"""Generates Altair chart"""
line = (
altair.Chart(source)
.mark_line()
Expand All @@ -68,9 +68,7 @@ def generate_chart(source, metric): # TODO: rename function and add docstrings
)

# Create a selection that chooses the nearest point & selects based on x-value
nearest = altair.selection(
type="single", nearest=True, on="mouseover", fields=["epoch"], empty="none"
)
nearest = altair.selection_point(nearest=True, on="mouseover", fields=["epoch"])

# Transparent selectors across the chart. This is what tells us
# the x-value of the cursor
Expand All @@ -81,7 +79,7 @@ def generate_chart(source, metric): # TODO: rename function and add docstrings
x="epoch",
opacity=altair.value(0),
)
.add_selection(nearest)
.add_params(nearest)
)

# Draw points on the line, and highlight based on selection
Expand Down
Loading
Loading