Skip to content

Commit

Permalink
Merge pull request #186 from gganapavarapu/master
Browse files Browse the repository at this point in the history
Model and data support fixes in NNContrastive and TS* docstrings.
  • Loading branch information
vijay-arya authored Jul 31, 2023
2 parents 298ce21 + ad7a1db commit d06585b
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 7 deletions.
1 change: 1 addition & 0 deletions aix360/algorithms/nncontrastive/nncontrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def set_exemplars(self, x: Union[pd.DataFrame, np.ndarray]):
if not self.is_fitted:
raise RuntimeError(f"Error: exemplar can only be set post model fitting!")

x = np.asarray(x)
if self.model is not None: # identify class tags for exemplars using model.
classes = self.model(x)
classes = np.array(classes, dtype=int).reshape(-1)
Expand Down
2 changes: 0 additions & 2 deletions aix360/algorithms/tsice/tsice.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ def __init__(
if perturbers is None:
perturbers = [
dict(type="block-bootstrap"),
dict(type="moving_average"),
dict(type="frequency"),
]

self._parameters = {
Expand Down
7 changes: 6 additions & 1 deletion aix360/algorithms/tslime/tslime.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def __init__(
Args:
model (Callable): Callable object produces a prediction as numpy array
for a given input as numpy array.
for a given input as numpy array. It can be a model prediction (predict/
predict_proba) function that results a real value like probability or regressed value.
This function must accept numpy array of shape (input_length x len(feature_names)) as
input and result in numpy array of shape (1, -1). Currently, TSLime supports sinlge output
models only. For multi-output models, you can aggregate the output using a custom
model_wrapper. Use model wrapper classes from aix360.algorithms.tsutils.model_wrappers.
input_length (int): Input (history) length used for input model.
n_perturbations (int): Number of perturbed instance for TSExplanation. Defaults to 25.
relevant_history (int): Interested window size for explanations. The explanation is
Expand Down
11 changes: 7 additions & 4 deletions aix360/algorithms/tssaliency/tssaliency.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@ def __init__(
"""Initializer for TSSaliencyExplainer
Args:
model (Callable): model prediction (predict/predict_proba) function that
results a real value like probability or regressed value. This function
must accept numpy array of shape (input_length x len(feature_names)) as
input and result in numpy array of shape (1, -1).
model (Callable): Callable object produces a prediction as numpy array
for a given input as numpy array. It can be a model prediction (predict/
predict_proba) function that results a real value like probability or regressed value.
This function must accept numpy array of shape (input_length x len(feature_names)) as
input and result in numpy array of shape (1, -1). Currently, TSSaliency supports sinlge output
models only. For multi-output models, you can aggregate the output using a custom
model_wrapper. Use model wrapper classes from aix360.algorithms.tsutils.model_wrappers.
input_length (int): length of history window used in model training.
feature_names (List[str]): list of feature names in the input data.
base_value (List[float]): base value to be used in saliency computation. The
Expand Down
39 changes: 39 additions & 0 deletions examples/gce/gce_demo.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "ac1735e0-8b6e-4c70-ad8e-41439952dbbe",
"metadata": {},
Expand All @@ -10,6 +11,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "085ad98d-cc56-447c-a6b4-33c6d10eef0c",
"metadata": {},
Expand All @@ -26,6 +28,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "bbc9d496-b41d-4efc-99f7-bb1050633acf",
"metadata": {},
Expand All @@ -34,6 +37,16 @@
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c2ef8dd",
"metadata": {},
"outputs": [],
"source": [
"!pip install plotly"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand All @@ -51,6 +64,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "2bcc686f",
"metadata": {},
Expand Down Expand Up @@ -109,6 +123,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0c9b88d7",
"metadata": {},
Expand Down Expand Up @@ -156,6 +171,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f6b6c392",
"metadata": {},
Expand Down Expand Up @@ -195,6 +211,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "39afef88-3259-4fc8-b365-420c588ab83a",
"metadata": {},
Expand All @@ -203,6 +220,15 @@
"### Plot ICE Explanation"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5e577c3a",
"metadata": {},
"source": [
"`plots.plot_ice_explanation` has helper code to plot the ICE explanation. For a different dataset or variation of plot, you can update the code `plots.py`."
]
},
{
"cell_type": "code",
"execution_count": 7,
Expand All @@ -228,6 +254,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "db330e41-cedf-46e6-bc92-f34b582711f8",
"metadata": {},
Expand All @@ -236,6 +263,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "fb601ba1",
"metadata": {},
Expand Down Expand Up @@ -317,6 +345,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "be57b208-84ac-4a96-b613-c10781402634",
"metadata": {},
Expand All @@ -325,6 +354,15 @@
"### Plot GCE Explanation"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f1b0f097",
"metadata": {},
"source": [
"`plots.plot_gce_explanation` has helper code to plot the GroupedCE (GCE) explanation. For a different dataset or variation of plot, you can update the code `plots.py`."
]
},
{
"cell_type": "code",
"execution_count": 11,
Expand Down Expand Up @@ -357,6 +395,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "a2930171-a22b-44b5-bdb7-8755e29cdce7",
"metadata": {},
Expand Down
46 changes: 46 additions & 0 deletions examples/tsice/tsice_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40a2f00e",
"metadata": {},
"outputs": [],
"source": [
"!pip install huggingface_hub"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -436,6 +446,15 @@
"### Plot Explanation With Range"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b1c6c682",
"metadata": {},
"source": [
"`plots.plot_tsice_explanation` has helper code to plot the TSICE explanation. For a different dataset or variation of plot, you can update the code `plots.py`."
]
},
{
"cell_type": "code",
"execution_count": 10,
Expand Down Expand Up @@ -475,6 +494,15 @@
"Signed impact (Δ forecast) = mean of the forecasts over perturbations - mean of the original forecast."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c4eebf67",
"metadata": {},
"source": [
"`plots.plot_tsice_with_observed_features` has helper code to plot the TSICE explanation with extracted features. For a different dataset or variation of plot, you can update the code `plots.py`."
]
},
{
"cell_type": "code",
"execution_count": 11,
Expand Down Expand Up @@ -590,6 +618,15 @@
"### Plot Explanation With Latest"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "552b423f",
"metadata": {},
"source": [
"`plots.plot_tsice_explanation` has helper code to plot the TSICE explanation. For a different dataset or variation of plot, you can update the code `plots.py`."
]
},
{
"cell_type": "code",
"execution_count": 14,
Expand Down Expand Up @@ -625,6 +662,15 @@
"The above plot clearly shows that the sunspots from 1982 and 1983 have larger impact on the sunpots forecasts of 1984 compared to sunspots from 1975 and 1976."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c881ff5c",
"metadata": {},
"source": [
"`plots.plot_tsice_with_observed_features` has helper code to plot the TSICE explanation with extracted features. For a different dataset or variation of plot, you can update the code `plots.py`."
]
},
{
"cell_type": "code",
"execution_count": 15,
Expand Down
22 changes: 22 additions & 0 deletions examples/tslime/tslime_multivariate_demo.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "48f88759-48eb-4f4f-a896-3b60aa00e003",
"metadata": {},
Expand All @@ -13,6 +14,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "45c39595-265b-4480-b6c1-20339113ed33",
"metadata": {},
Expand All @@ -28,6 +30,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "21e3e84e-ece2-43bb-a11c-9d7dff21bc13",
"metadata": {},
Expand All @@ -38,6 +41,16 @@
"The example model is a pre-trained keras model and hosted on huggingface hub. So, this notebook requires to install tensorflow 2.4+ and huggingface_hub packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0918762f",
"metadata": {},
"outputs": [],
"source": [
"!pip install huggingface_hub"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -67,6 +80,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6a33c82c-5947-4ae2-9a80-87a27032db1d",
"metadata": {},
Expand Down Expand Up @@ -154,6 +168,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "1e50af8b-2ca8-45e4-8a6f-48492ce7a084",
"metadata": {},
Expand Down Expand Up @@ -228,6 +243,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f13daba6-29c7-4143-be7f-333b6d792201",
"metadata": {},
Expand All @@ -248,6 +264,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b7304d90-1085-4aa5-9c85-cb86776e3d03",
"metadata": {},
Expand Down Expand Up @@ -280,6 +297,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "64b6fb0f-c9dd-4243-bd5b-c6c44d8160cd",
"metadata": {},
Expand Down Expand Up @@ -319,6 +337,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "d6144697-4b3d-4945-85dc-56e087ccd65c",
"metadata": {},
Expand Down Expand Up @@ -527,6 +546,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "d41f16ed-758e-484a-b5e8-c1487e3ce565",
"metadata": {},
Expand All @@ -547,6 +567,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b531066b-b407-4704-9bbd-6a25ccfba8f2",
"metadata": {},
Expand Down Expand Up @@ -618,6 +639,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "fa09fee2-4451-4098-9cd4-ce9a3b5bdd2e",
"metadata": {},
Expand Down
Loading

0 comments on commit d06585b

Please sign in to comment.