Skip to content

Commit

Permalink
Update to plotting functions to store data info in filename
Browse files Browse the repository at this point in the history
  • Loading branch information
jeipollack committed Nov 9, 2023
1 parent 90000c9 commit d2d1a32
Showing 1 changed file with 40 additions and 12 deletions.
52 changes: 40 additions & 12 deletions src/wf_psf/plotting/plots_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,12 @@ def plot(self):
y_right_axis_label="Relative error [%]",
filename=os.path.join(
self.plots_dir,
plot_dataset + "_" + self.metric_name + "_RMSE.png",
plot_dataset
+ "_"
+ self.metric_name
+ "_nstars_"
+ "_".join(str(nstar) for nstar in self.list_of_stars)
+ "_RMSE.png",
),
plot_show=self.plotting_params.plot_show,
)
Expand All @@ -254,17 +259,27 @@ class MonochromaticMetricsPlotHandler:
Dictionary containing the metric configurations as RecursiveNamespace objects for each run
metrics: list
Dictionary containing list of metrics
list_of_stars: list
List containing the number of stars used for each training data set
plots_dir: str
Output directory for metrics plots
"""

ids = ("mono_metrics",)

def __init__(self, plotting_params, metrics_confs, metrics, plots_dir):
def __init__(
self,
plotting_params,
metrics_confs,
metrics,
list_of_stars,
plots_dir,
):
self.plotting_params = plotting_params
self.metrics_confs = metrics_confs
self.metrics = metrics
self.list_of_stars = list_of_stars
self.plots_dir = plots_dir

def plot(self):
Expand All @@ -289,16 +304,16 @@ def plot(self):
metrics_id.append(run_id + "-" + k)
y_axis.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset]["mono_metric"][
"rmse_lda"
]
(k + "-" + run_id): metrics_data[run_id][0][
plot_dataset
]["mono_metric"]["rmse_lda"]
}
)
y_axis_err.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset]["mono_metric"][
"std_rmse_lda"
]
(k + "-" + run_id): metrics_data[run_id][0][
plot_dataset
]["mono_metric"]["std_rmse_lda"]
}
)

Expand All @@ -315,7 +330,12 @@ def plot(self):
y_right_axis_label="Relative error [%]",
filename=os.path.join(
self.plots_dir,
(plot_dataset + "_monochrom_pixel_RMSE.png"),
(
plot_dataset
+ "_nstars_"
+ "_".join(str(nstar) for nstar in self.list_of_stars)
+ "_monochrom_pixel_RMSE.png"
),
),
plot_show=self.plotting_params.plot_show,
)
Expand All @@ -335,6 +355,8 @@ class ShapeMetricsPlotHandler:
Recursive Namespace Object containing plotting parameters
metrics: list
Dictionary containing list of metrics
list_of_stars: list
List containing the number of stars used for each training data set
plots_dir: str
Output directory for metrics plots
Expand Down Expand Up @@ -417,7 +439,6 @@ def plot(self):

std_rmse_R2_meanR2.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["std_rmse_R2_meanR2"]
Expand All @@ -435,7 +456,10 @@ def plot(self):
y_right_axis_label="Relative error [%]",
filename=os.path.join(
self.plots_dir,
plot_dataset + "_Shape_RMSE.png",
plot_dataset
+ "_nstars_"
+ "_".join(str(nstar) for nstar in self.list_of_stars)
+ "_Shape_RMSE.png",
),
plot_show=self.plotting_params.plot_show,
)
Expand All @@ -447,6 +471,10 @@ def get_number_of_stars(metrics):
A function to get the number of stars used
in training the model.
Parameters
----------
metrics: dict
A dictionary containig the metrics results per run
Returns
-------
list_of_stars: list
Expand Down Expand Up @@ -511,7 +539,7 @@ def plot_metrics(plotting_params, list_of_metrics, metrics_confs, plot_saving_pa
metrics_plot.plot()

monochrom_metrics_plot = MonochromaticMetricsPlotHandler(
plotting_params, metrics_confs, list_of_metrics, plot_saving_path
plotting_params, metrics_confs, list_of_metrics, list_of_stars, plot_saving_path
)
monochrom_metrics_plot.plot()

Expand Down

0 comments on commit d2d1a32

Please sign in to comment.