Skip to content

Commit

Permalink
making target optional for use of folktexts for encoding data only (n…
Browse files Browse the repository at this point in the history
…o predictions)
  • Loading branch information
AndreFCruz committed Nov 17, 2024
1 parent 3e33fce commit 546c791
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 19 deletions.
2 changes: 1 addition & 1 deletion folktexts/acs/acs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _parse_task_data(cls, full_df: pd.DataFrame, task: ACSTaskMetadata) -> pd.Da
parsed_df = full_df

# Threshold the target column if necessary
if task.target_threshold is not None and task.get_target() not in parsed_df.columns:
if task.target is not None and task.target_threshold is not None and task.get_target() not in parsed_df.columns:
parsed_df[task.get_target()] = task.target_threshold.apply_to_column_data(parsed_df[task.target])

return parsed_df
8 changes: 6 additions & 2 deletions folktexts/acs/acs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def make_task(
cls,
name: str,
features: list[str],
target: str,
target: str = None,
sensitive_attribute: str = None,
target_threshold: Threshold = None,
multiple_choice_qa: MultipleChoiceQA = None,
Expand All @@ -53,7 +53,11 @@ def make_task(
) -> ACSTaskMetadata:
# Validate columns mappings exist
if not all(col in acs_columns_map for col in (features + [target])):
raise ValueError("Not all columns have mappings to text descriptions.")
missing_cols = {col for col in (features + [target]) if col not in acs_columns_map}
raise ValueError(
f"Not all columns have mappings to textual descriptions. "
f"Missing columns: {missing_cols}."
)

# Resolve target column name
target_col_name = (
Expand Down
30 changes: 17 additions & 13 deletions folktexts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,8 @@ def __init__(
f"Invalid `task` type: {type(self._task)}. "
f"Expected `TaskMetadata`.")

if not all(col in self.data.columns for col in (task.features + [task.get_target()])):
raise ValueError(
f"The following task columns were not found in the dataset: "
f"{list(set(task.features + [task.get_target()]) - set(self.data.columns))};"
)
# Validate data for this task
self._check_task_columns_are_in_df(task, data)

self._test_size = test_size
self._val_size = val_size or 0
Expand All @@ -86,17 +83,27 @@ def __init__(
if subsampling is not None:
self.subsample(subsampling)

@staticmethod
def _check_task_columns_are_in_df(task: TaskMetadata, df: pd.DataFrame, raise_: bool = True) -> bool:
available_cols = df.columns
required_cols = task.features + ([task.get_target()] if task.target else [])

if raise_ and not all(col in available_cols for col in required_cols):
raise ValueError(
f"The following required task columns were not found in the dataset: "
f"{list(set(required_cols) - set(available_cols))};"
)

return all(col in available_cols for col in required_cols)

@property
def data(self) -> pd.DataFrame:
return self._data

@data.setter
def data(self, new_data) -> pd.DataFrame:
# Check if task columns are in the data
if not all(col in new_data.columns for col in (self.task.features + [self.task.get_target()])):
raise ValueError(
f"Task columns not found in dataset: "
f"features={self.task.features}, target={self.task.get_target()}")
self._check_task_columns_are_in_df(self.task, new_data)

# Update data
self._data = new_data
Expand All @@ -120,10 +127,7 @@ def task(self) -> TaskMetadata:
@task.setter
def task(self, new_task: TaskMetadata):
# Check if task columns are in the data
if not all(col in self.data.columns for col in (new_task.features + [new_task.get_target()])):
raise ValueError(
f"Task columns not found in dataset: "
f"features={new_task.features}, target={new_task.get_target()}")
self._check_task_columns_are_in_df(new_task, self.data)

self._task = new_task

Expand Down
17 changes: 15 additions & 2 deletions folktexts/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,17 @@ def __post_init__(self):
# Add this task to the class-level dictionary
TaskMetadata._tasks[self.name] = self

# Check target is provided
if self.target is None:
logging.warning(
f"No target column provided for task '{self.name}'. "
f"Will not be able to generate predictions or use task Q&A prompts. "
f"Will still be able to generate row descriptions."
)
return

# If no question is explicitly provided, use the question from the target column
if self.multiple_choice_qa is None and self.direct_numeric_qa is None:
if self.multiple_choice_qa is None and self.direct_numeric_qa is None and self.target is not None:
logging.warning(
f"No question was explicitly provided for task '{self.name}'. "
f"Inferring from target column's default question ({self.get_target()}).")
Expand All @@ -85,6 +94,10 @@ def __hash__(self) -> int:

def get_target(self) -> str:
"""Resolves the name of the target column depending on `self.target_threshold`."""
if self.target is None:
logging.critical(f"No target column provided for task {self.name}.")
return None

if self.target_threshold is None:
return self.target
else:
Expand Down Expand Up @@ -160,7 +173,7 @@ def question(self) -> QAInterface:
q = self.multiple_choice_qa

if q is None:
raise ValueError(f"Invalid Q&A interface configured for task {self.name}.")
logging.critical(f"No Q&A interface provided for task {self.name}.")
return q

def get_row_description(self, row: pd.Series) -> str:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]

version = "0.0.23"
version = "0.0.24"
requires-python = ">=3.8"
dynamic = [
"readme",
Expand Down

0 comments on commit 546c791

Please sign in to comment.