From 05f9e9f3c0ca15ca051d8a580507b81e269a6003 Mon Sep 17 00:00:00 2001 From: Alejandro Velez-Arce Date: Sat, 26 Oct 2024 19:30:16 -0400 Subject: [PATCH] process per batch model(batch) --- tdc/test/test_model_server.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index 9af83f67..6707eb97 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -130,12 +130,13 @@ def testGeneformerTokenizer(self): input_tensor = torch.tensor(cells) # input_tensor = torch.squeeze(input_tensor) + out = [] try: - # input_tensor.squeeze(2) # last dim is zero - out = model(input_tensor) + for batch in input_tensor: + out.append(model(batch)) except Exception as e: raise Exception("tensor shape is", input_tensor.shape, "exception was:", e, "\n cells was\n", cells) - # raise Exception(e) + # input_tensor = torch.tensor(cells) # input_tensor_squeezed = torch.squeeze(input_tensor) # x = input_tensor_squeezed.shape[0] @@ -147,8 +148,8 @@ def testGeneformerTokenizer(self): # except Exception as e: # raise Exception("tensor shape is", input_tensor.shape, "exception was: {}".format(e), "input_tensor_squeezed is\n", input_tensor, "\n\ninput_tensor normal is: {}".format(input_tensor)) assert out, "FAILURE: Geneformer output is false-like. Value = {}".format(out) - assert out.shape[0] == input_tensor.shape[0], "FAILURE: Geneformer output and input tensor input don't have the same length. {} vs {}".format(out.shape[0], input_tensor.shape[0]) - assert out.shape[0] == len(cells), "FAILURE: Geneformer output and tokenized cells don't have the same length. {} vs {}".format(out.shape[0], len(cells)) + assert len(out) == input_tensor.shape[0], "FAILURE: Geneformer output and input tensor input don't have the same length. {} vs {}".format(len(out), input_tensor.shape[0]) + assert len(out) == len(cells), "FAILURE: Geneformer output and tokenized cells don't have the same length. {} vs {}".format(len(out), len(cells)) def tearDown(self): try: