diff --git a/spare_scores/svm.py b/spare_scores/svm.py index 88c653c..61a3aee 100644 --- a/spare_scores/svm.py +++ b/spare_scores/svm.py @@ -222,7 +222,7 @@ def run_CV(self, df: pd.DataFrame, **kwargs) -> None: self.mdl['bias_correct'] = self.bias_correct - def prepare_sample(self, df: pd.DataFrame, fold, scaler, classify=None) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def prepare_sample(self, df: pd.DataFrame, fold, scaler, classify=None): X_train, X_test = scaler.fit_transform(df.loc[fold[0], self.predictors]), scaler.transform(df.loc[fold[1], self.predictors]) y_train, y_test = df.loc[fold[0], self.to_predict], df.loc[fold[1], self.to_predict] if classify is not None: