diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index d32d5da..bccfe11 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -35,7 +35,7 @@ jobs: python scripts/tasks_retrieval/Genecorpus_tasks_creation.py --allow-downloads True python scripts/tasks_retrieval/HLA_task_creation.py --allow-downloads True python scripts/tasks_retrieval/HPA_tasks_creation.py --allow-downloads True - python scripts/tasks_retrieval/humantfs_task_creation.py --allow-downloads True + # python scripts/tasks_retrieval/humantfs_task_creation.py --allow-downloads True python scripts/tasks_retrieval/Reactome_tasks_creation.py --allow-downloads True - name: Test with pytest diff --git a/gene_benchmark/tasks.py b/gene_benchmark/tasks.py index 89f4784..aad23d3 100644 --- a/gene_benchmark/tasks.py +++ b/gene_benchmark/tasks.py @@ -183,6 +183,7 @@ def __init__( model_name=None, sub_task=None, multi_label_th=0, + cat_label_th=0, overlap_entities=False, ) -> None: """ @@ -214,6 +215,7 @@ def __init__( frac=frac, sub_task=sub_task, multi_label_th=multi_label_th, + cat_label_th=cat_label_th, ) else: self.task_definitions = task @@ -398,6 +400,7 @@ def load_task_definition( frac=1, sub_task=None, multi_label_th=0, + cat_label_th=0, ): """ Loads and returns the task definition object. @@ -413,6 +416,8 @@ def load_task_definition( frac (float): load a unique fraction of the rows in the task, default 1 tasks_folder(str|None): Use an alternative task repository (default repository if None) sub_task(str|None): Use only one of the columns of the outcome as a binary task + multi_label_th (float): Threshold for multi label tasks outcomes + cat_label_th (float): Threshold for categorical tasks outcomes, only categories that have rates above the threshold will be included. Returns: @@ -442,6 +447,12 @@ def load_task_definition( if multi_label_th != 0: outcomes = filter_low_threshold_features(outcomes, threshold=multi_label_th) + if cat_label_th != 0: + percent_outcome = outcomes.value_counts(normalize=True) + high_th_outcomes = percent_outcome[percent_outcome > cat_label_th].index + outcomes = outcomes[outcomes.apply(lambda x: x in high_th_outcomes)] + entities = entities.loc[outcomes.index] + return TaskDefinition( name=task_name, entities=entities, diff --git a/gene_benchmark/tests/resources/tasks/imbalanced_cat/entities.csv b/gene_benchmark/tests/resources/tasks/imbalanced_cat/entities.csv new file mode 100644 index 0000000..5b764ac --- /dev/null +++ b/gene_benchmark/tests/resources/tasks/imbalanced_cat/entities.csv @@ -0,0 +1,111 @@ +symbol +HOXA11 +PAX6 +ZIC2 +COASY +PAX2 +HOXA9 +HOXA1 +HOXA2 +HOXA3 +HOXA5 +HOXA6 +HOXA13 +EVX1 +TLX1 +KAZALD1 +FGF8 +PITX3 +GBF1 +HOXB6 +HSD17B1 +CNTNAP1 +SOX6 +IRX4 +DLX2 +PROX1 +ELOVL3 +HOXB8 +HOXB5 +HOXB3 +HOXB1 +HOXA7 +SOX21 +FOXA2 +PAX1 +NKX2-4 +NKX2-2 +FOXP2 +HOXD1 +HOXD3 +HOXD9 +HOXD10 +HOXD11 +HOXD13 +FOXA1 +NFATC1 +NKX2-8 +LMX1B +SIX3 +ZIC5 +LMO4 +ACTA1 +DLX1 +HTR7 +NKX6-2 +POU4F2 +POU4F1 +ZIC1 +HOXB13 +IRX6 +EIF4E3 +PROK2 +NKX6-1 +EBF1 +TLX3 +SHH +EN2 +OTX2 +LMO1 +GBX2 +SOX14 +ZFPM2 +HOXD4 +HOXD12 +IRX1 +IRX2 +SIX2 +HOXB9 +HOXB2 +EVX2 +ZIC4 +HOXD8 +IRX5 +IRX3 +MAF +SATB1 +HOXB4 +NR2F2 +FOXD3 +PAX5 +HOXA4 +PAX9 +HOXA10 +SALL3 +HOXB7 +DACH1 +BRCA1 +ATP6V0A1 +TUBG2 +MRPL43 +DHX8 +ST7 +MEOX1 +CIAPIN1 +CREBBP +MPO +SOX8 +ABCC8 +TMEM132A +ZNF263 +KDM7A diff --git a/gene_benchmark/tests/resources/tasks/imbalanced_cat/outcomes.csv b/gene_benchmark/tests/resources/tasks/imbalanced_cat/outcomes.csv new file mode 100644 index 0000000..aa0e9fd --- /dev/null +++ b/gene_benchmark/tests/resources/tasks/imbalanced_cat/outcomes.csv @@ -0,0 +1,111 @@ +Outcomes +class_4 +class_2 +class_2 +class_2 +class_2 +class_0 +class_1 +class_2 +class_3 +class_0 +class_0 +class_2 +class_3 +class_0 +class_2 +class_3 +class_1 +class_0 +class_3 +class_1 +class_0 +class_0 +class_2 +class_2 +class_1 +class_1 +class_1 +class_3 +class_2 +class_1 +class_3 +class_1 +class_1 +class_2 +class_0 +class_2 +class_0 +class_1 +class_0 +class_2 +class_2 +class_1 +class_3 +class_5 +class_5 +class_1 +class_3 +class_3 +class_3 +class_0 +class_2 +class_2 +class_1 +class_0 +class_2 +class_2 +class_0 +class_1 +class_2 +class_3 +class_5 +class_2 +class_2 +class_2 +class_2 +class_0 +class_1 +class_1 +class_1 +class_3 +class_5 +class_3 +class_1 +class_2 +class_0 +class_3 +class_3 +class_2 +class_2 +class_3 +class_0 +class_0 +class_0 +class_1 +class_2 +class_2 +class_2 +class_3 +class_0 +class_2 +class_3 +class_2 +class_2 +class_0 +class_2 +class_3 +class_1 +class_2 +class_2 +class_2 +class_3 +class_2 +class_2 +class_0 +class_2 +class_3 +class_1 +class_2 +class_2 +class_2 diff --git a/gene_benchmark/tests/test_tasks.py b/gene_benchmark/tests/test_tasks.py index d349727..7cc68da 100644 --- a/gene_benchmark/tests/test_tasks.py +++ b/gene_benchmark/tests/test_tasks.py @@ -412,7 +412,7 @@ def test_list_subtests(self): def test_get_task_names(self): tasks_folder = _get_tasks_folder() names = list(get_tasks_definition_names(tasks_folder)) - assert len(names) >= 70 + assert len(names) >= 65 assert "RNA cancer distribution" in names assert "bivalent vs non-methylated" in names @@ -553,3 +553,22 @@ def test_entities_task_inclusion(self): include_symbols=["ATP6V0A1", "TUBG2", "MRPL43", "DHX8"], ) assert full_entity_task.task_definitions.entities.shape[0] == 4 + + def test_multiclass_task_with_th(self): + task_name = "imbalanced_cat" + mpnet_name = "sentence-transformers/all-mpnet-base-v2" + full_entity_task = EntitiesTask( + task=task_name, + encoder=mpnet_name, + description_builder=NCBIDescriptor(), + base_model=LogisticRegression(max_iter=5000), + cv=5, + scoring=["roc_auc_ovr_weighted"], + tasks_folder=_get_test_tasks_folder(), + cat_label_th=0.04, + ) + full_entity_task.run() + this_run_df = full_entity_task.summary() + test_scores = this_run_df["test_roc_auc_ovr_weighted"].split(",") + test_scores = list(map(float, test_scores)) + assert np.nan not in test_scores diff --git a/scripts/run_task.py b/scripts/run_task.py index 631fb32..18fbd74 100644 --- a/scripts/run_task.py +++ b/scripts/run_task.py @@ -73,7 +73,7 @@ def expand_task_list(task_list): "--task-names", "-t", type=click.STRING, - help="The output file name.", + help="The path to the task yamls, or the task name", default=["long vs short range TF"], multiple=True, ) @@ -81,7 +81,7 @@ def expand_task_list(task_list): "--model-config-files", "-m", type=click.STRING, - help="Append results to the files", + help="path to model config files", default=[str(Path(__file__).parent / "models" / "ncbi_multi_class.yaml")], multiple=True, ) @@ -94,9 +94,9 @@ def expand_task_list(task_list): ) @click.option( "--include-symbols-file", - "-e", + "-i", type=click.STRING, - help="A path to a yaml file containing symbols to be excluded", + help="A path to a yaml file containing symbols to be included", default=None, ) @click.option( @@ -137,6 +137,13 @@ def expand_task_list(task_list): help="threshold of imbalance of labels in multi class tasks", default=0.0, ) +@click.option( + "--cat-label-th", + "-cth", + type=click.FLOAT, + help="threshold of imbalance of labels in category tasks", + default=0.0, +) def main( tasks_folder, task_names, @@ -149,6 +156,7 @@ def main( sub_sample, scoring_type, multi_label_th, + cat_label_th, ): if tasks_folder is None: tasks_folder = Path(os.environ["GENE_BENCHMARK_TASKS_FOLDER"]) @@ -251,6 +259,7 @@ def main( encoding_post_processing=post_processing, sub_task=sub_task, multi_label_th=multi_label_th, + cat_label_th=cat_label_th, ) _ = task.run() report_df = get_report(task, output_file_name, append_results)