diff --git a/src/wf_psf/plotting/plots_interface.py b/src/wf_psf/plotting/plots_interface.py index 30e22674..fcd6992c 100644 --- a/src/wf_psf/plotting/plots_interface.py +++ b/src/wf_psf/plotting/plots_interface.py @@ -49,6 +49,7 @@ def make_plot( x_axis, y_axis, y_axis_err, + y2_axis, label, plot_title, x_axis_label, @@ -69,6 +70,8 @@ def make_plot( y-axis values y_axis_err: list Error values for y-axis points + y2_axis: list + y2-axis values for right axis label: str Label for the points plot_title: str @@ -114,7 +117,7 @@ def make_plot( kwargs = dict( linewidth=2, linestyle="dashed", markersize=4, marker="^", alpha=0.5 ) - ax2.plot(x_axis[it], y_axis[it][k], **kwargs) + ax2.plot(x_axis[it], y2_axis[it][k], **kwargs) plt.savefig(filename) @@ -158,6 +161,7 @@ def __init__( metric_name, rmse, std_rmse, + rel_rmse, plot_title, plots_dir, ): @@ -166,6 +170,7 @@ def __init__( self.metric_name = metric_name self.rmse = rmse self.std_rmse = std_rmse + self.rel_rmse = rel_rmse self.plot_title = plot_title self.plots_dir = plots_dir self.list_of_stars = list_of_stars @@ -189,6 +194,7 @@ def get_metrics(self, dataset): """ rmse = [] std_rmse = [] + rel_rmse = [] metrics_id = [] for k, v in self.metrics.items(): for metrics_data in v: @@ -210,7 +216,15 @@ def get_metrics(self, dataset): } ) - return metrics_id, rmse, std_rmse + rel_rmse.append( + { + (k + "-" + run_id): metrics_data[run_id][0][dataset][ + self.metric_name + ][self.rel_rmse] + } + ) + + return metrics_id, rmse, std_rmse, rel_rmse def plot(self): """Plot. @@ -220,11 +234,12 @@ def plot(self): """ for plot_dataset in ["test_metrics", "train_metrics"]: - metrics_id, rmse, std_rmse = self.get_metrics(plot_dataset) + metrics_id, rmse, std_rmse, rel_rmse = self.get_metrics(plot_dataset) make_plot( x_axis=self.list_of_stars, y_axis=rmse, y_axis_err=std_rmse, + y2_axis=rel_rmse, label=metrics_id, plot_title="Stars " + plot_dataset + self.plot_title, x_axis_label="Number of stars", @@ -295,6 +310,7 @@ def plot(self): for plot_dataset in ["test_metrics", "train_metrics"]: y_axis = [] y_axis_err = [] + y2_axis = [] metrics_id = [] for k, v in self.metrics.items(): @@ -316,11 +332,19 @@ def plot(self): ]["mono_metric"]["std_rmse_lda"] } ) + y2_axis.append( + { + (k + "-" + run_id): metrics_data[run_id][0][ + plot_dataset + ]["mono_metric"]["rel_rmse_lda"] + } + ) make_plot( x_axis=[lambda_list for _ in range(len(y_axis))], y_axis=y_axis, y_axis_err=y_axis_err, + y2_axis=y2_axis, label=metrics_id, plot_title="Stars " + plot_dataset # type: ignore @@ -343,10 +367,8 @@ def plot(self): class ShapeMetricsPlotHandler: """ShapeMetricsPlotHandler class. - A class to handle plot parameters shape metrics results. - Parameters ---------- id: str @@ -359,7 +381,6 @@ class ShapeMetricsPlotHandler: List containing the number of stars used for each training data set plots_dir: str Output directory for metrics plots - """ id = "shape_metrics" @@ -373,96 +394,126 @@ def __init__(self, plotting_params, metrics, list_of_stars, plots_dir): def plot(self): """Plot. - A function to generate plots for the train and test + A generic function to generate plots for the train and test metrics. """ - # Define common data - # Common data e1_req_euclid = 2e-04 e2_req_euclid = 2e-04 R2_req_euclid = 1e-03 + for plot_dataset in ["test_metrics", "train_metrics"]: - e1_rmse = [] - e1_std_rmse = [] - e2_rmse = [] - e2_std_rmse = [] - rmse_R2_meanR2 = [] - std_rmse_R2_meanR2 = [] - metrics_id = [] + metrics_data = self.prepare_metrics_data( + plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid + ) - for k, v in self.metrics.items(): - for metrics_data in v: - run_id = list(metrics_data.keys())[0] - metrics_id.append(run_id + "-" + k) - - e1_rmse.append( - { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ - "shape_results_dict" - ]["rmse_e1"] - / e1_req_euclid - } - ) - e1_std_rmse.append( - { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ - "shape_results_dict" - ]["std_rmse_e1"] - } + # Plot for e1 + for k, v in metrics_data.items(): + self.make_shape_metrics_plot( + metrics_data[k]["rmse"], + metrics_data[k]["std_rmse"], + metrics_data[k]["rel_rmse"], + plot_dataset, + k, + ) + + def prepare_metrics_data( + self, plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid + ): + """Prepare Metrics Data. + + A function to prepare the metrics data for plotting. + + Parameters + ---------- + plot_dataset: str + A string representing the dataset, i.e. training or test metrics. + e1_req_euclid: float + A float denoting the Euclid requirement for the `e1` shape metric. + e2_req_euclid: float + A float denoting the Euclid requirement for the `e2` shape metric. + R2_req_euclid: float + A float denoting the Euclid requirement for the `R2` shape metric. + + Returns + ------- + shape_metrics_data: dict + A dictionary containing the shape metrics data from a set of runs. + + """ + shape_metrics_data = { + "e1": {"rmse": [], "std_rmse": [], "rel_rmse": []}, + "e2": {"rmse": [], "std_rmse": [], "rel_rmse": []}, + "R2_meanR2": {"rmse": [], "std_rmse": [], "rel_rmse": []}, + } + + for k, v in self.metrics.items(): + for metrics_data in v: + run_id = list(metrics_data.keys())[0] + + for metric in ["e1", "e2", "R2_meanR2"]: + metric_rmse = metrics_data[run_id][0][plot_dataset][ + "shape_results_dict" + ][f"rmse_{metric}"] + metric_std_rmse = metrics_data[run_id][0][plot_dataset][ + "shape_results_dict" + ][f"std_rmse_{metric}"] + + relative_metric_rmse = metric_rmse / ( + e1_req_euclid + if metric == "e1" + else (e2_req_euclid if metric == "e2" else R2_req_euclid) ) - e2_rmse.append( - { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ - "shape_results_dict" - ]["rmse_e2"] - / e2_req_euclid - } + shape_metrics_data[metric]["rmse"].append( + {f"{k}-{run_id}": metric_rmse} ) - e2_std_rmse.append( - { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ - "shape_results_dict" - ]["std_rmse_e2"] - } + shape_metrics_data[metric]["std_rmse"].append( + {f"{k}-{run_id}": metric_std_rmse} ) - - rmse_R2_meanR2.append( - { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ - "shape_results_dict" - ]["rmse_R2_meanR2"] - / R2_req_euclid - } + shape_metrics_data[metric]["rel_rmse"].append( + {f"{k}-{run_id}": relative_metric_rmse} ) - std_rmse_R2_meanR2.append( - { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ - "shape_results_dict" - ]["std_rmse_R2_meanR2"] - } - ) + return shape_metrics_data - make_plot( - x_axis=self.list_of_stars, - y_axis=e1_rmse, - y_axis_err=e1_std_rmse, - label=metrics_id, - plot_title="Stars " + plot_dataset + ".\nShape RMSE", - x_axis_label="Number of stars", - y_left_axis_label="Absolute error", - y_right_axis_label="Relative error [%]", - filename=os.path.join( - self.plots_dir, - plot_dataset - + "_nstars_" - + "_".join(str(nstar) for nstar in self.list_of_stars) - + "_Shape_RMSE.png", - ), - plot_show=self.plotting_params.plot_show, - ) + def make_shape_metrics_plot( + self, rmse_data, std_rmse_data, rel_rmse_data, plot_dataset, metric + ): + """Make Shape Metrics Plot. + + A function to produce plots for the shape metrics. + + Parameters + ---------- + rmse_data: list + A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse). + std_rmse_data: list + A list of dictionaries where each dictionary stores run as the key and the Standard Deviation of the Root Mean Square Error (rmse) as the value. + rel_rmse_data: list + A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse) relative to the Euclid requirements as the value. + plot_dataset: str + A string denoting whether metrics are for the train or test datasets. + metric: str + A string representing the type of shape metric, i.e., e1, e2, or R2. + + """ + make_plot( + x_axis=self.list_of_stars, + y_axis=rmse_data, + y_axis_err=std_rmse_data, + y2_axis=rel_rmse_data, + label=[key for item in rmse_data for key in item], + plot_title=f"Stars {plot_dataset}. Shape {metric.upper()} RMSE", + x_axis_label="Number of stars", + y_left_axis_label="Absolute error", + y_right_axis_label="Relative error [%]", + filename=os.path.join( + self.plots_dir, + f"{plot_dataset}_nstars_{'_'.join(str(nstar) for nstar in self.list_of_stars)}_Shape_{metric.upper()}_RMSE.png", + ), + plot_show=self.plotting_params.plot_show, + ) def get_number_of_stars(metrics): @@ -509,16 +560,19 @@ def plot_metrics(plotting_params, list_of_metrics, metrics_confs, plot_saving_pa "poly_metric": { "rmse": "rmse", "std_rmse": "std_rmse", + "rel_rmse": "rel_rmse", "plot_title": ".\nPolychromatic pixel RMSE @ Euclid resolution", }, "opd_metric": { "rmse": "rmse_opd", "std_rmse": "rmse_std_opd", + "rel_rmse": "rel_rmse_opd", "plot_title": ".\nOPD RMSE", }, "shape_results_dict": { "rmse": "pix_rmse", "std_rmse": "pix_rmse_std", + "rel_rmse": "rel_pix_rmse", "plot_title": "\nPixel RMSE @ 3x Euclid resolution", }, } @@ -533,6 +587,7 @@ def plot_metrics(plotting_params, list_of_metrics, metrics_confs, plot_saving_pa k, v["rmse"], v["std_rmse"], + v["rel_rmse"], v["plot_title"], plot_saving_path, ) diff --git a/src/wf_psf/tests/metrics_test.py b/src/wf_psf/tests/metrics_test.py index b81c9a96..821fe2d0 100644 --- a/src/wf_psf/tests/metrics_test.py +++ b/src/wf_psf/tests/metrics_test.py @@ -108,7 +108,7 @@ def main_metrics(training_params): return np.load(os.path.join(main_dir, metrics_filename), allow_pickle=True)[()] -@pytest.mark.skip(reason="Requires gpu") +@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI") def test_eval_metrics_polychromatic_lowres( training_params, weights_path_basename, @@ -156,7 +156,7 @@ def test_eval_metrics_polychromatic_lowres( assert ratio_rel_std_rmse < tol -@pytest.mark.skip(reason="Requires gpu") +@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI") def test_evaluate_metrics_opd( training_params, weights_path_basename, @@ -206,7 +206,7 @@ def test_evaluate_metrics_opd( assert ratio_rel_rmse_std_opd < tol -@pytest.mark.skip(reason="Requires gpu") +@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI") def test_eval_metrics_mono_rmse( training_params, weights_path_basename, @@ -271,7 +271,7 @@ def test_eval_metrics_mono_rmse( assert ratio_rel_rmse_std_mono < tol -@pytest.mark.skip(reason="Requires gpu") +@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI") def test_evaluate_metrics_shape( training_params, weights_path_basename, diff --git a/src/wf_psf/tests/train_test.py b/src/wf_psf/tests/train_test.py index 2a317031..0435d725 100644 --- a/src/wf_psf/tests/train_test.py +++ b/src/wf_psf/tests/train_test.py @@ -66,7 +66,7 @@ def psf_model_dir(): ) -@pytest.mark.skip(reason="Requires gpu") +@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI") def test_train( training_params, training_data,