diff --git a/src/wf_psf/plotting/plots_interface.py b/src/wf_psf/plotting/plots_interface.py index 262628ad..30e22674 100644 --- a/src/wf_psf/plotting/plots_interface.py +++ b/src/wf_psf/plotting/plots_interface.py @@ -232,7 +232,12 @@ def plot(self): y_right_axis_label="Relative error [%]", filename=os.path.join( self.plots_dir, - plot_dataset + "_" + self.metric_name + "_RMSE.png", + plot_dataset + + "_" + + self.metric_name + + "_nstars_" + + "_".join(str(nstar) for nstar in self.list_of_stars) + + "_RMSE.png", ), plot_show=self.plotting_params.plot_show, ) @@ -254,6 +259,8 @@ class MonochromaticMetricsPlotHandler: Dictionary containing the metric configurations as RecursiveNamespace objects for each run metrics: list Dictionary containing list of metrics + list_of_stars: list + List containing the number of stars used for each training data set plots_dir: str Output directory for metrics plots @@ -261,10 +268,18 @@ class MonochromaticMetricsPlotHandler: ids = ("mono_metrics",) - def __init__(self, plotting_params, metrics_confs, metrics, plots_dir): + def __init__( + self, + plotting_params, + metrics_confs, + metrics, + list_of_stars, + plots_dir, + ): self.plotting_params = plotting_params self.metrics_confs = metrics_confs self.metrics = metrics + self.list_of_stars = list_of_stars self.plots_dir = plots_dir def plot(self): @@ -289,16 +304,16 @@ def plot(self): metrics_id.append(run_id + "-" + k) y_axis.append( { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset]["mono_metric"][ - "rmse_lda" - ] + (k + "-" + run_id): metrics_data[run_id][0][ + plot_dataset + ]["mono_metric"]["rmse_lda"] } ) y_axis_err.append( { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset]["mono_metric"][ - "std_rmse_lda" - ] + (k + "-" + run_id): metrics_data[run_id][0][ + plot_dataset + ]["mono_metric"]["std_rmse_lda"] } ) @@ -315,7 +330,12 @@ def plot(self): y_right_axis_label="Relative error [%]", filename=os.path.join( self.plots_dir, - (plot_dataset + "_monochrom_pixel_RMSE.png"), + ( + plot_dataset + + "_nstars_" + + "_".join(str(nstar) for nstar in self.list_of_stars) + + "_monochrom_pixel_RMSE.png" + ), ), plot_show=self.plotting_params.plot_show, ) @@ -335,6 +355,8 @@ class ShapeMetricsPlotHandler: Recursive Namespace Object containing plotting parameters metrics: list Dictionary containing list of metrics + list_of_stars: list + List containing the number of stars used for each training data set plots_dir: str Output directory for metrics plots @@ -417,7 +439,6 @@ def plot(self): std_rmse_R2_meanR2.append( { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ "shape_results_dict" ]["std_rmse_R2_meanR2"] @@ -435,7 +456,10 @@ def plot(self): y_right_axis_label="Relative error [%]", filename=os.path.join( self.plots_dir, - plot_dataset + "_Shape_RMSE.png", + plot_dataset + + "_nstars_" + + "_".join(str(nstar) for nstar in self.list_of_stars) + + "_Shape_RMSE.png", ), plot_show=self.plotting_params.plot_show, ) @@ -447,6 +471,10 @@ def get_number_of_stars(metrics): A function to get the number of stars used in training the model. + Parameters + ---------- + metrics: dict + A dictionary containig the metrics results per run Returns ------- list_of_stars: list @@ -511,7 +539,7 @@ def plot_metrics(plotting_params, list_of_metrics, metrics_confs, plot_saving_pa metrics_plot.plot() monochrom_metrics_plot = MonochromaticMetricsPlotHandler( - plotting_params, metrics_confs, list_of_metrics, plot_saving_path + plotting_params, metrics_confs, list_of_metrics, list_of_stars, plot_saving_path ) monochrom_metrics_plot.plot()