Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Scienfitz committed Jan 2, 2025
1 parent 5d8c2de commit 99a4e6c
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 75 deletions.
9 changes: 8 additions & 1 deletion baybe/insights/shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ def _get_explainer_maps() -> (
Returns:
The maps for SHAP and non-SHAP explainers.
"""
EXCLUDED_EXPLAINER_KEYWORDS = ["Tree", "GPU", "Gradient", "Sampling", "Deep"]
EXCLUDED_EXPLAINER_KEYWORDS = [
"Tree",
"GPU",
"Gradient",
"Sampling",
"Deep",
"Linear",
]

def _has_required_init_parameters(cls):
"""Check if non-shap initializer has required standard parameters."""
Expand Down
4 changes: 1 addition & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def fixture_batch_size(request):
@pytest.fixture(
params=[5, pytest.param(8, marks=pytest.mark.slow)],
name="n_grid_points",
ids=["grid5", "grid8"],
ids=["g5", "g8"],
)
def fixture_n_grid_points(request):
"""Number of grid points used in e.g. the mixture tests.
Expand Down Expand Up @@ -887,8 +887,6 @@ def fixture_default_onnx_surrogate(onnx_str) -> CustomONNXSurrogate:
# Reusables


# TODO consider turning this into a fixture returning a campaign after running some
# fake iterations
@retry(
stop=stop_after_attempt(5),
retry=retry_any(
Expand Down
133 changes: 62 additions & 71 deletions tests/insights/test_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,29 @@

import pandas as pd
import pytest
from pytest import param
from pytest import mark

from baybe._optional.info import INSIGHTS_INSTALLED
from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender
from baybe.recommenders.pure.bayesian.base import BayesianRecommender
from baybe.searchspace import SearchSpaceType
from baybe.utils.basic import get_subclasses
from baybe._optional.info import SHAP_INSTALLED
from tests.conftest import run_iterations

pytestmark = pytest.mark.skipif(
not INSIGHTS_INSTALLED, reason="Optional 'insights' dependency not installed."
)
# File-wide parameterization settings
pytestmark = [
mark.skipif(not SHAP_INSTALLED, reason="Optional shap package not installed."),
mark.parametrize("n_grid_points", [5], ids=["g5"]),
mark.parametrize("n_iterations", [2], ids=["i2"]),
mark.parametrize("batch_size", [2], ids=["b2"]),
mark.parametrize(
"parameter_names",
[
["Conti_finite1", "Conti_finite2"],
["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"],
],
ids=["conti_params", "hybrid_params"],
),
]


if INSIGHTS_INSTALLED:
if SHAP_INSTALLED:
from baybe import insights
from baybe._optional.insights import shap
from baybe.insights.shap import (
Expand All @@ -27,13 +36,11 @@
SUPPORTED_SHAP_PLOTS,
SHAPInsight,
)


valid_hybrid_bayesian_recommenders = [
param(TwoPhaseMetaRecommender(recommender=cls()), id=f"{cls.__name__}")
for cls in get_subclasses(BayesianRecommender)
if cls.compatibility == SearchSpaceType.HYBRID
]
else:
ALL_EXPLAINERS = []
NON_SHAP_EXPLAINERS = []
SHAP_EXPLAINERS = []
SUPPORTED_SHAP_PLOTS = []


def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap):
Expand All @@ -56,14 +63,6 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap):
# Sanity check explanation
shap_explanation = shap_insight.explanation
assert isinstance(shap_explanation, shap.Explanation)

df = pd.DataFrame({"Num_disc_1": [0, 2]})
with pytest.raises(
ValueError,
match="The provided data does not have the same "
"amount of parameters as the shap explainer background.",
):
shap_insight._init_explanation(df)
except TypeError as e:
if "The selected explainer class" in str(e):
pytest.xfail("Unsupported model/explainer combination")
Expand All @@ -80,52 +79,47 @@ def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap):
raise e


@pytest.mark.slow
@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders)
@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"])
@pytest.mark.parametrize("explainer_cls", SHAP_EXPLAINERS)
@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"])
@pytest.mark.parametrize(
"parameter_names",
[
["Conti_finite1", "Conti_finite2"],
["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"],
],
ids=["continuous_params", "hybrid_params"],
)
def test_shapley_with_measurements(ongoing_campaign, explainer_cls, use_comp_rep):
@mark.slow
@mark.parametrize("explainer_cls", SHAP_EXPLAINERS)
@mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"])
def test_shap_explainers(ongoing_campaign, explainer_cls, use_comp_rep):
"""Test the explain functionalities with measurements."""
_test_shap_insight(ongoing_campaign, explainer_cls, use_comp_rep, is_shap=True)


@pytest.mark.parametrize("explainer_cls", NON_SHAP_EXPLAINERS)
@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"])
@pytest.mark.parametrize(
"parameter_names",
[["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]],
ids=["hybrid_params"],
)
def test_non_shapley_explainers(ongoing_campaign, explainer_cls):
@mark.parametrize("explainer_cls", NON_SHAP_EXPLAINERS)
def test_non_shap_explainers(ongoing_campaign, explainer_cls):
"""Test the explain functionalities with the non-SHAP explainer MAPLE."""
"""Test the non-SHAP explainer in computational representation."""
_test_shap_insight(
ongoing_campaign, explainer_cls, use_comp_rep=True, is_shap=False
)


@pytest.mark.slow
@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders)
@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"])
@pytest.mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"])
@pytest.mark.parametrize("plot_type", SUPPORTED_SHAP_PLOTS)
@pytest.mark.parametrize(
"parameter_names",
[["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1"]],
ids=["hybrid_params"],
)
def test_shap_insight_plots(ongoing_campaign, use_comp_rep, plot_type):
@mark.slow
@mark.parametrize("explainer_cls", ["KernelExplainer"], ids=["KernelExplainer"])
@mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"])
def test_invalid_explained_data(ongoing_campaign, explainer_cls, use_comp_rep):
"""Test invalid explained data."""
shap_insight = SHAPInsight.from_campaign(
ongoing_campaign,
explainer_cls=explainer_cls,
use_comp_rep=use_comp_rep,
)
df = pd.DataFrame({"Num_disc_1": [0, 2]})
with pytest.raises(
ValueError,
match="The provided data does not have the same amount of parameters as the "
"shap explainer background.",
):
shap_insight._init_explanation(df)


@mark.slow
@mark.parametrize("use_comp_rep", [False, True], ids=["exp", "comp"])
@mark.parametrize("plot_type", SUPPORTED_SHAP_PLOTS)
def test_plots(ongoing_campaign, use_comp_rep, plot_type):
"""Test the default SHAP plots."""
# run_iterations(onngoing_campaign, n_iterations=2, batch_size=1)
shap_insight = SHAPInsight.from_campaign(
ongoing_campaign,
use_comp_rep=use_comp_rep,
Expand All @@ -134,30 +128,27 @@ def test_shap_insight_plots(ongoing_campaign, use_comp_rep, plot_type):
shap_insight.plot(plot_type)


@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders)
@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"])
def test_updated_campaign_explanations(campaign):
def test_updated_campaign_explanations(campaign, n_iterations, batch_size):
"""Test explanations for campaigns with updated measurements."""
with pytest.raises(
ValueError,
match="The campaign does not contain any measurements.",
):
SHAPInsight.from_campaign(campaign)

run_iterations(campaign, n_iterations=2, batch_size=1)
run_iterations(campaign, n_iterations=n_iterations, batch_size=batch_size)
shap_insight = SHAPInsight.from_campaign(campaign)
explanation_two_iter = shap_insight.explanation
run_iterations(campaign, n_iterations=2, batch_size=1)
explanation_1 = shap_insight.explanation

run_iterations(campaign, n_iterations=n_iterations, batch_size=batch_size)
shap_insight = SHAPInsight.from_campaign(campaign)
explanation_four_iter = shap_insight.explanation
assert explanation_two_iter != explanation_four_iter
explanation_2 = shap_insight.explanation

assert explanation_1 != explanation_2, "SHAP explanations should not be identical."


@pytest.mark.parametrize("recommender", valid_hybrid_bayesian_recommenders)
@pytest.mark.parametrize("n_grid_points", [5], ids=["grid5"])
def test_shap_insight_from_recommender(ongoing_campaign):
def test_creation_from_recommender(ongoing_campaign):
"""Test the creation of SHAP insights from a recommender."""
# run_iterations(campaign, n_iterations=2, batch_size=1)
recommender = ongoing_campaign.recommender.recommender
shap_insight = SHAPInsight.from_recommender(
recommender,
Expand Down

0 comments on commit 99a4e6c

Please sign in to comment.