Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
Scienfitz committed Jan 2, 2025
1 parent 69cecd8 commit 5d8c2de
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 17 deletions.
6 changes: 4 additions & 2 deletions baybe/insights/shap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""SHAP insights."""

from __future__ import annotations

import inspect
import numbers
import warnings
Expand Down Expand Up @@ -115,7 +117,7 @@ def from_campaign(
explained_data: pd.DataFrame | None = None,
explainer_cls: type[shap.Explainer] | str = "KernelExplainer",
use_comp_rep: bool = False,
):
) -> SHAPInsight:
"""Create a SHAP insight from a campaign.
Args:
Expand Down Expand Up @@ -161,7 +163,7 @@ def from_recommender(
explained_data: pd.DataFrame | None = None,
explainer_cls: type[shap.Explainer] | str = "KernelExplainer",
use_comp_rep: bool = False,
):
) -> SHAPInsight:
"""Create a SHAP insight from a recommender.
Args:
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,13 @@ def fixture_campaign(parameters, constraints, recommender, objective):
)


@pytest.fixture(name="ongoing_campaign")
def fixture_ongoing_campaign(campaign, n_iterations, batch_size):
"""Returns a campaign that already ran for several iterations."""
run_iterations(campaign, n_iterations, batch_size)
return campaign


@pytest.fixture(name="searchspace")
def fixture_searchspace(parameters, constraints):
"""Returns a searchspace."""
Expand Down
37 changes: 22 additions & 15 deletions tests/insights/test_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@

def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap):
"""Helper function for general SHAP explainer tests."""
run_iterations(campaign, n_iterations=2, batch_size=1)
# run_iterations(campaign, n_iterations=2, batch_size=5)
try:
# Sanity check explainer
shap_insight = SHAPInsight.from_campaign(
campaign,
explainer_cls=explainer_cls,
Expand All @@ -51,8 +52,11 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap):
ALL_EXPLAINERS[explainer_cls],
)
assert shap_insight.uses_shap_explainer == is_shap

# Sanity check explanation
shap_explanation = shap_insight.explanation
assert isinstance(shap_explanation, shap.Explanation)

df = pd.DataFrame({"Num_disc_1": [0, 2]})
with pytest.raises(
ValueError,
Expand Down Expand Up @@ -89,9 +93,9 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap):
],
ids=["continuous_params", "hybrid_params"],
)
def test_shapley_with_measurements(campaign, explainer_cls, use_comp_rep):
def test_shapley_with_measurements(ongoing_campaign, explainer_cls, use_comp_rep):
"""Test the explain functionalities with measurements."""
_test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap=True)
_test_shap_insight(ongoing_campaign, explainer_cls, use_comp_rep, is_shap=True)


@pytest.mark.parametrize("explainer_cls", NON_SHAP_EXPLAINERS)
Expand All @@ -101,10 +105,12 @@ def test_shapley_with_measurements(campaign, explainer_cls, use_comp_rep):
[["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]],
ids=["hybrid_params"],
)
def test_non_shapley_explainers(campaign, explainer_cls):
def test_non_shapley_explainers(ongoing_campaign, explainer_cls):
"""Test the explain functionalities with the non-SHAP explainer MAPLE."""
"""Test the non-SHAP explainer in computational representation."""
_test_shap_insight(campaign, explainer_cls, use_comp_rep=True, is_shap=False)
_test_shap_insight(
ongoing_campaign, explainer_cls, use_comp_rep=True, is_shap=False
)


@pytest.mark.slow
Expand All @@ -117,11 +123,11 @@ def test_non_shapley_explainers(campaign, explainer_cls):
[["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]],
ids=["hybrid_params"],
)
def test_shap_insight_plots(campaign, use_comp_rep, plot_type):
def test_shap_insight_plots(ongoing_campaign, use_comp_rep, plot_type):
"""Test the default SHAP plots."""
run_iterations(campaign, n_iterations=2, batch_size=1)
# run_iterations(onngoing_campaign, n_iterations=2, batch_size=1)
shap_insight = SHAPInsight.from_campaign(
campaign,
ongoing_campaign,
use_comp_rep=use_comp_rep,
)
with mock.patch("matplotlib.pyplot.show"):
Expand All @@ -136,7 +142,8 @@ def test_updated_campaign_explanations(campaign):
ValueError,
match="The campaign does not contain any measurements.",
):
shap_insight = SHAPInsight.from_campaign(campaign)
SHAPInsight.from_campaign(campaign)

run_iterations(campaign, n_iterations=2, batch_size=1)
shap_insight = SHAPInsight.from_campaign(campaign)
explanation_two_iter = shap_insight.explanation
Expand All @@ -148,14 +155,14 @@ def test_updated_campaign_explanations(campaign):

@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders)
@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"])
def test_shap_insight_from_recommender(campaign):
def test_shap_insight_from_recommender(ongoing_campaign):
"""Test the creation of SHAP insights from a recommender."""
run_iterations(campaign, n_iterations=2, batch_size=1)
recommender = campaign.recommender.recommender
# run_iterations(campaign, n_iterations=2, batch_size=1)
recommender = ongoing_campaign.recommender.recommender
shap_insight = SHAPInsight.from_recommender(
recommender,
campaign.searchspace,
campaign.objective,
campaign.measurements,
ongoing_campaign.searchspace,
ongoing_campaign.objective,
ongoing_campaign.measurements,
)
assert isinstance(shap_insight, insights.SHAPInsight)

0 comments on commit 5d8c2de

Please sign in to comment.