From 26355f33f26fb52bd0f67b30628bfc35a9ed8b6c Mon Sep 17 00:00:00 2001 From: Alejandro Velez-Arce Date: Mon, 11 Nov 2024 19:28:36 -0500 Subject: [PATCH 1/2] parse out the embeddings from geneformer in test case --- tdc/test/test_model_server.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index af0e69c2..29a1df09 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -79,6 +79,14 @@ def get_ensembl_id_from_chembl_id(chembl_id): return str(e) +def quant_layers(model): + layer_nums = [] + for name, parameter in model.named_parameters(): + if "layer" in name: + layer_nums += [int(name.split("layer.")[1].split(".")[0])] + return int(max(layer_nums)) + 1 + + class TestModelServer(unittest.TestCase): def setUp(self): @@ -146,7 +154,14 @@ def testGeneformerTokenizer(self): # build an attention mask attention_mask = torch.tensor( [[x[0] != 0, x[1] != 0] for x in batch]) - out.append(model(batch, attention_mask=attention_mask)) + outputs = model(batch, attention_mask=attention_mask) + layer_to_quant = quant_layers(model) + ( + -1 + ) # TODO note this can be parametrized to either 0 (extract last embedding layer) or -1 (second-to-last which is more generalized) + embs_i = outputs.hidden_states[layer_to_quant] + # there are "cls", "cell", and "gene" embeddings. we will only capture "gene", which is cell type specific. for "cell", you'd average out across unmasked gene embeddings per cell + embs = embs_i + out.append(embs) if ctr == 2: break ctr += 1 @@ -159,6 +174,9 @@ def testGeneformerTokenizer(self): out ) == 3, "length not matching ctr+1: {} vs {}. output was \n {}".format( len(out), ctr + 1, out) + print( + "Geneformer ran sucessfully. Find batch embedding example here:\n {}" + .format(out[0])) def tearDown(self): try: From f77297b5eb05ec3a0b8c9c214354f36dda5fe796 Mon Sep 17 00:00:00 2001 From: Alejandro Velez-Arce Date: Mon, 11 Nov 2024 19:53:52 -0500 Subject: [PATCH 2/2] mend --- tdc/test/test_model_server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index 29a1df09..6a51ab5c 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -154,7 +154,9 @@ def testGeneformerTokenizer(self): # build an attention mask attention_mask = torch.tensor( [[x[0] != 0, x[1] != 0] for x in batch]) - outputs = model(batch, attention_mask=attention_mask) + outputs = model(batch, + attention_mask=attention_mask, + output_hidden_states=True) layer_to_quant = quant_layers(model) + ( -1 ) # TODO note this can be parametrized to either 0 (extract last embedding layer) or -1 (second-to-last which is more generalized)