From e89e2ec0e88bdc05afa1eb8b1197c228de10ee21 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 29 Jul 2024 13:10:49 -0400 Subject: [PATCH 1/6] support global_run for local compute context, solve tcp buffer error --- dacapo/blockwise/__init__.py | 1 + dacapo/blockwise/global_vars.py | 2 + dacapo/blockwise/predict_worker.py | 181 ++++++++++-------- .../datasplits/datasets/arrays/zarr_array.py | 16 +- dacapo/predict.py | 12 +- dacapo/validate.py | 8 +- 6 files changed, 126 insertions(+), 94 deletions(-) create mode 100644 dacapo/blockwise/global_vars.py diff --git a/dacapo/blockwise/__init__.py b/dacapo/blockwise/__init__.py index 6027a9115..aa198e0d0 100644 --- a/dacapo/blockwise/__init__.py +++ b/dacapo/blockwise/__init__.py @@ -1,2 +1,3 @@ from .blockwise_task import DaCapoBlockwiseTask from .scheduler import run_blockwise, segment_blockwise +from . import global_vars diff --git a/dacapo/blockwise/global_vars.py b/dacapo/blockwise/global_vars.py new file mode 100644 index 000000000..4d3771721 --- /dev/null +++ b/dacapo/blockwise/global_vars.py @@ -0,0 +1,2 @@ +current_run = None + diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 9b5cdbf33..787787034 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -17,6 +17,7 @@ import numpy as np import click +from dacapo.blockwise import global_vars import logging @@ -27,6 +28,14 @@ path = __file__ +def is_global_run_set(run_name) -> bool: + found = global_vars.current_run is not None + if found: + found = global_vars.current_run.name == run_name + if not found: + logger.error(f"Found global run {global_vars.current_run.name} but looking for {run_name}") + return found + @click.group() @click.option( "--log-level", @@ -70,7 +79,7 @@ def cli(log_level): ) @click.option("-od", "--output_dataset", required=True, type=str) def start_worker( - run_name: str | Run, + run_name: str, iteration: int | None, input_container: Path | str, input_dataset: str, @@ -90,7 +99,7 @@ def start_worker( def start_worker_fn( - run_name: str | Run, + run_name: str, iteration: int | None, input_container: Path | str, input_dataset: str, @@ -109,93 +118,95 @@ def start_worker_fn( output_container (Path | str): The output container. output_dataset (str): The output dataset. """ - compute_context = create_compute_context() - device = compute_context.device + def io_loop(): + daisy_client = daisy.Client() - # retrieving run - if isinstance(run_name, Run): - run = run_name - run_name = run.name - else: - config_store = create_config_store() - run_config = config_store.retrieve_run_config(run_name) - run = Run(run_config) + compute_context = create_compute_context() + device = compute_context.device + + if is_global_run_set(run_name): + logger.warning("Using global run variable") + run = global_vars.current_run + else: + logger.warning("initiating local run in predict_worker") + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + + if iteration is not None and compute_context.distribute_workers: + # create weights store + weights_store = create_weights_store() + + # load weights + run.model.load_state_dict( + weights_store.retrieve_weights(run_name, iteration).model + ) - if iteration is not None and compute_context.distribute_workers: - # create weights store - weights_store = create_weights_store() + # get arrays + input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) + raw_array = ZarrArray.open_from_array_identifier(input_array_identifier) - # load weights - run.model.load_state_dict( - weights_store.retrieve_weights(run_name, iteration).model + output_array_identifier = LocalArrayIdentifier( + Path(output_container), output_dataset + ) + output_array = ZarrArray.open_from_array_identifier(output_array_identifier) + + # set benchmark flag to True for performance + torch.backends.cudnn.benchmark = True + + # get the model's input and output size + model = run.model.eval() + # .to(device) + input_voxel_size = Coordinate(raw_array.voxel_size) + output_voxel_size = model.scale(input_voxel_size) + input_shape = Coordinate(model.eval_input_shape) + input_size = input_voxel_size * input_shape + output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] + + print(f"Predicting with input size {input_size}, output size {output_size}") + + # create gunpowder keys + + raw = gp.ArrayKey("RAW") + prediction = gp.ArrayKey("PREDICTION") + + # assemble prediction pipeline + + # prepare data source + pipeline = DaCapoArraySource(raw_array, raw) + # raw: (c, d, h, w) + pipeline += gp.Pad(raw, None) + # raw: (c, d, h, w) + pipeline += gp.Unsqueeze([raw]) + # raw: (1, c, d, h, w) + + pipeline += gp.Normalize(raw) + + # predict + # model.eval() + pipeline += gp_torch.Predict( + model=model, + inputs={"x": raw}, + outputs={0: prediction}, + array_specs={ + prediction: gp.ArraySpec( + voxel_size=output_voxel_size, + dtype=np.float32, # assumes network output is float32 + ) + }, + spawn_subprocess=False, + device=str(device), ) - # get arrays - input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) - raw_array = ZarrArray.open_from_array_identifier(input_array_identifier) - - output_array_identifier = LocalArrayIdentifier( - Path(output_container), output_dataset - ) - output_array = ZarrArray.open_from_array_identifier(output_array_identifier) - - # set benchmark flag to True for performance - torch.backends.cudnn.benchmark = True - - # get the model's input and output size - model = run.model.eval().to(device) - input_voxel_size = Coordinate(raw_array.voxel_size) - output_voxel_size = model.scale(input_voxel_size) - input_shape = Coordinate(model.eval_input_shape) - input_size = input_voxel_size * input_shape - output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] - - print(f"Predicting with input size {input_size}, output size {output_size}") - - # create gunpowder keys - - raw = gp.ArrayKey("RAW") - prediction = gp.ArrayKey("PREDICTION") - - # assemble prediction pipeline - - # prepare data source - pipeline = DaCapoArraySource(raw_array, raw) - # raw: (c, d, h, w) - pipeline += gp.Pad(raw, None) - # raw: (c, d, h, w) - pipeline += gp.Unsqueeze([raw]) - # raw: (1, c, d, h, w) - - pipeline += gp.Normalize(raw) - - # predict - # model.eval() - pipeline += gp_torch.Predict( - model=model, - inputs={"x": raw}, - outputs={0: prediction}, - array_specs={ - prediction: gp.ArraySpec( - voxel_size=output_voxel_size, - dtype=np.float32, # assumes network output is float32 - ) - }, - spawn_subprocess=False, - device=str(device), - ) - - # make reference batch request - request = gp.BatchRequest() - request.add(raw, input_size, voxel_size=input_voxel_size) - request.add( - prediction, - output_size, - voxel_size=output_voxel_size, - ) + # make reference batch request + request = gp.BatchRequest() + request.add(raw, input_size, voxel_size=input_voxel_size) + request.add( + prediction, + output_size, + voxel_size=output_voxel_size, + ) - def io_loop(): - daisy_client = daisy.Client() while True: with daisy_client.acquire_block() as block: @@ -231,7 +242,7 @@ def io_loop(): def spawn_worker( - run_name: str | Run, + run_name: str, iteration: int | None, input_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", @@ -248,6 +259,8 @@ def spawn_worker( Callable: The function to run the worker. """ compute_context = create_compute_context() + + if not compute_context.distribute_workers: return start_worker_fn( run_name=run_name, diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index f61bf0cd4..30c6ac693 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -369,11 +369,17 @@ def data(self) -> Any: """ file_name = str(self.file_name) # Zarr library does not detect the store for N5 datasets - if file_name.endswith(".n5"): - zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode) - else: - zarr_container = zarr.open(str(file_name), mode=self.mode) - return zarr_container[self.dataset] + try: + if file_name.endswith(".n5"): + zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode) + else: + zarr_container = zarr.open(str(file_name), mode=self.mode) + return zarr_container[self.dataset] + except Exception as e: + logger.error( + f"Could not open dataset {self.dataset} in file {file_name} in mode {self.mode}" + ) + raise e def __getitem__(self, roi: Roi) -> np.ndarray: """ diff --git a/dacapo/predict.py b/dacapo/predict.py index 09ac848cc..f79f26a1e 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -1,5 +1,5 @@ from upath import UPath as Path - +from dacapo.blockwise import global_vars from dacapo.blockwise import run_blockwise import dacapo.blockwise from dacapo.experiments import Run @@ -24,7 +24,7 @@ def predict( input_dataset: str, output_path: LocalArrayIdentifier | Path | str, output_roi: Optional[Roi | str] = None, - num_workers: int = 12, + num_workers: int = 1, output_dtype: np.dtype | str = np.uint8, # type: ignore overwrite: bool = True, ): @@ -136,10 +136,13 @@ def predict( write_size=output_size, ) + global_vars.current_run = run + + # run blockwise prediction worker_file = str(Path(Path(dacapo.blockwise.__file__).parent, "predict_worker.py")) print("Running blockwise prediction with worker_file: ", worker_file) - run_blockwise( + success = run_blockwise( worker_file=worker_file, total_roi=_input_roi, read_roi=Roi((0, 0, 0), input_size), @@ -148,9 +151,10 @@ def predict( max_retries=2, # TODO: make this an option timeout=None, # TODO: make this an option ###### - run_name=run, + run_name=run.name, iteration=iteration, input_array_identifier=input_array_identifier, output_array_identifier=output_array_identifier, ) print("Done predicting.") + return success diff --git a/dacapo/validate.py b/dacapo/validate.py index 398308df2..0da9dfa30 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -186,7 +186,7 @@ def validate( run.name, iteration, validation_dataset.name ) compute_context = create_compute_context() - predict( + sucess = predict( run, iteration if compute_context.distribute_workers else None, input_container=input_raw_array_identifier.container, @@ -198,6 +198,12 @@ def validate( overwrite=overwrite, ) + if not sucess: + logger.error( + f"Could not predict run {run.name} on dataset {validation_dataset.name}." + ) + continue + print(f"Predicted on dataset {validation_dataset.name}") post_processor.set_prediction(prediction_array_identifier) From 9ea893c2bf896c5794bff48af466b01b4f00f666 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 29 Jul 2024 13:17:58 -0400 Subject: [PATCH 2/6] black format --- dacapo/blockwise/global_vars.py | 1 - dacapo/blockwise/predict_worker.py | 12 ++++++++---- dacapo/experiments/datasplits/datasplit_generator.py | 2 +- dacapo/predict.py | 1 - 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/dacapo/blockwise/global_vars.py b/dacapo/blockwise/global_vars.py index 4d3771721..0c804e3ff 100644 --- a/dacapo/blockwise/global_vars.py +++ b/dacapo/blockwise/global_vars.py @@ -1,2 +1 @@ current_run = None - diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 787787034..867c9554b 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -33,9 +33,12 @@ def is_global_run_set(run_name) -> bool: if found: found = global_vars.current_run.name == run_name if not found: - logger.error(f"Found global run {global_vars.current_run.name} but looking for {run_name}") + logger.error( + f"Found global run {global_vars.current_run.name} but looking for {run_name}" + ) return found + @click.group() @click.option( "--log-level", @@ -118,6 +121,7 @@ def start_worker_fn( output_container (Path | str): The output container. output_dataset (str): The output dataset. """ + def io_loop(): daisy_client = daisy.Client() @@ -143,7 +147,9 @@ def io_loop(): ) # get arrays - input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) + input_array_identifier = LocalArrayIdentifier( + Path(input_container), input_dataset + ) raw_array = ZarrArray.open_from_array_identifier(input_array_identifier) output_array_identifier = LocalArrayIdentifier( @@ -207,7 +213,6 @@ def io_loop(): voxel_size=output_voxel_size, ) - while True: with daisy_client.acquire_block() as block: if block is None: @@ -260,7 +265,6 @@ def spawn_worker( """ compute_context = create_compute_context() - if not compute_context.distribute_workers: return start_worker_fn( run_name=run_name, diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index ce229deee..d3a6cb7d6 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -757,7 +757,7 @@ def __generate_semantic_seg_datasplit(self): mask_config=mask_config, ) ) - + return TrainValidateDataSplitConfig( name=f"{self.name}_{self.segmentation_type}_{classes}_{self.output_resolution[0]}nm", train_configs=train_dataset_configs, diff --git a/dacapo/predict.py b/dacapo/predict.py index f79f26a1e..f28e97663 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -138,7 +138,6 @@ def predict( global_vars.current_run = run - # run blockwise prediction worker_file = str(Path(Path(dacapo.blockwise.__file__).parent, "predict_worker.py")) print("Running blockwise prediction with worker_file: ", worker_file) From 0cf9e2557a6831c7b2dd23580bf998b9e35cd787 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 30 Jul 2024 11:37:02 -0400 Subject: [PATCH 3/6] fix plot hack --- .../threshold_post_processor.py | 7 ++- dacapo/plot.py | 49 +++++++++---------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index c0e10418c..f99c64d3a 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -68,7 +68,7 @@ def process( self, parameters: "ThresholdPostProcessorParameters", # type: ignore[override] output_array_identifier: "LocalArrayIdentifier", - num_workers: int = 16, + num_workers: int = 12, block_size: Coordinate = Coordinate((256, 256, 256)), ) -> ZarrArray: """ @@ -122,7 +122,7 @@ def process( read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :]) # run blockwise post-processing - run_blockwise( + sucess = run_blockwise( worker_file=str( Path(Path(dacapo.blockwise.__file__).parent, "threshold_worker.py") ), @@ -138,4 +138,7 @@ def process( threshold=parameters.threshold, ) + if not sucess: + raise RuntimeError("Blockwise post-processing failed.") + return output_array diff --git a/dacapo/plot.py b/dacapo/plot.py index e86f697b3..3b12a52d8 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -3,9 +3,8 @@ from dacapo.store.create_store import create_config_store, create_stats_store from dacapo.experiments.run import Run -from bokeh.palettes import Category20 as palette -import bokeh.layouts -import bokeh.plotting +from dacapo.plotting.plot_handler import PlotHandler, RunInfo +from bokeh.plotting.matplot_plot_handler import MatplotPlotHandler import numpy as np from collections import namedtuple @@ -104,7 +103,7 @@ def get_runs_info( run_config.trainer_config.name, run_config.datasplit_config.name, ( - stats_store.retrieve_training_stats(run_config_name, subsample=True) + stats_store.retrieve_training_stats(run_config_name) if plot_loss else None ), @@ -159,7 +158,7 @@ def plot_runs( tools="pan, wheel_zoom, reset, save, hover", x_axis_label="iterations", tooltips=loss_tooltips, - plot_width=2048, + # plot_width=2048, ) loss_figure.background_fill_color = "#efefef" @@ -202,7 +201,7 @@ def plot_runs( tools="pan, wheel_zoom, reset, save, hover", x_axis_label="iterations", tooltips=validation_tooltips, - plot_width=2048, + # plot_width=2048, ) validation_figure.background_fill_color = "#efefef" validation_figures[dataset.name] = validation_figure @@ -226,7 +225,7 @@ def plot_runs( x_axis_label="model size", y_axis_label="best validation", tooltips=summary_tooltips, - plot_width=2048, + # plot_width=2048, ) summary_figure.background_fill_color = "#efefef" @@ -297,24 +296,24 @@ def plot_runs( "run": [run.name] * len(x), } # TODO: get_best: higher_is_better is not true for all scores - best_parameters, best_scores = run.validation_scores.get_best( - dataset_data, dim="parameters" - ) - - source_dict.update( - { - name: np.array( - [ - getattr(best_parameter, name) - for best_parameter in best_parameters.values - ] - ) - for name in run.validation_scores.parameter_names - } - ) - source_dict.update( - {run.validation_score_name: np.array(best_scores.values)} - ) + # best_parameters, best_scores = run.validation_scores.get_best( + # dataset_data, dim="parameters" + # ) + + # source_dict.update( + # { + # name: np.array( + # [ + # getattr(best_parameter, name) + # for best_parameter in best_parameters.values + # ] + # ) + # for name in run.validation_scores.parameter_names + # } + # ) + # source_dict.update( + # {run.validation_score_name: np.array(best_scores.values)} + # ) source = bokeh.plotting.ColumnDataSource(source_dict) validation_figures[dataset.name].line( From 276d48138f60878bc6eac4c0be61ead204ce9d6d Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 30 Jul 2024 11:39:37 -0400 Subject: [PATCH 4/6] fix import --- dacapo/plot.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dacapo/plot.py b/dacapo/plot.py index 3b12a52d8..9829dfd60 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -3,8 +3,9 @@ from dacapo.store.create_store import create_config_store, create_stats_store from dacapo.experiments.run import Run -from dacapo.plotting.plot_handler import PlotHandler, RunInfo -from bokeh.plotting.matplot_plot_handler import MatplotPlotHandler +from bokeh.palettes import Category20 as palette +import bokeh.layouts +import bokeh.plotting import numpy as np from collections import namedtuple From 24c963aec25b80eb9704a7a07e9dd4f8fa145dbe Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 31 Jul 2024 16:41:57 -0400 Subject: [PATCH 5/6] matplotlib plot --- dacapo/plot.py | 90 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/dacapo/plot.py b/dacapo/plot.py index 9829dfd60..8d9d3c2dd 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -7,11 +7,18 @@ import bokeh.layouts import bokeh.plotting import numpy as np +from tqdm import tqdm from collections import namedtuple import itertools from typing import List +import matplotlib.pyplot as plt + +import os + + + RunInfo = namedtuple( "RunInfo", [ @@ -117,7 +124,7 @@ def get_runs_info( return runs -def plot_runs( +def bokeh_plot_runs( run_config_base_names, smooth=100, validation_scores=None, @@ -384,3 +391,84 @@ def plot_runs( else: bokeh.plotting.output_file("performance_plots.html") bokeh.plotting.save(plot) + + +def plot_runs( + run_config_base_names, + smooth=100, + validation_scores=None, + higher_is_betters=None, + plot_losses=None, +): + """ + Plot runs. + Args: + run_config_base_names: Names of run configs to plot + smooth: Smoothing factor + validation_scores: Validation scores to plot + higher_is_betters: Whether higher is better + plot_losses: Whether to plot losses + Returns: + None + """ + print("PLOTTING RUNS") + runs = get_runs_info(run_config_base_names, validation_scores, plot_losses) + print("GOT RUNS INFO") + + colors = itertools.cycle(plt.cm.tab20.colors) + include_validation_figure = False + include_loss_figure = False + + fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(15, 10)) + loss_ax = axes[0] + validation_ax = axes[1] + + for run, color in zip(runs, colors): + name = run.name + + if run.plot_loss: + iterations = [stat.iteration for stat in run.training_stats.iteration_stats] + losses = [stat.loss for stat in run.training_stats.iteration_stats] + + print(f"Run {run.name} has {len(losses)} iterations") + + if run.plot_loss: + include_loss_figure = True + smooth = int(np.maximum(len(iterations) / 2500, 1)) + print(f"smoothing: {smooth}") + x, _ = smooth_values(iterations, smooth, stride=smooth) + y, s = smooth_values(losses, smooth, stride=smooth) + print(x, y) + print(f"plotting {(len(x), len(y))} points") + loss_ax.plot(x, y, label=name, color=color) + print("LOSS PLOTTED") + + if run.validation_score_name and run.validation_scores.validated_until() > 0: + validation_score_data = run.validation_scores.to_xarray().sel( + criteria=run.validation_score_name + ) + colors_val = itertools.cycle(plt.cm.tab20.colors) + for dataset,color_v in zip(run.validation_scores.datasets,colors_val): + dataset_data = validation_score_data.sel(datasets=dataset) + include_validation_figure = True + x = [score.iteration for score in run.validation_scores.scores] + cc = next(colors_val) + for i in range(dataset_data.data.shape[1]): + current_name = f"{i}_{dataset.name}_{name}_{run.validation_score_name}" + validation_ax.plot(x, dataset_data.data[:,i] , label=current_name, color=cc, alpha=0.5+0.2*i) + print("VALIDATION PLOTTED") + + if include_loss_figure: + loss_ax.set_title("Training") + loss_ax.set_xlabel("Iterations") + loss_ax.set_ylabel("Loss") + loss_ax.legend() + + if include_validation_figure: + validation_ax.set_title("Validation") + validation_ax.set_xlabel("Iterations") + validation_ax.set_ylabel("Validation Score") + validation_ax.legend() + + plt.tight_layout() + plt.show() \ No newline at end of file From 7f534cfe22debd002185b97acbff49260ffd0931 Mon Sep 17 00:00:00 2001 From: mzouink Date: Wed, 31 Jul 2024 20:42:31 +0000 Subject: [PATCH 6/6] :art: Format Python code with psf/black --- dacapo/plot.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/dacapo/plot.py b/dacapo/plot.py index 8d9d3c2dd..d5bfe1d28 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -17,8 +17,7 @@ import os - - + RunInfo = namedtuple( "RunInfo", [ @@ -448,14 +447,22 @@ def plot_runs( criteria=run.validation_score_name ) colors_val = itertools.cycle(plt.cm.tab20.colors) - for dataset,color_v in zip(run.validation_scores.datasets,colors_val): + for dataset, color_v in zip(run.validation_scores.datasets, colors_val): dataset_data = validation_score_data.sel(datasets=dataset) include_validation_figure = True x = [score.iteration for score in run.validation_scores.scores] cc = next(colors_val) for i in range(dataset_data.data.shape[1]): - current_name = f"{i}_{dataset.name}_{name}_{run.validation_score_name}" - validation_ax.plot(x, dataset_data.data[:,i] , label=current_name, color=cc, alpha=0.5+0.2*i) + current_name = ( + f"{i}_{dataset.name}_{name}_{run.validation_score_name}" + ) + validation_ax.plot( + x, + dataset_data.data[:, i], + label=current_name, + color=cc, + alpha=0.5 + 0.2 * i, + ) print("VALIDATION PLOTTED") if include_loss_figure: @@ -471,4 +478,4 @@ def plot_runs( validation_ax.legend() plt.tight_layout() - plt.show() \ No newline at end of file + plt.show()