diff --git a/autoemulate/plotting.py b/autoemulate/plotting.py index 1aa38c07..4842b63e 100644 --- a/autoemulate/plotting.py +++ b/autoemulate/plotting.py @@ -29,7 +29,7 @@ def _validate_inputs(cv_results, model_name): ) -def check_multioutput(y, output_index): +def _check_multioutput(y, output_index): """Checks if y is multi-output and if the output_index is valid.""" if y.ndim > 1: if (output_index > y.shape[1] - 1) | (output_index < 0): @@ -37,7 +37,7 @@ def check_multioutput(y, output_index): f"Output index {output_index} is out of range. The index should be between 0 and {y.shape[1] - 1}." ) print( - f"""Multiple outputs detected. Plotting the output variable with index {output_index}. + f"""Plotting the output variable with index {output_index}. To plot other outputs, set `output_index` argument to the desired index.""" ) @@ -148,6 +148,8 @@ def _plot_single_fold( y_test_std, ax, title=f"{model_name} - {title_suffix}", + input_index=input_index, + output_index=output_index, ) else: display = PredictionErrorDisplay.from_predictions( @@ -334,7 +336,7 @@ def _plot_cv( """ _validate_inputs(cv_results, model_name) - check_multioutput(y, output_index) + _check_multioutput(y, output_index) if model_name: figure = _plot_model_folds( @@ -449,7 +451,9 @@ def _plot_model( y_pred[:, out_idx], y_std[:, out_idx] if y_std is not None else None, ax=axs[plot_index], - title=f"X{in_idx} vs. y{out_idx}", + title=f"$X_{in_idx}$ vs. $y_{out_idx}$", + input_index=in_idx, + output_index=out_idx, ) plot_index += 1 else: @@ -479,7 +483,9 @@ def _plot_model( return fig -def _plot_Xy(X, y, y_pred, y_std=None, ax=None, title="Xy"): +def _plot_Xy( + X, y, y_pred, y_std=None, ax=None, title="Xy", input_index=0, output_index=0 +): """ Plots observed and predicted values vs. features, including 2σ error bands where available. """ @@ -533,9 +539,9 @@ def _plot_Xy(X, y, y_pred, y_std=None, ax=None, title="Xy"): label="pred.", ) - ax.set_xlabel("X") - ax.set_ylabel("y") - ax.set_title(title) + ax.set_xlabel(f"$X_{input_index}$", fontsize=13) + ax.set_ylabel(f"$y_{output_index}$", fontsize=13) + ax.set_title(title, fontsize=13) ax.grid(True, alpha=0.3) # Get the handles and labels for the scatter plots diff --git a/tests/test_plotting.py b/tests/test_plotting.py index e36613ae..3155c272 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -7,12 +7,12 @@ from autoemulate.compare import AutoEmulate from autoemulate.emulators import RadialBasisFunctions +from autoemulate.plotting import _check_multioutput from autoemulate.plotting import _plot_cv from autoemulate.plotting import _plot_model from autoemulate.plotting import _plot_single_fold from autoemulate.plotting import _predict_with_optional_std from autoemulate.plotting import _validate_inputs -from autoemulate.plotting import check_multioutput @pytest.fixture(scope="module") @@ -72,7 +72,7 @@ def test_check_multioutput_with_single_output(): y = np.array([1, 2, 3, 4, 5]) output_index = 0 try: - check_multioutput(y, output_index) + _check_multioutput(y, output_index) except ValueError as e: assert False, f"Unexpected ValueError: {str(e)}" @@ -81,7 +81,7 @@ def test_check_multioutput_with_multioutput(): y = np.array([[1, 2, 3], [4, 5, 6]]) output_index = 1 try: - check_multioutput(y, output_index) + _check_multioutput(y, output_index) except ValueError as e: assert False, f"Unexpected ValueError: {str(e)}" @@ -90,7 +90,7 @@ def test_check_multioutput_with_invalid_output_index(): y = np.array([[1, 2, 3], [4, 5, 6]]) output_index = 3 try: - check_multioutput(y, output_index) + _check_multioutput(y, output_index) assert False, "Expected ValueError to be raised" except ValueError as e: assert ( @@ -354,7 +354,7 @@ def test__plot_model_int(ae_single_output): output_index=0, ) assert isinstance(fig, plt.Figure) - assert fig.axes[0].get_title() == "X0 vs. y0" + assert all(term in fig.axes[0].get_title() for term in ["X", "y", "vs."]) def test__plot_model_list(ae_single_output): @@ -367,7 +367,7 @@ def test__plot_model_list(ae_single_output): output_index=[0], ) assert isinstance(fig, plt.Figure) - assert fig.axes[1].get_title() == "X1 vs. y0" + assert all(term in fig.axes[1].get_title() for term in ["X", "y", "vs."]) def test__plot_model_int_out_of_range(ae_single_output):