diff --git a/anai/__init__.py b/anai/__init__.py index 0ae3b04..991244e 100644 --- a/anai/__init__.py +++ b/anai/__init__.py @@ -6,7 +6,13 @@ import os import shutil import warnings - +from distributed import Client, LocalCluster +try: + clust = LocalCluster(name='ANAI-Cluster', n_workers=2, threads_per_worker=2, processes=True, + host='0.0.0.0', protocol='tcp://', scheduler_port=0, dashboard_address=0) + client = Client(clust) +except Exception as e: + pass import modin.pandas as pd from colorama import Fore from optuna.samplers._tpe.sampler import TPESampler @@ -16,7 +22,6 @@ from anai.utils.connectors import load_data_from_config from anai.utils.connectors.data_handler import __df_loader_single, df_loader -os.environ["MODIN_ENGINE"] = "dask" if os.path.exists(os.getcwd() + "/dask-worker-space"): diff --git a/anai/supervised/__init__.py b/anai/supervised/__init__.py index f56e44f..04cdae2 100644 --- a/anai/supervised/__init__.py +++ b/anai/supervised/__init__.py @@ -1614,7 +1614,7 @@ def explain(self, method, show_graph=True): self.y_val, self.cv_folds, self.fit_params, - self.show_graph, + show_graph, ) if self.pred_mode == "all": regressor = copy.deepcopy(self.best_regressor.model) diff --git a/anai/utils/explainable_anai/explain_core.py b/anai/utils/explainable_anai/explain_core.py index 8a601ae..357a9ae 100644 --- a/anai/utils/explainable_anai/explain_core.py +++ b/anai/utils/explainable_anai/explain_core.py @@ -72,15 +72,16 @@ def permutation(self, model): def shap(self, model): try: - res = shap_feature_importance(self.features.columns, self.X_train, model, self.isReg, self.show_graph) + res = shap_feature_importance(self.features.columns, self.X_train, model, self.show_graph) return res except Exception as e: print(Fore.YELLOW + "Automatically switching to Surrogate mode\n") try: res = shap_feature_importance( - self.features.columns, + self.features.columns, self.X_train, - surrogate_decision_tree(model, self.X_train, isReg=self.isReg, show_graph=self.show_graph), + surrogate_decision_tree(model, self.X_train), + show_graph=self.show_graph ) return res except Exception as e: diff --git a/setup.py b/setup.py index 91352a4..95b9528 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ ], include=["anai.*", "anai"], ), - version="0.1.6", + version="0.1.7", license="Apache License 2.0", description="Automated ML", url="https://github.com/Revca-ANAI/ANAI",