Skip to content

Commit

Permalink
version 1.0.1 of shapiq
Browse files Browse the repository at this point in the history
  • Loading branch information
hbaniecki committed Jun 5, 2024
1 parent 08bb0db commit 153c73b
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 7 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

### v1.0.1 (2024-06-05)

- add `max_order=1` to `TabularExplainer`
-
- add `max_order=1` to `TabularExplainer` and `TreeExplainer`
- fix `TreeExplainer.explain_X(..., njobs=2, random_state=0)`

### v1.0.0 (2024-06-04)

Expand Down
2 changes: 1 addition & 1 deletion shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
the well established Shapley value and its generalization to interaction.
"""

__version__ = "1.0.0.9000"
__version__ = "1.0.1"

# approximator classes
from .approximator import (
Expand Down
9 changes: 6 additions & 3 deletions shapiq/explainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,12 @@ def explain_X(
"""
assert len(X.shape) == 2
if random_state is not None:
self._imputer._rng = np.random.default_rng(random_state)
self._approximator._rng = np.random.default_rng(random_state)
self._approximator._sampler._rng = np.random.default_rng(random_state)
if hasattr(self, "_imputer"):
self._imputer._rng = np.random.default_rng(random_state)
if hasattr(self, "_approximator"):
self._approximator._rng = np.random.default_rng(random_state)
if hasattr(self._approximator, "_sampler"):
self._approximator._sampler._rng = np.random.default_rng(random_state)
if n_jobs:
import joblib

Expand Down
5 changes: 4 additions & 1 deletion shapiq/explainer/tree/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TreeExplainer(Explainer):
interaction values up to that order. Defaults to ``2``.
min_order: The minimum interaction order to be computed. Defaults to ``1``.
index: The type of interaction to be computed. It can be one of
``["k-SII", "SII", "STII", "FSII", "BII"]``. All indices apart from ``"BII"`` will
``["k-SII", "SII", "STII", "FSII", "BII", "SV"]``. All indices apart from ``"BII"`` will
reduce to the ``"SV"`` (Shapley value) for order 1. Defaults to ``"k-SII"``.
class_label: The class label of the model to explain.
"""
Expand All @@ -52,6 +52,9 @@ def __init__(
if index == "SV" and max_order > 1:
warnings.warn("For index='SV' the max_order is set to 1.")
max_order = 1
elif max_order == 1 and index != "SV":
warnings.warn("For max_order=1 the index is set to 'SV'.")
index = "SV"

# validate and parse model
validated_model = validate_tree_model(model, class_label=class_label)
Expand Down

0 comments on commit 153c73b

Please sign in to comment.