Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG fix] Rebase caused spec decode fix #613

Open
wants to merge 3 commits into
base: habana_main
Choose a base branch
from

Conversation

xuechendi
Copy link

@xuechendi xuechendi commented Dec 11, 2024

Error reported in https://jira.habana-labs.com/browse/SW-212516

Found two recent merged PR breaks down Spec Decode functionality:

  1. Support mllama (llama 3.2) model for HPU #491 overrides existing workerwrapperBase design for speculative decoding.
if model_runner_cls is not None:
    ModelRunnerClass = model_runner_cls

is not needed since we now use codes as below for init model_runner_cls to follow upstream design.

if model_runner_cls is not None:
            self.model_runner = model_runner_cls(self.model_runner)
  1. Prepare sin/cos buffers for rope outside model forward #566 is not working in Spec Decode Eagle mode
    Due to input tensors is now different to the pre-assumption that decode_fwd only provide one token per seq. Spec Decode provides multiple candidates tokens as q.
    To fix that, added a new ENV - "VLLM_COS_SIN_RECOMPUTE=true", need to use it to trigger recompute to cos and sin for spec decode.

@xuechendi
Copy link
Author

@michalkuligowski , please help to review.

@xuechendi
Copy link
Author

@kzawora-intel , please check a fix here:
previous mllama PR will break spec decode, I added a fix PR
de79b5c

@@ -741,6 +749,8 @@ def load_model(self) -> None:
get_decoder_layer_suffix(model_config.model_type if
model_config is not None else None),
hidden_layer_markstep_interval)
recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE',
'false').lower() == 'true'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have "in ['1', 'true']" instead of "== 'true'"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On another note, do you need to get the value of this env var here? Can it be done in init of RotaryEmbedding instead? I don't think it's necessary to pass this value to rope.prepare_cos_sin from here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I want to pass the value from hpu_model_runner is because it is easier to be noticed. And I think the "RotaryEmbedding.init()" is general for all HW, so I don't want to do any change there.

offsets: Optional[torch.Tensor] = None):
offsets: Optional[torch.Tensor] = None,
recompute_cos_sin: bool = False):
self.recompute_cos_sin = recompute_cos_sin

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you set
self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', 'false').lower() in ['1', 'true'])
in the init method, you don't have to pass the recompute_cos_sin parameter here (see my other comment on hpu_model_runner.py:753)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the 'RotaryEmbedding.init()' is general for all HW, so I am thinking to only pass this new 'argument' in prepare_cos_sin() which is added by us?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a valid point, it might be easier to upstream if we don't touch the constructor

offsets: Optional[torch.Tensor] = None):
offsets: Optional[torch.Tensor] = None,
recompute_cos_sin: bool = False):
self.recompute_cos_sin = recompute_cos_sin

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a valid point, it might be easier to upstream if we don't touch the constructor

@@ -741,6 +749,8 @@ def load_model(self) -> None:
get_decoder_layer_suffix(model_config.model_type if
model_config is not None else None),
hidden_layer_markstep_interval)
recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE',

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You explained why you don't want to touch the init method of RotaryEmbedding, but maybe we can at least move this getter to the init of HpuModelAdapter?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, I've moved it into HpuModelAdapter - init func

For spec decode eagle mode, need to VLLM_COS_SIN_RECOMPUTE=true

Signed-off-by: Chendi.Xue <[email protected]>
@xuechendi xuechendi force-pushed the rebase_caused_spec_decode_fix branch from ef256b5 to 54ba9f1 Compare December 19, 2024 21:12
@xuechendi xuechendi requested a review from vivekgoe as a code owner December 19, 2024 21:12
@xuechendi
Copy link
Author

@michalkuligowski , since I'll take long leave starts next week, would like to check with you if we can get this fix merged?
I rebased this PR today and tested with below two scripts locally, both passed

test_spec.sh

VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_mlp_correctness.py::test_mlp_e2e_greedy_correctness
VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_medusa_correctness.py::test_medusa_e2e_greedy_correctness
VLLM_COS_SIN_RECOMPUTE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness

qa.sh

VLLM_SKIP_WARMUP=true \
VLLM_CONTIGUOUS_PA=false \
python3 benchmarks/benchmark_throughput.py --model=meta-llama/Llama-2-13b-chat-hf --device=hpu --seed=2024 --backend=vllm --input_len=1024 --num-prompts=128 --output_len=128 --dtype=bfloat16 --num_scheduler_steps=1 --gpu-memory-util=0.9 --tensor_parallel_size=1 --max-model-len=4096 --speculative_model=ibm-fms/llama-13b-accelerator --use-v2-block-manager

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants