Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move validation code to a new validators.py module #169

Merged
merged 4 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 6 additions & 122 deletions tableone/tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from statsmodels.stats import multitest
from tabulate import tabulate

from tableone.validators import DataValidator, InputValidator, InputError
from tableone.modality import hartigan_diptest

# display deprecation warnings
Expand Down Expand Up @@ -53,13 +54,6 @@ def dec(obj):
return dec


class InputError(Exception):
"""
Exception raised for errors in the input.
"""
pass


class TableOne:
"""

Expand Down Expand Up @@ -225,14 +219,15 @@ def __init__(self, data: pd.DataFrame,

self._handle_deprecations(labels, rename, isnull, pval_test_name, remarks)

# Default assignment for columns if not provided
if not columns:
columns = data.columns.values # type: ignore

self._validate_data(data, columns)
self.data_validator = DataValidator()
self.data_validator.validate(data, columns) # type: ignore

(groupby, nonnormal, min_max, pval_adjust, order) = self._validate_arguments(
groupby, nonnormal, min_max, pval_adjust, order, pval, columns, categorical, continuous)
self.input_validator = InputValidator()
self.input_validator.validate(groupby, nonnormal, min_max, pval_adjust, order, # type: ignore
pval, columns, categorical, continuous) # type: ignore

# if categorical not specified, try to identify categorical
if not categorical and type(categorical) != list:
Expand Down Expand Up @@ -398,117 +393,6 @@ def _handle_deprecations(self, labels, rename, isnull, pval_test_name, remarks):
"by name instead (e.g. diptest = True)",
DeprecationWarning, stacklevel=2)

def _validate_arguments(self, groupby, nonnormal, min_max, pval_adjust, order, pval, columns,
categorical, continuous):
"""
Run validation checks on the arguments.
"""
# Set defaults if None
if categorical is None:
categorical = []
if continuous is None:
continuous = []

# validate 'groupby' argument
if groupby:
if isinstance(groupby, list):
raise ValueError(f"Invalid 'groupby' type: expected a string, received a list. Use '{groupby[0]}' if it's the intended group.")
elif not isinstance(groupby, str):
raise TypeError(f"Invalid 'groupby' type: expected a string, received {type(groupby).__name__}.")
else:
# If 'groupby' is not provided or is explicitly None, treat it as an empty string.
groupby = ''

# Validate 'nonnormal' argument
if nonnormal is None:
nonnormal = []
elif isinstance(nonnormal, str):
nonnormal = [nonnormal]
elif not isinstance(nonnormal, list):
raise TypeError(f"Invalid 'nonnormal' type: expected a list or a string, received {type(nonnormal).__name__}.")
else:
# Ensure all elements in the list are strings
if not all(isinstance(item, str) for item in nonnormal):
raise ValueError("All items in 'nonnormal' list must be strings.")

# Validate 'min_max' argument
if min_max is None:
min_max = []
elif isinstance(min_max, list):
# Optionally, further validate that the list contains only strings (if needed)
if not all(isinstance(item, str) for item in min_max):
raise ValueError("All items in 'min_max' list must be strings representing column names.")
else:
raise TypeError(f"Invalid 'min_max' type: expected a list, received {type(min_max).__name__}.")

# Validate 'pval_adjust' argument
if pval_adjust is not None:
valid_methods = {"bonferroni", "sidak", "holm-sidak", "simes-hochberg", "hommel", None}
if isinstance(pval_adjust, str):
if pval_adjust.lower() not in valid_methods:
raise ValueError(f"Invalid 'pval_adjust' value: '{pval_adjust}'. "
f"Expected one of {', '.join(valid_methods)} or None.")
else:
raise TypeError(f"Invalid type for 'pval_adjust': expected a string or None, "
f"received {type(pval_adjust).__name__}.")

# Validate 'order' argument
if order is not None:
if not isinstance(order, dict):
raise TypeError("The 'order' parameter must be a dictionary where keys are column names and values are lists of ordered categories.")

for key, values in order.items():
if not isinstance(values, list):
raise TypeError(f"The value for '{key}' in 'order' must be a list of categories.")

# Convert all items in the list to strings safely and efficiently
order[key] = [str(v) for v in values]

# Validate 'pval' argument
if pval and not groupby:
raise ValueError("The 'pval' parameter is set to True, but no 'groupby' parameter was specified. "
"Please provide a 'groupby' column name to perform p-value calculations.")

# Validate 'continuous' and 'categorical' arguments
# Check for mutual exclusivity
cat_set = set(categorical)
cont_set = set(continuous)
if cat_set & cont_set:
raise ValueError("Columns cannot be both categorical and continuous: "
f"{cat_set & cont_set}")

# Check that all specified columns exist in the DataFrame
all_specified = cat_set.union(cont_set)
if not all_specified.issubset(set(columns)):
missing = list(all_specified - set(columns))
raise ValueError("Specified categorical/continuous columns not found in the DataFrame: "
f"{missing}")

return groupby, nonnormal, min_max, pval_adjust, order

def _validate_data(self, data, columns):
"""
Run validation checks on the input dataframe.
"""
if data.empty:
raise ValueError("Input data is empty.")

if not data.index.is_unique:
raise InputError("Input data contains duplicate values in the "
"index. Reset the index and try again.")

if not set(columns).issubset(data.columns): # type: ignore
missing_cols = list(set(columns) - set(data.columns)) # type: ignore
raise InputError("""The following columns were not found in the
dataset: {}""".format(missing_cols))

# check for duplicate columns
dups = data[columns].columns[
data[columns].columns.duplicated()].unique()
if not dups.empty:
raise InputError("""Input data contains duplicate
columns: {}""".format(dups))

def __str__(self) -> str:
return self.tableone.to_string() + self._generate_remarks('\n')

Expand Down
156 changes: 156 additions & 0 deletions tableone/validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Any, List, Optional, Union, Dict, Set

import pandas as pd


class InputError(Exception):
"""Custom exception for input validation errors."""
pass


class DataValidator:
def __init__(self):
"""Initialize the DataValidator class."""
pass

def validate(self, data: pd.DataFrame, columns: list) -> None:
"""
Check the input dataset for obvious issues.

Parameters:
data (pd.DataFrame): The input dataframe for validation.
columns (list): A list of columns expected to be in the dataframe.
"""
self.check_empty_data(data)
self.check_unique_index(data)
self.check_columns_exist(data, columns)
self.check_duplicate_columns(data, columns)

def check_empty_data(self, data: pd.DataFrame):
"""Ensure the dataframe is not empty."""
if data.empty:
raise InputError("Input data is empty.")

def check_unique_index(self, data: pd.DataFrame):
"""Ensure the dataframe's index is unique."""
if not data.index.is_unique:
raise InputError("Input data contains duplicate values in the "
"index. Reset the index and try again.")

def check_columns_exist(self, data: pd.DataFrame, columns: list):
"""Ensure all required columns are present in the dataframe."""
if not set(columns).issubset(data.columns): # type: ignore
missing_cols = list(set(columns) - set(data.columns)) # type: ignore
raise InputError("""The following columns were not found in the
dataset: {}""".format(missing_cols))

def check_duplicate_columns(self, data: pd.DataFrame, columns: list):
"""Ensure no duplicate columns in the data."""
dups = data[columns].columns[
data[columns].columns.duplicated()].unique()
if not dups.empty:
raise InputError("""Input data contains duplicate
columns: {}""".format(dups))


class InputValidator:
def __init__(self):
"""Initialize the InputValidator class."""
pass

def validate(self,
groupby: str,
nonnormal: Union[List[str], str],
min_max: Union[List[str], str],
pval_adjust: str,
order: Dict[str, List[Any]],
pval: bool,
columns: List[str],
categorical: List[str],
continuous: List[str]) -> None:
"""
Check the input dataset for obvious issues.

Parameters:
data (pd.DataFrame): The input dataframe for validation.
columns (list): A list of columns expected to be in the dataframe.
"""
self.check_groupby(groupby, pval)
self.check_list(nonnormal, 'nonnormal')
self.check_list(min_max, 'min_max', expected_type=str)
self.check_pval_adjust(pval_adjust)
self.check_order(order)
self.check_exclusivity(categorical, continuous)
self.check_columns_exist(columns, categorical, continuous)

def check_groupby(self, groupby: str, pval: bool) -> None:
"""Ensure 'groupby' is provided as a str."""
if groupby:
if isinstance(groupby, list):
raise ValueError(f"Invalid 'groupby' type: expected a string, received a list. Use '{groupby[0]}' if it's the intended group.")
elif not isinstance(groupby, str):
raise TypeError(f"Invalid 'groupby' type: expected a string, received {type(groupby).__name__}.")
elif pval:
raise ValueError("The 'pval' parameter is set to True, but no 'groupby' parameter was specified.")

def check_list(self,
parameter: Optional[Union[List[Any], str]],
parameter_name: str,
expected_type: Optional[type] = None) -> None:
"""Ensure list arguments are properly formatted."""
if parameter:
if not isinstance(parameter, (list, str)):
raise TypeError(f"Invalid '{parameter_name}' type: expected a list or a string, received {type(parameter).__name__}.")
if expected_type and any(not isinstance(item, expected_type) for item in parameter):
raise ValueError(f"All items in '{parameter_name}' list must be of type {expected_type.__name__}.")

def check_pval_adjust(self, pval_adjust: str):
"""Ensure 'pval_adjust' is a known method."""
if pval_adjust is not None:
valid_methods = {"bonferroni", "sidak", "holm-sidak", "simes-hochberg", "hommel", None}
if isinstance(pval_adjust, str):
if pval_adjust.lower() not in valid_methods:
raise ValueError(f"Invalid 'pval_adjust' value: '{pval_adjust}'. "
f"Expected one of {', '.join(valid_methods)} or None.")
else:
raise TypeError(f"Invalid type for 'pval_adjust': expected a string or None, "
f"received {type(pval_adjust).__name__}.")

def check_order(self, order: dict):
"""Ensure the order argument is correctly specified."""
if order is not None:
if not isinstance(order, dict):
raise TypeError("The 'order' parameter must be a dictionary where keys are column names and values are lists of ordered categories.")

for key, values in order.items():
if not isinstance(values, list):
raise TypeError(f"The value for '{key}' in 'order' must be a list of categories.")

def check_exclusivity(self, categorical: list, continuous: list):
"""Ensure categorical and continuous are mutually exclusive."""
if categorical is None:
categorical = []
if continuous is None:
continuous = []

if set(categorical) & set(continuous):
raise ValueError("Columns cannot be both categorical and continuous: "
f"{set(categorical) & set(continuous)}")

def check_columns_exist(self, columns: list, categorical: list, continuous: list):
"""Ensure all specified columns exist in the DataFrame columns list."""
if categorical:
cat_set = set(categorical)
else:
cat_set = set()

if continuous:
cont_set = set(continuous)
else:
cont_set = set()

all_specified = cat_set.union(cont_set)
if not all_specified.issubset(set(columns)):
missing = list(all_specified - set(columns))
raise ValueError("Specified categorical/continuous columns not found in the DataFrame: "
f"{missing}")
5 changes: 3 additions & 2 deletions tests/unit/test_tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from scipy import stats

from tableone import TableOne, load_dataset
from tableone.tableone import InputError
from tableone.validators import InputError
from tableone.modality import hartigan_diptest, generate_data

seed = 12345
Expand Down Expand Up @@ -152,7 +152,7 @@ def test_examples_used_in_the_readme_run_without_raising_error_pn(
categorical=categorical, groupby=groupby,
nonnormal=nonnormal, pval=False)

def test_robust_to_duplicates_in_input_df_index(self):
def test_duplicate_index_values_raise_error(self):

d_control = pd.DataFrame(data={'group': [0, 0, 0, 0, 0, 0, 0],
'value': [3, 4, 4, 4, 4, 4, 5]})
Expand All @@ -163,6 +163,7 @@ def test_robust_to_duplicates_in_input_df_index(self):
with pytest.raises(InputError):
TableOne(d, ['value'], groupby='group', pval=True)

# Test with reset indices to ensure normal behavior
d_idx_reset = pd.concat([d_case, d_control], ignore_index=True)
t2 = TableOne(d_idx_reset, ['value'], groupby='group', pval=True)

Expand Down
Loading