From 241d681136e959d71b6bebb477eae6c161109b63 Mon Sep 17 00:00:00 2001 From: Kevin Klein <7267523+kklein@users.noreply.github.com> Date: Thu, 29 Aug 2024 17:32:17 +0200 Subject: [PATCH] Add tests for index_matrix and index_vector. (#89) --- metalearners/_utils.py | 1 + tests/test__utils.py | 59 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/metalearners/_utils.py b/metalearners/_utils.py index a5c02f37..1297ad99 100644 --- a/metalearners/_utils.py +++ b/metalearners/_utils.py @@ -26,6 +26,7 @@ def safe_len(X: Matrix) -> int: + """Determine the length of a Matrix.""" if scipy.sparse.issparse(X): return X.shape[0] return len(X) diff --git a/tests/test__utils.py b/tests/test__utils.py index 756f756a..e169c020 100644 --- a/tests/test__utils.py +++ b/tests/test__utils.py @@ -8,6 +8,7 @@ import pytest from glum import GeneralizedLinearRegressor, GeneralizedLinearRegressorCV from lightgbm import LGBMClassifier, LGBMRegressor +from scipy.sparse import csr_matrix from sklearn.ensemble import HistGradientBoostingClassifier from sklearn.linear_model import LinearRegression from xgboost import XGBClassifier, XGBRegressor @@ -20,6 +21,8 @@ convert_treatment, function_has_argument, get_linear_dimension, + index_matrix, + index_vector, supports_categoricals, validate_all_vectors_same_index, validate_model_and_predict_method, @@ -345,3 +348,59 @@ def test_validate_valid_treatment_variant_not_control( else: with pytest.raises(ValueError, match="variant"): validate_valid_treatment_variant_not_control(treatment_variant, n_variants) + + +@pytest.mark.parametrize("matrix_backend", [np.ndarray, pd.DataFrame, csr_matrix]) +@pytest.mark.parametrize("rows_backend", [np.array, pd.Series]) +def test_index_matrix(matrix_backend, rows_backend): + n_samples = 10 + if matrix_backend == np.ndarray: + matrix = np.array(list(range(n_samples))).reshape((-1, 1)) + elif matrix_backend == pd.DataFrame: + # We make sure that the index is not equal to the row number. + matrix = pd.DataFrame( + list(range(n_samples)), index=list(range(20, 20 + n_samples)) + ) + elif matrix_backend == csr_matrix: + matrix = csr_matrix(np.array(list(range(n_samples))).reshape((-1, 1))) + else: + raise ValueError() + rows = rows_backend([1, 4, 5]) + result = index_matrix(matrix=matrix, rows=rows) + + assert isinstance(result, matrix_backend) + assert result.shape[1] == matrix.shape[1] + + if isinstance(result, pd.DataFrame): + processed_result = result.values[:, 0] + else: + processed_result = result[:, 0] + + expected = np.array([1, 4, 5]) + assert (processed_result == expected).sum() == len(expected) + + +@pytest.mark.parametrize("vector_backend", [np.ndarray, pd.Series]) +@pytest.mark.parametrize("rows_backend", [np.array, pd.Series]) +def test_index_vector(vector_backend, rows_backend): + n_samples = 10 + if vector_backend == np.ndarray: + vector = np.array(list(range(n_samples))) + elif vector_backend == pd.Series: + # We make sure that the index is not equal to the row number. + vector = pd.Series( + list(range(n_samples)), index=list(range(20, 20 + n_samples)) + ) + else: + raise ValueError() + + rows = rows_backend([1, 4, 5]) + + result = index_vector(vector=vector, rows=rows) + assert isinstance(result, vector_backend) + + if isinstance(result, pd.Series): + result = result.values + + expected = np.array([1, 4, 5]) + assert (result == expected).all()