diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py index 4d0d2e68e8983..47b7f35cbdd7c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py @@ -400,11 +400,7 @@ def main(): sampling_times.append(sampling_end_time - sampling_start_time) all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1) - - # Return early if all batch entries have reached EOS token id current_length += 1 - if torch.all(has_eos) or current_length > max_length: - break # Update inputs for next inference run inputs["input_ids"] = tokens_to_add