From f3d884718a003c51b055dbba3f714905bede1351 Mon Sep 17 00:00:00 2001 From: Eden-Zohar Date: Tue, 9 Jul 2024 05:44:50 -0400 Subject: [PATCH 1/7] Add threshold for categorical tasks --- gene_benchmark/tasks.py | 11 +++++++++++ scripts/run_task.py | 13 +++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/gene_benchmark/tasks.py b/gene_benchmark/tasks.py index 89f4784..3341611 100644 --- a/gene_benchmark/tasks.py +++ b/gene_benchmark/tasks.py @@ -183,6 +183,7 @@ def __init__( model_name=None, sub_task=None, multi_label_th=0, + cat_label_th=0, overlap_entities=False, ) -> None: """ @@ -214,6 +215,7 @@ def __init__( frac=frac, sub_task=sub_task, multi_label_th=multi_label_th, + cat_label_th=cat_label_th, ) else: self.task_definitions = task @@ -398,6 +400,7 @@ def load_task_definition( frac=1, sub_task=None, multi_label_th=0, + cat_label_th=0, ): """ Loads and returns the task definition object. @@ -413,6 +416,8 @@ def load_task_definition( frac (float): load a unique fraction of the rows in the task, default 1 tasks_folder(str|None): Use an alternative task repository (default repository if None) sub_task(str|None): Use only one of the columns of the outcome as a binary task + multi_label_th (float): Threshold for multi label tasks outcomes + cat_label_th (float): Threshold for categorical tasks outcomes Returns: @@ -442,6 +447,12 @@ def load_task_definition( if multi_label_th != 0: outcomes = filter_low_threshold_features(outcomes, threshold=multi_label_th) + if cat_label_th != 0: + percent_outcome = outcomes.value_counts(normalize=True) + high_th_outcomes = percent_outcome[percent_outcome > cat_label_th].index + outcomes = outcomes[outcomes.apply(lambda x: x in high_th_outcomes)] + entities = entities.loc[outcomes.index] + return TaskDefinition( name=task_name, entities=entities, diff --git a/scripts/run_task.py b/scripts/run_task.py index 631fb32..b828443 100644 --- a/scripts/run_task.py +++ b/scripts/run_task.py @@ -94,9 +94,9 @@ def expand_task_list(task_list): ) @click.option( "--include-symbols-file", - "-e", + "-i", type=click.STRING, - help="A path to a yaml file containing symbols to be excluded", + help="A path to a yaml file containing symbols to be included", default=None, ) @click.option( @@ -137,6 +137,13 @@ def expand_task_list(task_list): help="threshold of imbalance of labels in multi class tasks", default=0.0, ) +@click.option( + "--cat-label-th", + "-cth", + type=click.FLOAT, + help="threshold of imbalance of labels in category tasks", + default=0.0, +) def main( tasks_folder, task_names, @@ -149,6 +156,7 @@ def main( sub_sample, scoring_type, multi_label_th, + cat_label_th, ): if tasks_folder is None: tasks_folder = Path(os.environ["GENE_BENCHMARK_TASKS_FOLDER"]) @@ -251,6 +259,7 @@ def main( encoding_post_processing=post_processing, sub_task=sub_task, multi_label_th=multi_label_th, + cat_label_th=cat_label_th, ) _ = task.run() report_df = get_report(task, output_file_name, append_results) From ec202a1e1453badb062c27577cb83f915bf50e0f Mon Sep 17 00:00:00 2001 From: Eden-Zohar Date: Tue, 16 Jul 2024 11:00:44 -0400 Subject: [PATCH 2/7] Add documentation --- gene_benchmark/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gene_benchmark/tasks.py b/gene_benchmark/tasks.py index 3341611..aad23d3 100644 --- a/gene_benchmark/tasks.py +++ b/gene_benchmark/tasks.py @@ -417,7 +417,7 @@ def load_task_definition( tasks_folder(str|None): Use an alternative task repository (default repository if None) sub_task(str|None): Use only one of the columns of the outcome as a binary task multi_label_th (float): Threshold for multi label tasks outcomes - cat_label_th (float): Threshold for categorical tasks outcomes + cat_label_th (float): Threshold for categorical tasks outcomes, only categories that have rates above the threshold will be included. Returns: From 4199d157ad4fa94aaeca214bc5e8358e148e4180 Mon Sep 17 00:00:00 2001 From: Eden-Zohar Date: Wed, 17 Jul 2024 04:54:28 -0400 Subject: [PATCH 3/7] Fix click documentation --- scripts/run_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/run_task.py b/scripts/run_task.py index b828443..18fbd74 100644 --- a/scripts/run_task.py +++ b/scripts/run_task.py @@ -73,7 +73,7 @@ def expand_task_list(task_list): "--task-names", "-t", type=click.STRING, - help="The output file name.", + help="The path to the task yamls, or the task name", default=["long vs short range TF"], multiple=True, ) @@ -81,7 +81,7 @@ def expand_task_list(task_list): "--model-config-files", "-m", type=click.STRING, - help="Append results to the files", + help="path to model config files", default=[str(Path(__file__).parent / "models" / "ncbi_multi_class.yaml")], multiple=True, ) From 24ffc1ddff163f8c89310af5f48ba2105f8fcd9a Mon Sep 17 00:00:00 2001 From: Eden-Zohar Date: Wed, 17 Jul 2024 06:10:02 -0400 Subject: [PATCH 4/7] Add unittest for categorical threshold --- .../tasks/imbalanced_cat/entities.csv | 111 ++++++++++++++++++ .../tasks/imbalanced_cat/outcomes.csv | 111 ++++++++++++++++++ gene_benchmark/tests/test_tasks.py | 19 +++ 3 files changed, 241 insertions(+) create mode 100644 gene_benchmark/tests/resources/tasks/imbalanced_cat/entities.csv create mode 100644 gene_benchmark/tests/resources/tasks/imbalanced_cat/outcomes.csv diff --git a/gene_benchmark/tests/resources/tasks/imbalanced_cat/entities.csv b/gene_benchmark/tests/resources/tasks/imbalanced_cat/entities.csv new file mode 100644 index 0000000..5b764ac --- /dev/null +++ b/gene_benchmark/tests/resources/tasks/imbalanced_cat/entities.csv @@ -0,0 +1,111 @@ +symbol +HOXA11 +PAX6 +ZIC2 +COASY +PAX2 +HOXA9 +HOXA1 +HOXA2 +HOXA3 +HOXA5 +HOXA6 +HOXA13 +EVX1 +TLX1 +KAZALD1 +FGF8 +PITX3 +GBF1 +HOXB6 +HSD17B1 +CNTNAP1 +SOX6 +IRX4 +DLX2 +PROX1 +ELOVL3 +HOXB8 +HOXB5 +HOXB3 +HOXB1 +HOXA7 +SOX21 +FOXA2 +PAX1 +NKX2-4 +NKX2-2 +FOXP2 +HOXD1 +HOXD3 +HOXD9 +HOXD10 +HOXD11 +HOXD13 +FOXA1 +NFATC1 +NKX2-8 +LMX1B +SIX3 +ZIC5 +LMO4 +ACTA1 +DLX1 +HTR7 +NKX6-2 +POU4F2 +POU4F1 +ZIC1 +HOXB13 +IRX6 +EIF4E3 +PROK2 +NKX6-1 +EBF1 +TLX3 +SHH +EN2 +OTX2 +LMO1 +GBX2 +SOX14 +ZFPM2 +HOXD4 +HOXD12 +IRX1 +IRX2 +SIX2 +HOXB9 +HOXB2 +EVX2 +ZIC4 +HOXD8 +IRX5 +IRX3 +MAF +SATB1 +HOXB4 +NR2F2 +FOXD3 +PAX5 +HOXA4 +PAX9 +HOXA10 +SALL3 +HOXB7 +DACH1 +BRCA1 +ATP6V0A1 +TUBG2 +MRPL43 +DHX8 +ST7 +MEOX1 +CIAPIN1 +CREBBP +MPO +SOX8 +ABCC8 +TMEM132A +ZNF263 +KDM7A diff --git a/gene_benchmark/tests/resources/tasks/imbalanced_cat/outcomes.csv b/gene_benchmark/tests/resources/tasks/imbalanced_cat/outcomes.csv new file mode 100644 index 0000000..aa0e9fd --- /dev/null +++ b/gene_benchmark/tests/resources/tasks/imbalanced_cat/outcomes.csv @@ -0,0 +1,111 @@ +Outcomes +class_4 +class_2 +class_2 +class_2 +class_2 +class_0 +class_1 +class_2 +class_3 +class_0 +class_0 +class_2 +class_3 +class_0 +class_2 +class_3 +class_1 +class_0 +class_3 +class_1 +class_0 +class_0 +class_2 +class_2 +class_1 +class_1 +class_1 +class_3 +class_2 +class_1 +class_3 +class_1 +class_1 +class_2 +class_0 +class_2 +class_0 +class_1 +class_0 +class_2 +class_2 +class_1 +class_3 +class_5 +class_5 +class_1 +class_3 +class_3 +class_3 +class_0 +class_2 +class_2 +class_1 +class_0 +class_2 +class_2 +class_0 +class_1 +class_2 +class_3 +class_5 +class_2 +class_2 +class_2 +class_2 +class_0 +class_1 +class_1 +class_1 +class_3 +class_5 +class_3 +class_1 +class_2 +class_0 +class_3 +class_3 +class_2 +class_2 +class_3 +class_0 +class_0 +class_0 +class_1 +class_2 +class_2 +class_2 +class_3 +class_0 +class_2 +class_3 +class_2 +class_2 +class_0 +class_2 +class_3 +class_1 +class_2 +class_2 +class_2 +class_3 +class_2 +class_2 +class_0 +class_2 +class_3 +class_1 +class_2 +class_2 +class_2 diff --git a/gene_benchmark/tests/test_tasks.py b/gene_benchmark/tests/test_tasks.py index d349727..91412f8 100644 --- a/gene_benchmark/tests/test_tasks.py +++ b/gene_benchmark/tests/test_tasks.py @@ -553,3 +553,22 @@ def test_entities_task_inclusion(self): include_symbols=["ATP6V0A1", "TUBG2", "MRPL43", "DHX8"], ) assert full_entity_task.task_definitions.entities.shape[0] == 4 + + def test_multiclass_task_with_th(self): + task_name = "imbalanced_cat" + mpnet_name = "sentence-transformers/all-mpnet-base-v2" + full_entity_task = EntitiesTask( + task=task_name, + encoder=mpnet_name, + description_builder=NCBIDescriptor(), + base_model=LogisticRegression(max_iter=5000), + cv=5, + scoring=["roc_auc_ovr_weighted"], + tasks_folder=_get_test_tasks_folder(), + cat_label_th=0.04, + ) + full_entity_task.run() + this_run_df = pd.DataFrame.from_dict(full_entity_task.summary(), orient="index") + test_scores = this_run_df["test_roc_auc_ovr_weighted"].split(",") + test_scores = list(map(float, test_scores)) + assert np.nan not in test_scores From b723508c0166d16ddf15c4d701faacb3a6db04d8 Mon Sep 17 00:00:00 2001 From: Eden-Zohar Date: Wed, 17 Jul 2024 06:17:16 -0400 Subject: [PATCH 5/7] comment out humantf task creation because of site maintenance --- .github/workflows/python-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index d32d5da..bccfe11 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -35,7 +35,7 @@ jobs: python scripts/tasks_retrieval/Genecorpus_tasks_creation.py --allow-downloads True python scripts/tasks_retrieval/HLA_task_creation.py --allow-downloads True python scripts/tasks_retrieval/HPA_tasks_creation.py --allow-downloads True - python scripts/tasks_retrieval/humantfs_task_creation.py --allow-downloads True + # python scripts/tasks_retrieval/humantfs_task_creation.py --allow-downloads True python scripts/tasks_retrieval/Reactome_tasks_creation.py --allow-downloads True - name: Test with pytest From 46432e20094267535b4f71ff36fd686aac92010c Mon Sep 17 00:00:00 2001 From: Eden-Zohar Date: Wed, 17 Jul 2024 06:30:49 -0400 Subject: [PATCH 6/7] Fix unittest bug --- gene_benchmark/tests/test_tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gene_benchmark/tests/test_tasks.py b/gene_benchmark/tests/test_tasks.py index 91412f8..5a11204 100644 --- a/gene_benchmark/tests/test_tasks.py +++ b/gene_benchmark/tests/test_tasks.py @@ -568,7 +568,7 @@ def test_multiclass_task_with_th(self): cat_label_th=0.04, ) full_entity_task.run() - this_run_df = pd.DataFrame.from_dict(full_entity_task.summary(), orient="index") + this_run_df = full_entity_task.summary() test_scores = this_run_df["test_roc_auc_ovr_weighted"].split(",") test_scores = list(map(float, test_scores)) assert np.nan not in test_scores From fe06b81194884d2f94f3a40df27d01591231a84f Mon Sep 17 00:00:00 2001 From: Eden-Zohar Date: Wed, 17 Jul 2024 06:41:33 -0400 Subject: [PATCH 7/7] change number of total tasks in pytest --- gene_benchmark/tests/test_tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gene_benchmark/tests/test_tasks.py b/gene_benchmark/tests/test_tasks.py index 5a11204..7cc68da 100644 --- a/gene_benchmark/tests/test_tasks.py +++ b/gene_benchmark/tests/test_tasks.py @@ -412,7 +412,7 @@ def test_list_subtests(self): def test_get_task_names(self): tasks_folder = _get_tasks_folder() names = list(get_tasks_definition_names(tasks_folder)) - assert len(names) >= 70 + assert len(names) >= 65 assert "RNA cancer distribution" in names assert "bivalent vs non-methylated" in names