From f0ebf4dc553e1ecd479c57af61f93ba1600ef1b1 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Wed, 20 Mar 2024 21:03:22 +0000 Subject: [PATCH] Add changes suggested by linter --- .../python/tools/transformers/models/llama/benchmark_e2e.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py index 75e2044ae6784..3d93de482ca66 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py @@ -383,7 +383,11 @@ def main(): inputs["attention_mask"] = torch.cat( [inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1 ) - inputs["position_ids"] = None if "position_ids" not in inputs else torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1 + inputs["position_ids"] = ( + None + if "position_ids" not in inputs + else torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1 + ) # Set logits to zeros for next inference run and re-use memory buffer if outputs["logits"].shape[1] != 1: