Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Medusa performance degrades with batch size larger than 1 #2482

Open
SoundProvider opened this issue Nov 22, 2024 · 4 comments
Open

Medusa performance degrades with batch size larger than 1 #2482

SoundProvider opened this issue Nov 22, 2024 · 4 comments
Labels
Performance Issue about performance number triaged Issue has been triaged by maintainers

Comments

@SoundProvider
Copy link

SoundProvider commented Nov 22, 2024

I'm trying to use medusa with trt-llm, referencing this page

It's working fine with vicuna 7B and its medusa heads, as reference in the example page.

In the example, it's stated that Note: Increasing the batch size may have a negative impact on performance
My understanding is that, when the batch size increases, each sequence should wait for the other sequences to reach its position, resulting performance degradation.

But when I tested with vicuna 7B, the performance still dropped with 4 batch, each sequence using the same input. This is contradicting from my understanding.

Image
I tested batch size variation with same inputs(4batch with same inputs)

What would be the reason?? It would be really nice if someone could explain.

Thank you

@hello-11
Copy link
Collaborator

@SoundProvider could you tell me the method of your performance evaluations?

@hello-11 hello-11 added the Performance Issue about performance number label Nov 22, 2024
@SoundProvider
Copy link
Author

SoundProvider commented Nov 22, 2024

@hello-11 hello.
I used the run script in the medusa example folder

python /app/tensorrt_llm/examples/run.py --engine_dir /app/models/medusa_test_3b/tensorrt_llm/4-gpu \
                                            --tokenizer_dir /app/models/vicuna-33b-v1.3 \
                                            --max_output_len=500 \
                                            --medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
                                            --temperature 1.0 \
                                            --input_text "Once upon" \
                                            --run_profiling

I just measured the .generate time

for _ in range(10):
    s_time = time.time()
    outputs = runner.generate(
        batch_input_ids=decoder_input_ids
        if is_enc_dec else batch_input_ids,
        encoder_input_ids=encoder_input_ids if is_enc_dec else None,
        encoder_input_features=encoder_input_features
        if is_enc_dec else None,
        encoder_output_lengths=encoder_output_lengths
        if is_enc_dec else None,
        max_new_tokens=args.max_output_len,
        max_attention_window_size=args.max_attention_window_size,
        sink_token_length=args.sink_token_length,
        end_id=end_id,
        pad_id=pad_id,
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p,
        num_beams=args.num_beams,
        num_return_sequences=args.num_return_sequences,
        length_penalty=args.length_penalty,
        early_stopping=args.early_stopping,
        repetition_penalty=args.repetition_penalty,
        presence_penalty=args.presence_penalty,
        frequency_penalty=args.frequency_penalty,
        stop_words_list=stop_words_list,
        bad_words_list=bad_words_list,
        output_cum_log_probs=(args.output_cum_log_probs_npy != None),
        output_log_probs=(args.output_log_probs_npy != None),
        random_seed=args.random_seed,
        lora_uids=args.lora_task_uids,
        prompt_table=args.prompt_table_path,
        prompt_tasks=args.prompt_tasks,
        streaming=args.streaming,
        output_sequence_lengths=True,
        no_repeat_ngram_size=args.no_repeat_ngram_size,
        return_dict=True,
        medusa_choices=args.medusa_choices,
        return_all_generated_tokens=args.return_all_generated_tokens,
        input_token_extra_ids=input_token_extra_ids)
    e_time = time.time()
    
    times.append( e_time - s_time )

print("[TIME] :: ", np.average(times))

@hello-11 hello-11 added the triaged Issue has been triaged by maintainers label Dec 10, 2024
@hello-11
Copy link
Collaborator

@SoundProvider what do Medusa X and Medusa O mean in your table?

@yweng0828
Copy link

Hi @SoundProvider, Thanks for your attention.

Note: Increasing the batch size may have a negative impact on performance

This is because: Medusa tries many different draft tokens to determine which of them can be accepted. E.g. with the default Medusa choices and BS=1 Medusa runs the generation stage with 64 tokens (63 draft tokens) instead of just 1 token as in the normal model generation phase.

Hence, if GPU is underoccupied (small batch size) Medusa can trade free resources to try to predict the draft tokens and potentially improve the latency. But if BS is high, GPU is fully saturated, and an overhead to verify draft tokens starts being too large compared to the gains that Medusa gives.

The batch size range in which Medusa gives improvement depends on the model, medusa choices, HW and target dataset.

Could you please provide some information about HW? It seems that the problem is due to the limited capabilities of HW.

Thanks,
Yue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Performance Issue about performance number triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants