Skip to content

Commit

Permalink
Merge pull request #326 from mims-harvard/geneformer_server
Browse files Browse the repository at this point in the history
Geneformer server <> cellxgene test
  • Loading branch information
amva13 authored Oct 27, 2024
2 parents 7ab1fc1 + 7c2ccf0 commit 58c236b
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 33 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies:
- torchvision==0.16.1
- transformers==4.43.4
- yapf==0.40.2
- git+https://github.com/amva13/geneformer.git@main#egg=geneformer
- git+https://huggingface.co/ctheodoris/Geneformer.git@main#egg=geneformer

variables:
KMP_DUPLICATE_LIB_OK: "TRUE"
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ tiledbsoma>=1.7.2,<2.0.0
yapf>=0.40.2,<1.0.0

# github packages
git+https://github.com/amva13/geneformer.git@main#egg=geneformer
git+https://huggingface.co/ctheodoris/Geneformer.git@main#egg=geneformer
12 changes: 6 additions & 6 deletions tdc/model_server/tokenizers/geneformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import numpy as np
import scipy.sparse as sp

from geneformer import TranscriptomeTokenizer
from ...utils.load import pd_load, download_wrapper


class GeneformerTokenizer(TranscriptomeTokenizer):
class GeneformerTokenizer:
"""
Uses Geneformer Utils to parse zero-shot model server requests for tokenizing single-cell gene expression data.
Expand Down Expand Up @@ -53,7 +52,8 @@ def tokenize_cell_vectors(self,
cell_vector_adata,
target_sum=10_000,
chunk_size=512,
ensembl_id="ensembl_id"):
ensembl_id="ensembl_id",
ncounts="ncounts"):
"""
Tokenizing single-cell gene expression vectors formatted as anndata types.
Expand Down Expand Up @@ -96,16 +96,16 @@ def tokenize_cell_vectors(self,
for i in range(0, len(filter_pass_loc), chunk_size):
idx = filter_pass_loc[i:i + chunk_size]

n_counts = adata[idx].obs['ncounts'].values[:, None]
n_counts = adata[idx].obs[ncounts].values[:, None]
X_view = adata[idx, coding_miRNA_loc].X
X_norm = (X_view / n_counts * target_sum / norm_factor_vector)
X_norm = sp.csr_matrix(X_norm)

tokenized_cells += [
tokenized_cells.append([
self.rank_genes(X_norm[i].data,
coding_miRNA_tokens[X_norm[i].indices])
for i in range(X_norm.shape[0])
]
])

# add custom attributes for subview to dict
if self.custom_attr_name_dict is not None:
Expand Down
108 changes: 83 additions & 25 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import unittest
import shutil
import pytest
import mygene
import numpy as np

# temporary solution for relative imports in case TDC is not installed
# if TDC is installed, no need to use the following line
Expand All @@ -19,9 +21,17 @@
import requests


def get_ensembl_id(gene_symbols):
mg = mygene.MyGeneInfo()
return mg.querymany(gene_symbols,
scopes='symbol',
fields='ensembl.gene',
species='human')


def get_target_from_chembl(chembl_id):
# Query ChEMBL API for target information
chembl_url = f"https://www.ebi.ac.uk/chembl/api/data/target/{chembl_id}.json"
chembl_url = f"https://www.ebi.ac.uk/chembl/api/data/{chembl_id}.json"
response = requests.get(chembl_url)

if response.status_code == 200:
Expand Down Expand Up @@ -76,31 +86,79 @@ def setUp(self):
self.resource = cellxgene_census.CensusResource()

def testGeneformerTokenizer(self):
import anndata
from tdc.multi_pred.perturboutcome import PerturbOutcome
test_loader = PerturbOutcome(
name="scperturb_drug_AissaBenevolenskaya2021")
adata = test_loader.adata
print("swapping obs and var because scperturb violated convention...")
adata_flipped = anndata.AnnData(adata.X.T)
adata_flipped.obs = adata.var
adata_flipped.var = adata.obs
adata = adata_flipped
print("swap complete")
print("adding ensembl ids...")
adata.var["ensembl_id"] = adata.var["chembl-ID"].apply(
get_ensembl_id_from_chembl_id)
print("added ensembl_id column")

print(type(adata.var))
print(adata.var.columns)
print(type(adata.obs))
print(adata.obs.columns)
print("initializing tokenizer")

adata = self.resource.get_anndata(
var_value_filter=
"feature_id in ['ENSG00000161798', 'ENSG00000188229']",
obs_value_filter=
"sex == 'female' and cell_type in ['microglial cell', 'neuron']",
column_names={
"obs": [
"assay", "cell_type", "tissue", "tissue_general",
"suspension_type", "disease"
]
},
)
tokenizer = GeneformerTokenizer()
print("testing tokenizer")
x = tokenizer.tokenize_cell_vectors(adata)
assert x
x = tokenizer.tokenize_cell_vectors(adata,
ensembl_id="feature_id",
ncounts="n_measured_vars")
assert x[0]

# test Geneformer can serve the request
cells, _ = x
assert cells, "FAILURE: cells false-like. Value is = {}".format(cells)
assert len(cells) > 0, "FAILURE: length of cells <= 0 {}".format(cells)
from tdc import tdc_hf_interface
import torch
geneformer = tdc_hf_interface("Geneformer")
model = geneformer.load()

# using very few genes for these test cases so expecting empties... let's pad...
for idx in range(len(cells)):
x = cells[idx]
for j in range(len(x)):
v = x[j]
if len(v) < 2:
out = None
for _ in range(2 - len(v)):
if out is None:
out = np.append(v, 0) # pad with 0
else:
out = np.append(out, 0)
cells[idx][j] = out
if len(cells[idx]) < 512: # batch size
array = cells[idx]
# Calculate how many rows need to be added
n_rows_to_add = 512 - len(array)

# Create a padding array with [0, 0] for the remaining rows
padding = np.tile([0, 0], (n_rows_to_add, 1))

# Concatenate the original array with the padding array
cells[idx] = np.vstack((array, padding))

input_tensor = torch.tensor(cells)
out = []
try:
ctr = 0 # stop after some passes to avoid failure
for batch in input_tensor:
# build an attention mask
attention_mask = torch.tensor(
[[x[0] != 0, x[1] != 0] for x in batch])
out.append(model(batch, attention_mask=attention_mask))
if ctr == 2:
break
ctr += 1
except Exception as e:
raise Exception(e)

assert out, "FAILURE: Geneformer output is false-like. Value = {}".format(
out)
assert len(
out
) == 3, "length not matching ctr+1: {} vs {}. output was \n {}".format(
len(out), ctr + 1, out)

def tearDown(self):
try:
Expand Down

0 comments on commit 58c236b

Please sign in to comment.