Skip to content

Commit

Permalink
yapf and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed Oct 27, 2024
1 parent 9c2f7d9 commit 7c2ccf0
Showing 1 changed file with 31 additions and 29 deletions.
60 changes: 31 additions & 29 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@

def get_ensembl_id(gene_symbols):
mg = mygene.MyGeneInfo()
return mg.querymany(gene_symbols, scopes='symbol', fields='ensembl.gene', species='human')
return mg.querymany(gene_symbols,
scopes='symbol',
fields='ensembl.gene',
species='human')


def get_target_from_chembl(chembl_id):
Expand Down Expand Up @@ -85,14 +88,21 @@ def setUp(self):
def testGeneformerTokenizer(self):

adata = self.resource.get_anndata(
var_value_filter = "feature_id in ['ENSG00000161798', 'ENSG00000188229']",
obs_value_filter = "sex == 'female' and cell_type in ['microglial cell', 'neuron']",
column_names = {"obs": ["assay", "cell_type", "tissue", "tissue_general", "suspension_type", "disease"]},
var_value_filter=
"feature_id in ['ENSG00000161798', 'ENSG00000188229']",
obs_value_filter=
"sex == 'female' and cell_type in ['microglial cell', 'neuron']",
column_names={
"obs": [
"assay", "cell_type", "tissue", "tissue_general",
"suspension_type", "disease"
]
},
)
print("initializing tokenizer")
tokenizer = GeneformerTokenizer()
print("testing tokenizer")
x = tokenizer.tokenize_cell_vectors(adata, ensembl_id="feature_id", ncounts="n_measured_vars")
x = tokenizer.tokenize_cell_vectors(adata,
ensembl_id="feature_id",
ncounts="n_measured_vars")
assert x[0]

# test Geneformer can serve the request
Expand All @@ -103,21 +113,21 @@ def testGeneformerTokenizer(self):
import torch
geneformer = tdc_hf_interface("Geneformer")
model = geneformer.load()
# tokenized_data = tokenizer.create_dataset(cells, metadata)
print("using very few genes for these test cases so expecting empties... let's pad...")

# using very few genes for these test cases so expecting empties... let's pad...
for idx in range(len(cells)):
x = cells[idx]
for j in range(len(x)):
v = x[j]
if len(v) < 2:
out = None
for _ in range(2-len(v)):
for _ in range(2 - len(v)):
if out is None:
out = np.append(v, 0) # pad with 0
else:
out = np.append(out, 0)
cells[idx][j] = out
if len(cells[idx]) < 512: # batch size
if len(cells[idx]) < 512: # batch size
array = cells[idx]
# Calculate how many rows need to be added
n_rows_to_add = 512 - len(array)
Expand All @@ -129,34 +139,26 @@ def testGeneformerTokenizer(self):
cells[idx] = np.vstack((array, padding))

input_tensor = torch.tensor(cells)
# input_tensor = torch.squeeze(input_tensor)
out = []
try:
ctr = 0 # stop after some passes to avoid failure
ctr = 0 # stop after some passes to avoid failure
for batch in input_tensor:
# build an attention mask
attention_mask = torch.tensor([[x[0]!=0, x[1]!=0] for x in batch])
attention_mask = torch.tensor(
[[x[0] != 0, x[1] != 0] for x in batch])
out.append(model(batch, attention_mask=attention_mask))
if ctr == 2:
break
ctr += 1
except Exception as e:
# raise Exception("tensor shape is", input_tensor.shape, "exception was:", e, "\n cells was\n", cells)
raise Exception(e)

# input_tensor = torch.tensor(cells)
# input_tensor_squeezed = torch.squeeze(input_tensor)
# x = input_tensor_squeezed.shape[0]
# y = input_tensor_squeezed.shape[1]
# out = None # try-except block
# try:
# input_tensor_squeezed = input_tensor_squeezed.reshape(x, y)
# out = model(input_tensor_squeezed)
# except Exception as e:
# raise Exception("tensor shape is", input_tensor.shape, "exception was: {}".format(e), "input_tensor_squeezed is\n", input_tensor, "\n\ninput_tensor normal is: {}".format(input_tensor))
assert out, "FAILURE: Geneformer output is false-like. Value = {}".format(out)
print(out)
assert len(out) == 3, "length not matching ctr+1: {} vs {}. output was \n {}".format(len(out), ctr + 1, out)

assert out, "FAILURE: Geneformer output is false-like. Value = {}".format(
out)
assert len(
out
) == 3, "length not matching ctr+1: {} vs {}. output was \n {}".format(
len(out), ctr + 1, out)

def tearDown(self):
try:
Expand Down

0 comments on commit 7c2ccf0

Please sign in to comment.