Skip to content

Commit

Permalink
Merge branch 'main' into mean_context_torch
Browse files Browse the repository at this point in the history
  • Loading branch information
yoavkt committed Aug 21, 2024
2 parents 388311a + 7cea557 commit 2e881e7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
29 changes: 24 additions & 5 deletions gene_benchmark/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,20 @@
import requests


def _get_ensemble(ens_data):
if "ensembl.gene" in ens_data.index and "ensembl" in ens_data.index:
if pd.isna(ens_data["ensembl.gene"]):
return ens_data["ensembl"][0]["gene"]
else:
return ens_data["ensembl.gene"]
if "ensembl" in ens_data.index:
return ens_data["ensembl"][0]["gene"]
if "ensembl.gene" in ens_data.index:
return ens_data["ensembl.gene"]
warnings.warn(f"Unknown ensemble format {ens_data}")
return None


def _fetch_ensembl_sequence(ensembl_gene_id):
"""
retries the base pair sequence of a given ensemble id REST ensenmbl API
Expand All @@ -24,10 +38,15 @@ def _fetch_ensembl_sequence(ensembl_gene_id):
if not ensembl_gene_id:
return None
url = f"https://rest.ensembl.org/sequence/id/{ensembl_gene_id}?content-type=text/plain"
response = requests.get(url)
if response.status_code == 200:
return response.text
else:

try:
response = requests.get(url)
if response.status_code == 200:
return response.text
else:
raise (f"Request failed for {ensembl_gene_id}")
except:
warnings.warn(f"Request failed for {ensembl_gene_id}")
return None


Expand Down Expand Up @@ -704,7 +723,7 @@ def _retrieve_dataframe_for_entities(
entities, first_description_only=first_description_only
)
ensembles[self.description_col] = ensembles.apply(
lambda x: _fetch_ensembl_sequence(x["ensembl.gene"]), axis=1
lambda x: _fetch_ensembl_sequence(_get_ensemble(x)), axis=1
)
return ensembles

Expand Down
12 changes: 12 additions & 0 deletions gene_benchmark/tests/test_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,15 @@ def test_base_pair_descriptor_describe(self):
assert not bp_df["BRCA1"] is None
assert bp_df["NOTGENE"] is None
assert set(bp_df["BRCA1"]) == {"A", "C", "G", "T"}

def test_base_pair_descriptor_describe_mul_ens(self):
gene_symbols = ["OR5V1", "SLC12A7"]
bp = BasePairDescriptor()
bp_df = bp.describe(pd.Series(gene_symbols, index=gene_symbols))
assert all(bp_df.notna())

def test_base_pair_descriptor_describe_mix_ens(self):
gene_symbols = ["OR5V1", "SLC12A7", "BRCA1", "TP53"]
bp = BasePairDescriptor()
bp_df = bp.describe(pd.Series(gene_symbols, index=gene_symbols))
assert all(bp_df.notna())

0 comments on commit 2e881e7

Please sign in to comment.