diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 0796da6dc1959..bc1ed7b31d538 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1340,7 +1340,13 @@ def warmup_scenario(self, profiler.start() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches, warmup_mode=True) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(inputs, kv_caches, intermediate_tensors=intermediate_tensors, warmup_mode=True) torch.hpu.synchronize() if profiler: profiler.step() @@ -1813,12 +1819,6 @@ def execute_model( use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) self._check_config(batch_size, seq_len, is_prompt, warmup_mode) - if not get_pp_group().is_first_rank: - intermediate_tensors = self.model.make_empty_intermediate_tensors( - batch_size=batch_size, - dtype=self.model_config.dtype, - device=self.device) - execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions,