Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update labels and label sizes in plots #273

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions autoemulate/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def _validate_inputs(cv_results, model_name):
)


def check_multioutput(y, output_index):
def _check_multioutput(y, output_index):
"""Checks if y is multi-output and if the output_index is valid."""
if y.ndim > 1:
if (output_index > y.shape[1] - 1) | (output_index < 0):
raise ValueError(
f"Output index {output_index} is out of range. The index should be between 0 and {y.shape[1] - 1}."
)
print(
f"""Multiple outputs detected. Plotting the output variable with index {output_index}.
f"""Plotting the output variable with index {output_index}.
To plot other outputs, set `output_index` argument to the desired index."""
)

Expand Down Expand Up @@ -148,6 +148,8 @@ def _plot_single_fold(
y_test_std,
ax,
title=f"{model_name} - {title_suffix}",
input_index=input_index,
output_index=output_index,
)
else:
display = PredictionErrorDisplay.from_predictions(
Expand Down Expand Up @@ -334,7 +336,7 @@ def _plot_cv(
"""

_validate_inputs(cv_results, model_name)
check_multioutput(y, output_index)
_check_multioutput(y, output_index)

if model_name:
figure = _plot_model_folds(
Expand Down Expand Up @@ -449,7 +451,9 @@ def _plot_model(
y_pred[:, out_idx],
y_std[:, out_idx] if y_std is not None else None,
ax=axs[plot_index],
title=f"X{in_idx} vs. y{out_idx}",
title=f"$X_{in_idx}$ vs. $y_{out_idx}$",
input_index=in_idx,
output_index=out_idx,
)
plot_index += 1
else:
Expand Down Expand Up @@ -479,7 +483,9 @@ def _plot_model(
return fig


def _plot_Xy(X, y, y_pred, y_std=None, ax=None, title="Xy"):
def _plot_Xy(
X, y, y_pred, y_std=None, ax=None, title="Xy", input_index=0, output_index=0
):
"""
Plots observed and predicted values vs. features, including 2σ error bands where available.
"""
Expand Down Expand Up @@ -533,9 +539,9 @@ def _plot_Xy(X, y, y_pred, y_std=None, ax=None, title="Xy"):
label="pred.",
)

ax.set_xlabel("X")
ax.set_ylabel("y")
ax.set_title(title)
ax.set_xlabel(f"$X_{input_index}$", fontsize=13)
ax.set_ylabel(f"$y_{output_index}$", fontsize=13)
ax.set_title(title, fontsize=13)
ax.grid(True, alpha=0.3)

# Get the handles and labels for the scatter plots
Expand Down
12 changes: 6 additions & 6 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

from autoemulate.compare import AutoEmulate
from autoemulate.emulators import RadialBasisFunctions
from autoemulate.plotting import _check_multioutput
from autoemulate.plotting import _plot_cv
from autoemulate.plotting import _plot_model
from autoemulate.plotting import _plot_single_fold
from autoemulate.plotting import _predict_with_optional_std
from autoemulate.plotting import _validate_inputs
from autoemulate.plotting import check_multioutput


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_check_multioutput_with_single_output():
y = np.array([1, 2, 3, 4, 5])
output_index = 0
try:
check_multioutput(y, output_index)
_check_multioutput(y, output_index)
except ValueError as e:
assert False, f"Unexpected ValueError: {str(e)}"

Expand All @@ -81,7 +81,7 @@ def test_check_multioutput_with_multioutput():
y = np.array([[1, 2, 3], [4, 5, 6]])
output_index = 1
try:
check_multioutput(y, output_index)
_check_multioutput(y, output_index)
except ValueError as e:
assert False, f"Unexpected ValueError: {str(e)}"

Expand All @@ -90,7 +90,7 @@ def test_check_multioutput_with_invalid_output_index():
y = np.array([[1, 2, 3], [4, 5, 6]])
output_index = 3
try:
check_multioutput(y, output_index)
_check_multioutput(y, output_index)
assert False, "Expected ValueError to be raised"
except ValueError as e:
assert (
Expand Down Expand Up @@ -354,7 +354,7 @@ def test__plot_model_int(ae_single_output):
output_index=0,
)
assert isinstance(fig, plt.Figure)
assert fig.axes[0].get_title() == "X0 vs. y0"
assert all(term in fig.axes[0].get_title() for term in ["X", "y", "vs."])


def test__plot_model_list(ae_single_output):
Expand All @@ -367,7 +367,7 @@ def test__plot_model_list(ae_single_output):
output_index=[0],
)
assert isinstance(fig, plt.Figure)
assert fig.axes[1].get_title() == "X1 vs. y0"
assert all(term in fig.axes[1].get_title() for term in ["X", "y", "vs."])


def test__plot_model_int_out_of_range(ae_single_output):
Expand Down
Loading