Skip to content

Commit

Permalink
reshape fix
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed Oct 23, 2024
1 parent 8ccdb01 commit 44f672d
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,17 @@ def testGeneformerTokenizer(self):
model = geneformer.load()
input_tensor = torch.tensor(cells)
input_tensor = torch.squeeze(input_tensor)
raise Exception("shape is", input_tensor.shape, "values are\n", input_tensor)
out = model(input_tensor)
x = input_tensor.shape[0]
y = input_tensor.shape[1]
input_tensor = input_tensor.reshape(x, y)
out = None # try-except block
try:
out = model(input_tensor)
except Exception as e:
raise Exception("shape is", input_tensor.shape, "exception was: {}".format(e), "values are\n", input_tensor)
assert out, "FAILURE: Geneformer output is false-like. Value = {}".format(out)
assert len(out) == len(cells), "FAILURE: Geneformer output and cells input don't have the same length. {} vs {}".format(len(out), len(cells))
assert out.shape[0] == input_tensor.shape[0], "FAILURE: Geneformer output and input tensor input don't have the same length. {} vs {}".format(out.shape[0], input_tensor.shape[0])
assert out.shape[0] == len(cells), "FAILURE: Geneformer output and tokenized cells don't have the same length. {} vs {}".format(out.shape[0], len(cells))

def tearDown(self):
try:
Expand Down

0 comments on commit 44f672d

Please sign in to comment.