Skip to content

Commit

Permalink
Implement Counterplots
Browse files Browse the repository at this point in the history
Signed-off-by: rmazzine <[email protected]>
  • Loading branch information
rmazzine committed Sep 11, 2023
1 parent 838d97f commit 4c53de2
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 1 deletion.
53 changes: 52 additions & 1 deletion dice_ml/counterfactual_explanations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import os

import jsonschema
import numpy as np
import pandas as pd
from counterplots import CreatePlot
from raiutils.exceptions import UserConfigValidationException

from dice_ml.constants import _SchemaVersions
from dice_ml.constants import _SchemaVersions, BackEndTypes
from dice_ml.diverse_counterfactuals import (CounterfactualExamples,
_DiverseCFV2SchemaConstants)

Expand Down Expand Up @@ -111,6 +114,54 @@ def visualize_as_list(self, display_sparse_df=True,
display_sparse_df=display_sparse_df,
show_only_changes=show_only_changes)

def plot_counterplots(self, dice_model):
"""Plot counterfactual with CounterPlots package.
:param dice_model: DiCE's model object.
"""
counterplots_out = []
for cf_examples in self.cf_examples_list:
features_names = list(cf_examples.test_instance_df.columns)[:-1]
features_dtypes = list(cf_examples.test_instance_df.dtypes)[:-1]
factual_instance = cf_examples.test_instance_df.to_numpy()[0][:-1]

def convert_data(x):
df_x = pd.DataFrame(data=x, columns=features_names)
# Transform each dtype according to features_dtypes
for feature_name, f_dtype in zip(features_names, features_dtypes):
df_x[feature_name] = df_x[feature_name].astype(f_dtype)

return df_x

if dice_model.backend == BackEndTypes.Sklearn:
factual_class_idx = np.argmax(
dice_model.model.predict_proba(convert_data([factual_instance])))
def model_pred(x):
# Use one against all strategy
pred_prob = dice_model.model.predict_proba(convert_data(x))
class_f_proba = pred_prob[:, factual_class_idx]

# Probability for all other classes (excluding class 0)
not_class_f_proba = 1 - class_f_proba

# Normalize to sum to 1
class_f_proba = class_f_proba / (class_f_proba + not_class_f_proba)

return class_f_proba
else:
def model_pred(x):
return dice_model.model.predict(dice_model.transformer.transform(convert_data(x)))

for cf_instance in cf_examples.final_cfs_df.to_numpy():
counterplots_out.append(
CreatePlot(
factual=factual_instance,
cf=cf_instance[:-1],
model_pred=model_pred,
feature_names=features_names,
))
return counterplots_out

@staticmethod
def _check_cf_exp_output_against_json_schema(
cf_dict, version):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pandas<2.0.0
scikit-learn
tqdm
raiutils>=0.4.0
counterplots>=0.0.7
79 changes: 79 additions & 0 deletions tests/test_counterfactual_explanations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import json

import pytest
import unittest
from unittest.mock import patch, Mock
from raiutils.exceptions import UserConfigValidationException

import pandas as pd
import numpy as np
from dice_ml.counterfactual_explanations import CounterfactualExplanations


Expand Down Expand Up @@ -319,3 +323,78 @@ def test_unsupported_versions_to_json(self, unsupported_version):
counterfactual_explanations.to_json()

assert "Unsupported serialization version {}".format(unsupported_version) in str(ucve)


class TestCounterfactualExplanations(unittest.TestCase):

@patch('dice_ml.counterfactual_explanations.CreatePlot', return_value="dummy_plot")
def test_plot_counterplots_sklearn(self, mock_create_plot):
# Dummy DiCE's model object with a Sklearn backend
dummy_model = Mock()
dummy_model.backend = "sklearn"
dummy_model.model.predict_proba = Mock(return_value=np.array([[0.4, 0.6], [0.2, 0.8]]))

# Sample cf_examples to test with
cf_examples_mock = Mock()
cf_examples_mock.test_instance_df = pd.DataFrame({
'feature1': [1],
'feature2': [2],
'target': [0]
})
cf_examples_mock.final_cfs_df = pd.DataFrame({
'feature1': [1.1, 1.2],
'feature2': [2.1, 2.2],
'target': [1, 1]
})

counterfact = CounterfactualExplanations(
cf_examples_list=[cf_examples_mock],
local_importance=None,
summary_importance=None,
version=None)

# Call function
result = counterfact.plot_counterplots(dummy_model)

# Assert the CreatePlot was called twice (as there are 2 counterfactual instances)
self.assertEqual(mock_create_plot.call_count, 2)

# Assert that the result is as expected
self.assertEqual(result, ["dummy_plot", "dummy_plot"])

@patch('dice_ml.counterfactual_explanations.CreatePlot', return_value="dummy_plot")
def test_plot_counterplots_non_sklearn(self, mock_create_plot):
# Sample Non-Sklearn backend
dummy_model = Mock()
dummy_model.backend = "NonSklearn"
dummy_model.model.predict = Mock(return_value=np.array([0, 1]))
dummy_model.transformer = Mock()
dummy_model.transformer.transform = Mock(return_value=np.array([[1, 2], [1.1, 2.1]]))

# Sample cf_examples to test with
cf_examples_mock = Mock()
cf_examples_mock.test_instance_df = pd.DataFrame({
'feature1': [1],
'feature2': [2],
'target': [0]
})
cf_examples_mock.final_cfs_df = pd.DataFrame({
'feature1': [1.1, 1.2],
'feature2': [2.1, 2.2],
'target': [1, 1]
})

counterfact = CounterfactualExplanations(
cf_examples_list=[cf_examples_mock],
local_importance=None,
summary_importance=None,
version=None)

# Call function
result = counterfact.plot_counterplots(dummy_model)

# Assert the CreatePlot was called twice (as there are 2 counterfactual instances)
self.assertEqual(mock_create_plot.call_count, 2)

# Assert that the result is as expected
self.assertEqual(result, ["dummy_plot", "dummy_plot"])

0 comments on commit 4c53de2

Please sign in to comment.