From a176b425aad2ab99a37261a84e62635ac41e082c Mon Sep 17 00:00:00 2001 From: George Aidinis Date: Mon, 20 May 2024 18:29:21 -0400 Subject: [PATCH] Upgraded torch to v2, fixed FutureWarning for SVM training --- setup.py | 7 ++++++- spare_scores/svm.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 583a2ba..5a00b7f 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,12 @@ packages=find_packages(), package_data={'spare_scores':['mdl/*.pkl.gz','data/*.csv']}, include_package_data=True, - install_requires=['numpy', 'pandas', 'scikit-learn', 'torch==1.11', 'matplotlib', 'optuna'], + install_requires=['numpy', + 'pandas', + 'scikit-learn', + 'torch<2.1', + 'matplotlib', + 'optuna'], entry_points={ 'console_scripts': ['spare_score=spare_scores.cli:main'] }, diff --git a/spare_scores/svm.py b/spare_scores/svm.py index e72a571..5c7bedd 100644 --- a/spare_scores/svm.py +++ b/spare_scores/svm.py @@ -189,11 +189,11 @@ def train_initialize(self, df, to_predict): if self.task == 'Classification': self.type, self.scoring, metrics = 'SVC', 'roc_auc', ['AUC', 'Accuracy', 'Sensitivity', 'Specificity', 'Precision', 'Recall', 'F1'] self.to_predict, self.classify = to_predict, list(df[to_predict].unique()) - self.mdl = ([LinearSVC(max_iter=100000)] if self.kernel == 'linear' else [SVC(max_iter=100000, kernel=self.kernel)]) * len(self.folds) + self.mdl = ([LinearSVC(max_iter=100000, dual='auto')] if self.kernel == 'linear' else [SVC(max_iter=100000, kernel=self.kernel)]) * len(self.folds) elif self.task == 'Regression': self.type, self.scoring, metrics = 'SVR', 'neg_mean_absolute_error', ['MAE', 'RMSE', 'R2'] self.to_predict, self.classify = to_predict, None - self.mdl = [LinearSVR(max_iter=100000)] * len(self.folds) + self.mdl = [LinearSVR(max_iter=100000, dual='auto')] * len(self.folds) self.bias_correct = {'slope':np.zeros((len(self.folds),)), 'int':np.zeros((len(self.folds),))} self.stats = {metric: [] for metric in metrics} logging.info(f'Training a SPARE model ({self.type}) with {len(df.index)} participants')