Skip to content

Commit

Permalink
improved error handling when no target column is passed
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Nov 17, 2024
1 parent 546c791 commit cd1b660
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 26 deletions.
9 changes: 1 addition & 8 deletions folktexts/acs/acs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,7 @@ def make_task(
description: str = None,
folktables_obj: BasicProblem = None,
) -> ACSTaskMetadata:
# Validate columns mappings exist
if not all(col in acs_columns_map for col in (features + [target])):
missing_cols = {col for col in (features + [target]) if col not in acs_columns_map}
raise ValueError(
f"Not all columns have mappings to textual descriptions. "
f"Missing columns: {missing_cols}."
)

"""Create an ACS task object from the given parameters."""
# Resolve target column name
target_col_name = (
target_threshold.apply_to_column_name(target)
Expand Down
28 changes: 10 additions & 18 deletions folktexts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def __init__(
f"Expected `TaskMetadata`.")

# Validate data for this task
self._check_task_columns_are_in_df(task, data)
task.check_task_columns_are_available(
available_cols=data.columns.to_list()
)

self._test_size = test_size
self._val_size = val_size or 0
Expand All @@ -83,27 +85,16 @@ def __init__(
if subsampling is not None:
self.subsample(subsampling)

@staticmethod
def _check_task_columns_are_in_df(task: TaskMetadata, df: pd.DataFrame, raise_: bool = True) -> bool:
available_cols = df.columns
required_cols = task.features + ([task.get_target()] if task.target else [])

if raise_ and not all(col in available_cols for col in required_cols):
raise ValueError(
f"The following required task columns were not found in the dataset: "
f"{list(set(required_cols) - set(available_cols))};"
)

return all(col in available_cols for col in required_cols)

@property
def data(self) -> pd.DataFrame:
return self._data

@data.setter
def data(self, new_data) -> pd.DataFrame:
def data(self, new_data: pd.DataFrame) -> pd.DataFrame:
# Check if task columns are in the data
self._check_task_columns_are_in_df(self.task, new_data)
self.task.check_task_columns_are_available(
new_data.columns.to_list()
)

# Update data
self._data = new_data
Expand All @@ -127,8 +118,9 @@ def task(self) -> TaskMetadata:
@task.setter
def task(self, new_task: TaskMetadata):
# Check if task columns are in the data
self._check_task_columns_are_in_df(new_task, self.data)

new_task.check_task_columns_are_available(
self.data.columns.to_list()
)
self._task = new_task

@property
Expand Down
34 changes: 34 additions & 0 deletions folktexts/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def __post_init__(self):
# Add this task to the class-level dictionary
TaskMetadata._tasks[self.name] = self

# Check all required columns are provided by the `cols_to_text` map
self.check_task_columns_are_available(self.cols_to_text.keys())

# Check target is provided
if self.target is None:
logging.warning(
Expand Down Expand Up @@ -92,6 +95,37 @@ def __hash__(self) -> int:
hashable_params["question_hash"] = hash(self.question)
return int(hash_dict(hashable_params), 16)

def check_task_columns_are_available(
self,
available_cols: list[str],
raise_: bool = True,
) -> bool:
"""Checks if all columns required by this task are available.
Parameters
----------
available_cols : list[str]
The list of column names available in the dataset.
raise_ : bool, optional
Whether to raise an error if some columns are missing, by default True.
Returns
-------
all_available : bool
True if all required columns are present in the given list of
available columns, False otherwise.
"""
required_cols = self.features + ([self.get_target()] if self.target else [])
missing_cols = set(required_cols) - set(available_cols)

if raise_ and len(missing_cols) > 0:
raise ValueError(
f"The following required task columns were not found in the dataset: "
f"{list(missing_cols)};"
)

return len(missing_cols) == 0 # Return True if all columns are present

def get_target(self) -> str:
"""Resolves the name of the target column depending on `self.target_threshold`."""
if self.target is None:
Expand Down

0 comments on commit cd1b660

Please sign in to comment.