Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add threshold for categorical tasks #27

Merged
merged 7 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions gene_benchmark/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(
model_name=None,
sub_task=None,
multi_label_th=0,
cat_label_th=0,
overlap_entities=False,
) -> None:
"""
Expand Down Expand Up @@ -214,6 +215,7 @@ def __init__(
frac=frac,
sub_task=sub_task,
multi_label_th=multi_label_th,
cat_label_th=cat_label_th,
)
else:
self.task_definitions = task
Expand Down Expand Up @@ -398,6 +400,7 @@ def load_task_definition(
frac=1,
sub_task=None,
multi_label_th=0,
cat_label_th=0,
):
"""
Loads and returns the task definition object.
Expand All @@ -413,6 +416,8 @@ def load_task_definition(
frac (float): load a unique fraction of the rows in the task, default 1
tasks_folder(str|None): Use an alternative task repository (default repository if None)
sub_task(str|None): Use only one of the columns of the outcome as a binary task
multi_label_th (float): Threshold for multi label tasks outcomes
cat_label_th (float): Threshold for categorical tasks outcomes
edenjenzohar marked this conversation as resolved.
Show resolved Hide resolved


Returns:
Expand Down Expand Up @@ -442,6 +447,12 @@ def load_task_definition(
if multi_label_th != 0:
outcomes = filter_low_threshold_features(outcomes, threshold=multi_label_th)

if cat_label_th != 0:
percent_outcome = outcomes.value_counts(normalize=True)
high_th_outcomes = percent_outcome[percent_outcome > cat_label_th].index
outcomes = outcomes[outcomes.apply(lambda x: x in high_th_outcomes)]
entities = entities.loc[outcomes.index]

return TaskDefinition(
name=task_name,
entities=entities,
Expand Down
13 changes: 11 additions & 2 deletions scripts/run_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def expand_task_list(task_list):
)
@click.option(
"--include-symbols-file",
"-e",
"-i",
type=click.STRING,
help="A path to a yaml file containing symbols to be excluded",
help="A path to a yaml file containing symbols to be included",
default=None,
)
@click.option(
Expand Down Expand Up @@ -137,6 +137,13 @@ def expand_task_list(task_list):
help="threshold of imbalance of labels in multi class tasks",
default=0.0,
)
@click.option(
"--cat-label-th",
"-cth",
type=click.FLOAT,
help="threshold of imbalance of labels in category tasks",
default=0.0,
)
def main(
tasks_folder,
task_names,
Expand All @@ -149,6 +156,7 @@ def main(
sub_sample,
scoring_type,
multi_label_th,
cat_label_th,
):
if tasks_folder is None:
tasks_folder = Path(os.environ["GENE_BENCHMARK_TASKS_FOLDER"])
Expand Down Expand Up @@ -251,6 +259,7 @@ def main(
encoding_post_processing=post_processing,
sub_task=sub_task,
multi_label_th=multi_label_th,
cat_label_th=cat_label_th,
)
_ = task.run()
report_df = get_report(task, output_file_name, append_results)
Expand Down
Loading