diff --git a/baybe/insights/shap.py b/baybe/insights/shap.py index 50516a407..3f7140214 100644 --- a/baybe/insights/shap.py +++ b/baybe/insights/shap.py @@ -1,5 +1,7 @@ """SHAP insights.""" +from __future__ import annotations + import inspect import numbers import warnings @@ -115,7 +117,7 @@ def from_campaign( explained_data: pd.DataFrame | None = None, explainer_cls: type[shap.Explainer] | str = "KernelExplainer", use_comp_rep: bool = False, - ): + ) -> SHAPInsight: """Create a SHAP insight from a campaign. Args: @@ -161,7 +163,7 @@ def from_recommender( explained_data: pd.DataFrame | None = None, explainer_cls: type[shap.Explainer] | str = "KernelExplainer", use_comp_rep: bool = False, - ): + ) -> SHAPInsight: """Create a SHAP insight from a recommender. Args: diff --git a/tests/conftest.py b/tests/conftest.py index c82630cd4..1dc620741 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -591,6 +591,13 @@ def fixture_campaign(parameters, constraints, recommender, objective): ) +@pytest.fixture(name="ongoing_campaign") +def fixture_ongoing_campaign(campaign, n_iterations, batch_size): + """Returns a campaign that already ran for several iterations.""" + run_iterations(campaign, n_iterations, batch_size) + return campaign + + @pytest.fixture(name="searchspace") def fixture_searchspace(parameters, constraints): """Returns a searchspace.""" diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 0b2122989..19781bf84 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -38,8 +38,9 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): """Helper function for general SHAP explainer tests.""" - run_iterations(campaign, n_iterations=2, batch_size=1) + # run_iterations(campaign, n_iterations=2, batch_size=5) try: + # Sanity check explainer shap_insight = SHAPInsight.from_campaign( campaign, explainer_cls=explainer_cls, @@ -51,8 +52,11 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): ALL_EXPLAINERS[explainer_cls], ) assert shap_insight.uses_shap_explainer == is_shap + + # Sanity check explanation shap_explanation = shap_insight.explanation assert isinstance(shap_explanation, shap.Explanation) + df = pd.DataFrame({"Num_disc_1": [0, 2]}) with pytest.raises( ValueError, @@ -89,9 +93,9 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): ], ids=["continuous_params", "hybrid_params"], ) -def test_shapley_with_measurements(campaign, explainer_cls, use_comp_rep): +def test_shapley_with_measurements(ongoing_campaign, explainer_cls, use_comp_rep): """Test the explain functionalities with measurements.""" - _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap=True) + _test_shap_insight(ongoing_campaign, explainer_cls, use_comp_rep, is_shap=True) @pytest.mark.parametrize("explainer_cls", NON_SHAP_EXPLAINERS) @@ -101,10 +105,12 @@ def test_shapley_with_measurements(campaign, explainer_cls, use_comp_rep): [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], ids=["hybrid_params"], ) -def test_non_shapley_explainers(campaign, explainer_cls): +def test_non_shapley_explainers(ongoing_campaign, explainer_cls): """Test the explain functionalities with the non-SHAP explainer MAPLE.""" """Test the non-SHAP explainer in computational representation.""" - _test_shap_insight(campaign, explainer_cls, use_comp_rep=True, is_shap=False) + _test_shap_insight( + ongoing_campaign, explainer_cls, use_comp_rep=True, is_shap=False + ) @pytest.mark.slow @@ -117,11 +123,11 @@ def test_non_shapley_explainers(campaign, explainer_cls): [["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]], ids=["hybrid_params"], ) -def test_shap_insight_plots(campaign, use_comp_rep, plot_type): +def test_shap_insight_plots(ongoing_campaign, use_comp_rep, plot_type): """Test the default SHAP plots.""" - run_iterations(campaign, n_iterations=2, batch_size=1) + # run_iterations(onngoing_campaign, n_iterations=2, batch_size=1) shap_insight = SHAPInsight.from_campaign( - campaign, + ongoing_campaign, use_comp_rep=use_comp_rep, ) with mock.patch("matplotlib.pyplot.show"): @@ -136,7 +142,8 @@ def test_updated_campaign_explanations(campaign): ValueError, match="The campaign does not contain any measurements.", ): - shap_insight = SHAPInsight.from_campaign(campaign) + SHAPInsight.from_campaign(campaign) + run_iterations(campaign, n_iterations=2, batch_size=1) shap_insight = SHAPInsight.from_campaign(campaign) explanation_two_iter = shap_insight.explanation @@ -148,14 +155,14 @@ def test_updated_campaign_explanations(campaign): @pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders) @pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"]) -def test_shap_insight_from_recommender(campaign): +def test_shap_insight_from_recommender(ongoing_campaign): """Test the creation of SHAP insights from a recommender.""" - run_iterations(campaign, n_iterations=2, batch_size=1) - recommender = campaign.recommender.recommender + # run_iterations(campaign, n_iterations=2, batch_size=1) + recommender = ongoing_campaign.recommender.recommender shap_insight = SHAPInsight.from_recommender( recommender, - campaign.searchspace, - campaign.objective, - campaign.measurements, + ongoing_campaign.searchspace, + ongoing_campaign.objective, + ongoing_campaign.measurements, ) assert isinstance(shap_insight, insights.SHAPInsight)