Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG fix] Rebase caused spec decode fix #613

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .jenkins/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ stages:
command: TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_medusa_correctness.py::test_medusa_e2e_greedy_correctness
- name: gsm8k_small_g2_tp1_eagle_spec_decode
flavor: g2
command: TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness
command: VLLM_COS_SIN_RECOMPUTE=true TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness
6 changes: 5 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def __init__(

def prepare_cos_sin(self,
positions: torch.Tensor,
offsets: Optional[torch.Tensor] = None):
offsets: Optional[torch.Tensor] = None,
recompute_cos_sin: bool = False):
self.recompute_cos_sin = recompute_cos_sin

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you set
self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', 'false').lower() in ['1', 'true'])
in the init method, you don't have to pass the recompute_cos_sin parameter here (see my other comment on hpu_model_runner.py:753)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the 'RotaryEmbedding.init()' is general for all HW, so I am thinking to only pass this new 'argument' in prepare_cos_sin() which is added by us?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a valid point, it might be easier to upstream if we don't touch the constructor

if offsets is not None:
offsets = offsets.view(positions.shape[0], -1)
positions = positions + offsets
Expand Down Expand Up @@ -237,6 +239,8 @@ def forward_hpu(
# forward, since the offset information wasn't available previously
if hasattr(self, "scaling_factors") or self.sin is None:
self.prepare_cos_sin(positions, offsets)
if self.recompute_cos_sin:
self.prepare_cos_sin(positions, offsets, recompute_cos_sin=True)
num_tokens = positions.shape[0] * positions.shape[1]
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
Expand Down
3 changes: 2 additions & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,8 @@ def update(self,
second_last_token_hidden_states: Optional[torch.Tensor] = None):
"""Update hidden states from target model invocation. Only used for
decode steps"""
assert len(seq_group_metadata_list) == len(hidden_states)
if len(seq_group_metadata_list) < len(hidden_states):
hidden_states = hidden_states[:len(seq_group_metadata_list)]
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self.hidden_states = torch.cat([self.hidden_states, hidden_states])

Expand Down
5 changes: 4 additions & 1 deletion vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def __init__(self, model, block_size, dtype, enforce_eager, layer_names):
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'1').lower() in ['1', 'true'] \
and not is_fake_hpu()
self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE',
'false').lower() in ['1', 'true']
self.block_size = block_size
self.dtype = dtype
self.layer_names = layer_names
Expand Down Expand Up @@ -370,7 +372,8 @@ def _prepare_cos_sin(self, positions):

# At the end, we should be at the RotaryEmbedding layer.
if hasattr(current_module, 'prepare_cos_sin'):
current_module.prepare_cos_sin(positions)
current_module.prepare_cos_sin(
positions, recompute_cos_sin=self.recompute_cos_sin)
else:
raise AttributeError(
"The module at the end of the path does not have \
Expand Down
4 changes: 1 addition & 3 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def __init__(

is_encoder_decoder_model = self._is_encoder_decoder_model()
ModelRunnerClass: Type[HPUModelRunnerBase] = HPUModelRunner
if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls
elif is_encoder_decoder_model:
if is_encoder_decoder_model:
ModelRunnerClass = HPUEncoderDecoderModelRunner
self.model_runner: HPUModelRunnerBase = ModelRunnerClass(
vllm_config=vllm_config,
Expand Down
Loading