From 54ba9f1b0b8155988d422e2865a83e60aab9f3f9 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Wed, 11 Dec 2024 05:15:52 +0000 Subject: [PATCH] Fix pre-compute not correct issue For spec decode eagle mode, need to VLLM_COS_SIN_RECOMPUTE=true Signed-off-by: Chendi.Xue --- .jenkins/test_config.yaml | 2 +- vllm/model_executor/layers/rotary_embedding.py | 8 ++++++-- vllm/worker/hpu_model_runner.py | 5 ++++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.jenkins/test_config.yaml b/.jenkins/test_config.yaml index 3d8b2416506c7..0b9a2231d59a8 100644 --- a/.jenkins/test_config.yaml +++ b/.jenkins/test_config.yaml @@ -57,4 +57,4 @@ stages: command: TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_medusa_correctness.py::test_medusa_e2e_greedy_correctness - name: gsm8k_small_g2_tp1_eagle_spec_decode flavor: g2 - command: TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness + command: VLLM_COS_SIN_RECOMPUTE=true TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 6344c3d39eb7e..2f8d434e82024 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -102,7 +102,9 @@ def __init__( def prepare_cos_sin(self, positions: torch.Tensor, - offsets: Optional[torch.Tensor] = None): + offsets: Optional[torch.Tensor] = None, + recompute_cos_sin: bool = False): + self.recompute_cos_sin = recompute_cos_sin if offsets is not None: offsets = offsets.view(positions.shape[0], -1) positions = positions + offsets @@ -232,8 +234,10 @@ def forward_hpu( ) -> Tuple[torch.Tensor, torch.Tensor]: from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb) - if self.sin is None: + if self.sin is None or self.cos is None: self.prepare_cos_sin(positions, offsets) + if self.recompute_cos_sin: + self.prepare_cos_sin(positions, offsets, recompute_cos_sin=True) num_tokens = positions.shape[0] * positions.shape[1] # HPU RoPE kernel requires hidden dimension for cos and sin to be equal # to query hidden dimension, so the original tensors need to be diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index d3090d313d155..b80463195ced0 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -209,6 +209,8 @@ def __init__(self, model, block_size, dtype, enforce_eager, layer_names): self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '1').lower() in ['1', 'true'] \ and not is_fake_hpu() + self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', + 'false').lower() in ['1', 'true'] self.block_size = block_size self.dtype = dtype self.layer_names = layer_names @@ -370,7 +372,8 @@ def _prepare_cos_sin(self, positions): # At the end, we should be at the RotaryEmbedding layer. if hasattr(current_module, 'prepare_cos_sin'): - current_module.prepare_cos_sin(positions) + current_module.prepare_cos_sin( + positions, recompute_cos_sin=self.recompute_cos_sin) else: raise AttributeError( "The module at the end of the path does not have \