diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py index 75e2044ae6784..3d93de482ca66 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py @@ -383,7 +383,11 @@ def main(): inputs["attention_mask"] = torch.cat( [inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1 ) - inputs["position_ids"] = None if "position_ids" not in inputs else torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1 + inputs["position_ids"] = ( + None + if "position_ids" not in inputs + else torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1 + ) # Set logits to zeros for next inference run and re-use memory buffer if outputs["logits"].shape[1] != 1: