Skip to content

Commit

Permalink
Fix pre-compute not correct issue
Browse files Browse the repository at this point in the history
For spec decode eagle mode, need to VLLM_COS_SIN_RECOMPUTE=true

Signed-off-by: Chendi.Xue <[email protected]>
  • Loading branch information
xuechendi committed Dec 19, 2024
1 parent 66e1af8 commit 54ba9f1
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .jenkins/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ stages:
command: TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_medusa_correctness.py::test_medusa_e2e_greedy_correctness
- name: gsm8k_small_g2_tp1_eagle_spec_decode
flavor: g2
command: TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness
command: VLLM_COS_SIN_RECOMPUTE=true TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness
8 changes: 6 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def __init__(

def prepare_cos_sin(self,
positions: torch.Tensor,
offsets: Optional[torch.Tensor] = None):
offsets: Optional[torch.Tensor] = None,
recompute_cos_sin: bool = False):
self.recompute_cos_sin = recompute_cos_sin
if offsets is not None:
offsets = offsets.view(positions.shape[0], -1)
positions = positions + offsets
Expand Down Expand Up @@ -232,8 +234,10 @@ def forward_hpu(
) -> Tuple[torch.Tensor, torch.Tensor]:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
if self.sin is None:
if self.sin is None or self.cos is None:
self.prepare_cos_sin(positions, offsets)
if self.recompute_cos_sin:
self.prepare_cos_sin(positions, offsets, recompute_cos_sin=True)
num_tokens = positions.shape[0] * positions.shape[1]
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
Expand Down
5 changes: 4 additions & 1 deletion vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def __init__(self, model, block_size, dtype, enforce_eager, layer_names):
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'1').lower() in ['1', 'true'] \
and not is_fake_hpu()
self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE',
'false').lower() in ['1', 'true']
self.block_size = block_size
self.dtype = dtype
self.layer_names = layer_names
Expand Down Expand Up @@ -370,7 +372,8 @@ def _prepare_cos_sin(self, positions):

# At the end, we should be at the RotaryEmbedding layer.
if hasattr(current_module, 'prepare_cos_sin'):
current_module.prepare_cos_sin(positions)
current_module.prepare_cos_sin(
positions, recompute_cos_sin=self.recompute_cos_sin)
else:
raise AttributeError(
"The module at the end of the path does not have \
Expand Down

0 comments on commit 54ba9f1

Please sign in to comment.