Skip to content

Commit

Permalink
Merge pull request #3 from BiomedSciAI/remove_one_hot
Browse files Browse the repository at this point in the history
Removed the change from string binary to int based
  • Loading branch information
yoavkt authored Jun 18, 2024
2 parents d53a9d0 + 7c97f09 commit ab03942
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
python scripts/tasks_retrival/HLA_task_creation.py --allow-downloads True
python scripts/tasks_retrival/HPA_tasks_creation.py --allow-downloads True
python scripts/tasks_retrival/humantfs_task_creation.py --allow-downloads True
python scripts/tasks_retrival/Reactome_tasks_creation.py --use-local-files False
python scripts/tasks_retrival/Reactome_tasks_creation.py --allow-downloads True
- name: Test with pytest
run: |
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.vscode/settings.json
.DS_Store
32 changes: 7 additions & 25 deletions gene_benchmark/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,6 @@
)


def is_binary_outcomes(outcomes: pd.Series | pd.DataFrame):
"""
Checks if a vector represents a binary prediction task.
Args:
----
outcomes (pd.series): a series containing the labels for prediction
Returns:
-------
bool: True if the series represents binary classification
"""
if isinstance(outcomes, pd.Series):
return outcomes.nunique() == 2
else:
return False


def convert_to_mat(data: pd.Series | pd.DataFrame):
"""
Convert a 1d series or df with np arrays as values to a 2D/3D np array.
Expand Down Expand Up @@ -238,11 +219,7 @@ def run(self, error_score=np.nan):
descriptions_df = self._create_encoding()
encodings_df = self.encoder.encode(descriptions_df)
encodings = self._post_processing_mat(encodings_df)

if is_binary_outcomes(self.task_definitions.outcomes):
outcomes = pd.get_dummies(self.task_definitions.outcomes).iloc[:, 0]
else:
outcomes = self.task_definitions.outcomes
outcomes = self.task_definitions.outcomes

cs_val = cross_validate(
self.base_prediction_model,
Expand Down Expand Up @@ -272,7 +249,12 @@ def summary(self):
summary_dict["sub_task"] = self.task_definitions.sub_task
summary_dict["base_prediction_model"] = str(self.base_prediction_model)
summary_dict["sample_size"] = self.task_definitions.outcomes.shape[0]
if is_binary_outcomes(self.task_definitions.outcomes):
is_bin = (
self.task_definitions.outcomes.nunique() == 2
if isinstance(self.task_definitions.outcomes, pd.Series)
else False
)
if is_bin:
summary_dict["class_sizes"] = ",".join(
[str(v) for v in self.task_definitions.outcomes.value_counts().values]
)
Expand Down
11 changes: 0 additions & 11 deletions gene_benchmark/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
convert_to_mat,
filter_exclusion,
get_tasks_definition_names,
is_binary_outcomes,
load_task_definition,
sub_sample_task_frames,
)
Expand Down Expand Up @@ -352,16 +351,6 @@ def test_task_modification_3D_to_2D_concat(self):
)
assert full_entity_task._post_processing_mat(threeDmat).shape == (3, 6)

def test_is_binary_outcomes(self):
assert is_binary_outcomes(pd.Series(["Blue", "Blue", "Blue", "Green", "Green"]))
assert is_binary_outcomes(pd.Series([0, 0, 1, 1, 0, 1]))
assert is_binary_outcomes(pd.Series([True, False, True, False]))
assert not is_binary_outcomes(
pd.Series(["Blue", "Blue", "Blue", "Green", "Green", "yellow"])
)
assert not is_binary_outcomes(pd.Series([0, 0, 1, 1, 0, 1, 2, 2]))
assert not is_binary_outcomes(pd.Series([True, False, True, False, 5]))

def test_load_multilabel(self):
task_name = "Protein class"
mpnet_name = "sentence-transformers/all-mpnet-base-v2"
Expand Down

0 comments on commit ab03942

Please sign in to comment.