diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index 39b45a5b..5b67fee7 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -121,15 +121,15 @@ def plot_counterplots(self, dice_model): """ counterplots_out = [] for cf_examples in self.cf_examples_list: - features_names = list(cf_examples.test_instance_df.columns)[:-1] - features_dtypes = list(cf_examples.test_instance_df.dtypes)[:-1] + self.features_names = list(cf_examples.test_instance_df.columns)[:-1] + self.features_dtypes = list(cf_examples.test_instance_df.dtypes)[:-1] factual_instance = cf_examples.test_instance_df.to_numpy()[0][:-1] def convert_data(x): - df_x = pd.DataFrame(data=x, columns=features_names) + df_x = pd.DataFrame(data=x, columns=self.features_names) # Transform each dtype according to features_dtypes - for feature_name, f_dtype in zip(features_names, features_dtypes): - df_x[feature_name] = df_x[feature_name].astype(f_dtype) + for feature_name, f_dtype in zip(self.features_names, self.features_dtypes): + df_x[feature_name] = pd.to_numeric(df_x[feature_name], errors='ignore').astype(f_dtype) return df_x @@ -158,7 +158,7 @@ def model_pred(x): factual=factual_instance, cf=cf_instance[:-1], model_pred=model_pred, - feature_names=features_names, + feature_names=self.features_names, )) return counterplots_out