Skip to content

Commit

Permalink
Complay with flake8 rules
Browse files Browse the repository at this point in the history
Signed-off-by: rmazzine <[email protected]>
  • Loading branch information
rmazzine committed Sep 12, 2023
1 parent 770121c commit a0f37a9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
5 changes: 3 additions & 2 deletions dice_ml/counterfactual_explanations.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,13 @@ def convert_data(x):
return df_x

if dice_model.backend == BackEndTypes.Sklearn:
factual_class_idx = np.argmax(
self.factual_class_idx = np.argmax(
dice_model.model.predict_proba(convert_data([factual_instance])))

def model_pred(x):
# Use one against all strategy
pred_prob = dice_model.model.predict_proba(convert_data(x))
class_f_proba = pred_prob[:, factual_class_idx]
class_f_proba = pred_prob[:, self.factual_class_idx]

# Probability for all other classes (excluding class 0)
not_class_f_proba = 1 - class_f_proba
Expand Down
10 changes: 5 additions & 5 deletions tests/test_counterfactual_explanations.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def test_unsupported_versions_to_json(self, unsupported_version):
assert "Unsupported serialization version {}".format(unsupported_version) in str(ucve)


class TestCounterfactualExplanations(unittest.TestCase):
class TestCounterfactualExplanationsPlot(unittest.TestCase):

@patch('dice_ml.counterfactual_explanations.CreatePlot', return_value="dummy_plot")
def test_plot_counterplots_sklearn(self, mock_create_plot):
Expand Down Expand Up @@ -357,10 +357,10 @@ def test_plot_counterplots_sklearn(self, mock_create_plot):
result = counterfact.plot_counterplots(dummy_model)

# Assert the CreatePlot was called twice (as there are 2 counterfactual instances)
self.assertEqual(mock_create_plot.call_count, 2)
assert mock_create_plot.call_count == 2

# Assert that the result is as expected
self.assertEqual(result, ["dummy_plot", "dummy_plot"])
assert result == ["dummy_plot", "dummy_plot"]

@patch('dice_ml.counterfactual_explanations.CreatePlot', return_value="dummy_plot")
def test_plot_counterplots_non_sklearn(self, mock_create_plot):
Expand Down Expand Up @@ -394,7 +394,7 @@ def test_plot_counterplots_non_sklearn(self, mock_create_plot):
result = counterfact.plot_counterplots(dummy_model)

# Assert the CreatePlot was called twice (as there are 2 counterfactual instances)
self.assertEqual(mock_create_plot.call_count, 2)
assert mock_create_plot.call_count == 2

# Assert that the result is as expected
self.assertEqual(result, ["dummy_plot", "dummy_plot"])
assert result == ["dummy_plot", "dummy_plot"]

0 comments on commit a0f37a9

Please sign in to comment.