diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index af0e69c2..6a51ab5c 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -79,6 +79,14 @@ def get_ensembl_id_from_chembl_id(chembl_id): return str(e) +def quant_layers(model): + layer_nums = [] + for name, parameter in model.named_parameters(): + if "layer" in name: + layer_nums += [int(name.split("layer.")[1].split(".")[0])] + return int(max(layer_nums)) + 1 + + class TestModelServer(unittest.TestCase): def setUp(self): @@ -146,7 +154,16 @@ def testGeneformerTokenizer(self): # build an attention mask attention_mask = torch.tensor( [[x[0] != 0, x[1] != 0] for x in batch]) - out.append(model(batch, attention_mask=attention_mask)) + outputs = model(batch, + attention_mask=attention_mask, + output_hidden_states=True) + layer_to_quant = quant_layers(model) + ( + -1 + ) # TODO note this can be parametrized to either 0 (extract last embedding layer) or -1 (second-to-last which is more generalized) + embs_i = outputs.hidden_states[layer_to_quant] + # there are "cls", "cell", and "gene" embeddings. we will only capture "gene", which is cell type specific. for "cell", you'd average out across unmasked gene embeddings per cell + embs = embs_i + out.append(embs) if ctr == 2: break ctr += 1 @@ -159,6 +176,9 @@ def testGeneformerTokenizer(self): out ) == 3, "length not matching ctr+1: {} vs {}. output was \n {}".format( len(out), ctr + 1, out) + print( + "Geneformer ran sucessfully. Find batch embedding example here:\n {}" + .format(out[0])) def tearDown(self): try: