Skip to content

Commit

Permalink
Merge pull request #30 from BiomedSciAI/pair_base_des
Browse files Browse the repository at this point in the history
Pair base description
  • Loading branch information
yoavkt authored Aug 6, 2024
2 parents 8db69ee + 52309a1 commit 2e7ffc9
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 0 deletions.
63 changes: 63 additions & 0 deletions gene_benchmark/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,31 @@

import mygene
import pandas as pd
import requests


def _fetch_ensembl_sequence(ensembl_gene_id):
"""
retries the base pair sequence of a given ensemble id REST ensenmbl API
available at https://rest.ensembl.org/.
Args:
----
ensembl_gene_id (str): an ensemble id
Returns:
-------
str: base pair sequence of the gene
"""
if not ensembl_gene_id:
return None
url = f"https://rest.ensembl.org/sequence/id/{ensembl_gene_id}?content-type=text/plain"
response = requests.get(url)
if response.status_code == 200:
return response.text
else:
return None


def missing_col_or_nan(df_series, indx):
Expand Down Expand Up @@ -650,3 +675,41 @@ def has_missing_columns(df_row, column_names):
"""
return any(missing_col_or_nan(df_row, col) for col in column_names)


class BasePairDescriptor(NCBIDescriptor):
"""A descriptor designated to describe each symbol by its base pair sequence."""

def __init__(self, allow_partial=False, description_col: str = "bp") -> None:
"""
Initialize descriptor class.
Args:
----
allow_partial (bool, optional):if true a partial description can be returned if false it will return None if the row is
missing name, symbol or summary. Defaults to False.
is_partial_row_function (callable): function to identify if a row only has partial knowledge.
description_col (str): column name for the base pair sequence column Defaults to False.
"""
super().__init__(allow_partial=allow_partial)
self.ensemble_des = NCBIDescriptor(allow_partial=allow_partial)
self.ensemble_des.needed_columns = ["ensembl.gene", "symbol"]
self.description_col = description_col

def _retrieve_dataframe_for_entities(
self, entities: list, first_description_only=False
):
ensembles = self.ensemble_des._retrieve_dataframe_for_entities(
entities, first_description_only=first_description_only
)
ensembles[self.description_col] = ensembles.apply(
lambda x: _fetch_ensembl_sequence(x["ensembl.gene"]), axis=1
)
return ensembles

def row_to_description(self, df_row: pd.Series) -> str:
return df_row[self.description_col]

def is_partial_description_row(self, df_row: pd.Series) -> bool:
return False
2 changes: 2 additions & 0 deletions gene_benchmark/deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sklearn.linear_model import LinearRegression, LogisticRegression

from gene_benchmark.descriptor import (
BasePairDescriptor,
CSVDescriptions,
MultiEntityTypeDescriptor,
NaiveDescriptor,
Expand Down Expand Up @@ -78,4 +79,5 @@ def get_gene_disease_multi_encoder(
"RandomForestClassifier": RandomForestClassifier,
"RandomForestRegressor": RandomForestRegressor,
"get_gene_disease_multi_encoder": get_gene_disease_multi_encoder,
"BasePairDescriptor": BasePairDescriptor,
}
18 changes: 18 additions & 0 deletions gene_benchmark/tests/test_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import pandas as pd

from gene_benchmark.descriptor import (
BasePairDescriptor,
CSVDescriptions,
MultiEntityTypeDescriptor,
NaiveDescriptor,
NCBIDescriptor,
_fetch_ensembl_sequence,
add_prefix_to_dict,
missing_col_or_nan,
)
Expand Down Expand Up @@ -334,3 +336,19 @@ def test_test_NaiveDescriptor(self):
entities = pd.Series(data=["PLAC4", "IAMNOTAGENE", "C3orf18"])
des = descriptor.describe(entities)
assert all(des == entities)

def test_ensemble_bp(self):
base_pair_seq = _fetch_ensembl_sequence("ENSG00000146648")
bp_org = "AGACGTCCGGGCAGCCCCCGGCGCAGCGCGGCCGCAGCAGCCTCCGCCCCCCGCACGGTGTGAGCGCCCGACGCGGCCGA"
assert base_pair_seq[: len(bp_org)] == bp_org

def test_base_pair_descriptor_describe(self):
gene_symbols = ["BRCA1", "TP53", "EGFR", "NOTGENE"]
bp = BasePairDescriptor()
bp_df = bp.describe(pd.Series(gene_symbols, index=gene_symbols))
bp.missing_entities
assert bp_df.shape[0] == 4
assert bp.missing_entities == ["NOTGENE"]
assert not bp_df["BRCA1"] is None
assert bp_df["NOTGENE"] is None
assert set(bp_df["BRCA1"]) == {"A", "C", "G", "T"}
1 change: 1 addition & 0 deletions scripts/extract_gene_text_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def do_gene_embedding(
output_folder / f"descriptions_{gene_symbols_name}_{model_name}.csv"
)
descriptions.to_csv(descriptions_ofname)
logger.info(f"Saved descriptions for {descriptions.shape} symbols")

logger.info(f"Getting encodings for {len(gene_symbol_list)} symbols")
encoded = encoder.encode(descriptions, randomize_missing=True, random_len=1024)
Expand Down

0 comments on commit 2e7ffc9

Please sign in to comment.