From cd1b660cdf8fe3913e8771cf98a68943eda9b8f6 Mon Sep 17 00:00:00 2001 From: AndreFCruz Date: Sun, 17 Nov 2024 16:38:06 +0100 Subject: [PATCH] improved error handling when no target column is passed --- folktexts/acs/acs_tasks.py | 9 +-------- folktexts/dataset.py | 28 ++++++++++------------------ folktexts/task.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 26 deletions(-) diff --git a/folktexts/acs/acs_tasks.py b/folktexts/acs/acs_tasks.py index 23064ba..2805810 100755 --- a/folktexts/acs/acs_tasks.py +++ b/folktexts/acs/acs_tasks.py @@ -51,14 +51,7 @@ def make_task( description: str = None, folktables_obj: BasicProblem = None, ) -> ACSTaskMetadata: - # Validate columns mappings exist - if not all(col in acs_columns_map for col in (features + [target])): - missing_cols = {col for col in (features + [target]) if col not in acs_columns_map} - raise ValueError( - f"Not all columns have mappings to textual descriptions. " - f"Missing columns: {missing_cols}." - ) - + """Create an ACS task object from the given parameters.""" # Resolve target column name target_col_name = ( target_threshold.apply_to_column_name(target) diff --git a/folktexts/dataset.py b/folktexts/dataset.py index 5d1710c..081faa1 100755 --- a/folktexts/dataset.py +++ b/folktexts/dataset.py @@ -62,7 +62,9 @@ def __init__( f"Expected `TaskMetadata`.") # Validate data for this task - self._check_task_columns_are_in_df(task, data) + task.check_task_columns_are_available( + available_cols=data.columns.to_list() + ) self._test_size = test_size self._val_size = val_size or 0 @@ -83,27 +85,16 @@ def __init__( if subsampling is not None: self.subsample(subsampling) - @staticmethod - def _check_task_columns_are_in_df(task: TaskMetadata, df: pd.DataFrame, raise_: bool = True) -> bool: - available_cols = df.columns - required_cols = task.features + ([task.get_target()] if task.target else []) - - if raise_ and not all(col in available_cols for col in required_cols): - raise ValueError( - f"The following required task columns were not found in the dataset: " - f"{list(set(required_cols) - set(available_cols))};" - ) - - return all(col in available_cols for col in required_cols) - @property def data(self) -> pd.DataFrame: return self._data @data.setter - def data(self, new_data) -> pd.DataFrame: + def data(self, new_data: pd.DataFrame) -> pd.DataFrame: # Check if task columns are in the data - self._check_task_columns_are_in_df(self.task, new_data) + self.task.check_task_columns_are_available( + new_data.columns.to_list() + ) # Update data self._data = new_data @@ -127,8 +118,9 @@ def task(self) -> TaskMetadata: @task.setter def task(self, new_task: TaskMetadata): # Check if task columns are in the data - self._check_task_columns_are_in_df(new_task, self.data) - + new_task.check_task_columns_are_available( + self.data.columns.to_list() + ) self._task = new_task @property diff --git a/folktexts/task.py b/folktexts/task.py index e0b4098..d38b4f0 100755 --- a/folktexts/task.py +++ b/folktexts/task.py @@ -60,6 +60,9 @@ def __post_init__(self): # Add this task to the class-level dictionary TaskMetadata._tasks[self.name] = self + # Check all required columns are provided by the `cols_to_text` map + self.check_task_columns_are_available(self.cols_to_text.keys()) + # Check target is provided if self.target is None: logging.warning( @@ -92,6 +95,37 @@ def __hash__(self) -> int: hashable_params["question_hash"] = hash(self.question) return int(hash_dict(hashable_params), 16) + def check_task_columns_are_available( + self, + available_cols: list[str], + raise_: bool = True, + ) -> bool: + """Checks if all columns required by this task are available. + + Parameters + ---------- + available_cols : list[str] + The list of column names available in the dataset. + raise_ : bool, optional + Whether to raise an error if some columns are missing, by default True. + + Returns + ------- + all_available : bool + True if all required columns are present in the given list of + available columns, False otherwise. + """ + required_cols = self.features + ([self.get_target()] if self.target else []) + missing_cols = set(required_cols) - set(available_cols) + + if raise_ and len(missing_cols) > 0: + raise ValueError( + f"The following required task columns were not found in the dataset: " + f"{list(missing_cols)};" + ) + + return len(missing_cols) == 0 # Return True if all columns are present + def get_target(self) -> str: """Resolves the name of the target column depending on `self.target_threshold`.""" if self.target is None: