Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix enc_dec bug and Make several improvements to whisper #992

Closed
wants to merge 1 commit into from

Conversation

Eddie-Wang1120
Copy link
Contributor

@Eddie-Wang1120 Eddie-Wang1120 commented Jan 28, 2024

Thanks to the brilliant work for NVIDIA team!
I made some changes to Tensorrt-LLM and hope to get some advice!

Pull Request Intro

This Pull Request include several points:

  • fix a bug in enc_dec model which will lead to a build error when the model has cross_attention and use weight_only_gemm_plugin at the same time.
  • ban layernorm plugin otherwise it will brings a severe memory usage increase for whisper fp16 inference (16000MiB to 8000MiB).
  • add int4 weight-only support to whisper
  • gives a base implementation for whisper int8_kv_cache (half-way finished due to an internal error)

What is the bug and What I do

Bug intro

The bug can be make a reproduction in previous version when add weight_only_gemm_plugin to whisper decoder model. The expected behaviour is to pass building correctly. However, when it comes to profiling in building step, errors as below will show in log, and build will ended up as failing.

[TensorRT-LLM][WARNING] Cannot profile configuration 59 (for m=0, n=3840, k=1280), reason: "[TensorRT-LLm Error][fpA_intB Runner] Failed to run cutlass fpA_intB gemm. Error: Error Internal". Skipped
[TensorRT-LLM][WARNING] Cannot profile configuration 60 (for m=0, n=3840, k=1280), reason: "[TensorRT-LLm Error][fpA_intB Runner] Failed to run cutlass fpA_intB gemm. Error: Error Internal". Skipped
[TensorRT-LLM][WARNING] Cannot profile configuration 61 (for m=0, n=3840, k=1280), reason: "Temp assertion: k must be multiple of threadblockK". Skipped
[TensorRT-LLM][WARNING] Cannot profile configuration 62 (for m=0, n=3840, k=1280), reason: "Temp assertion: k must be multiple of threadblockK". Skipped
[TensorRT-LLM][WARNING] Have not found any valid GEMM config for shape (m=0, n=3840, k=1280). Will try to use default or fail at runtime

How to solve

After I rebuild the whisper decoder model (which Inherits from enc_dec DecoderModel) layer by layer, I find the error only happens when the model has a cross attention. More suspiciously, when checking the prepare_inputs function in DecoderModel, a variable called encoder_input_len_range caught my eyes, for it is a dim range be used by several special inputs for cross_attention and the min range is 0 which exactly explains why there are m=0 logs in building process.

encoder_input_len_range = [
    0, (max_encoder_input_len + 1) // 2, max_encoder_input_len
]

In my opinion, the min value of encoder_input_len_range does not have to be 0 because it is not like kv-cache which needs to be concatenate. After I change it to 1, the building process passed successfully and the results maintain correction.
Now, the enc_dec model all can use weight_only_gemm_plugin and enjoy the performance improvements freely.

About LayerNorm plugin

Banning LayerNorm plugin is always a top mission for it is going to be deprecated. A main reason why it still be retained in the previous version is because simply banning it will lead to a building failure. In this version, banning it no longer bring any errors and brings multiple benefits. Most clearly, the memory usage of whisper fp16 inference decreases from 16030MiB to 8000MiB, means the whisper can be inference by Tensorrt-LLM in more devices.

About int8_kv_cache

It's a pity that the int8_kv_cache for whisper model still not finished. The building process seems correctly. When it comes to the inference step, an internal error occurs. After I tried all ways I can imagined, it still preserved. I create an issue for this bug #993 and display detailed bug information in it. Anyone is interested and has an idea please let me know, I sincerely hopes this error can be solved at an early date, thanks you all in advance.

Performance

\ float16 (with layernorm plugin) float16 int8 weight-only int4 weight-only
GPU memory usage 16030MiB 8186MiB 6717MiB 6036MiB
RTF 0.1962 0.0542 0.0488 0.0473
processing time 94.397s 26.066s 23.492s 22.741s
batch_size 4 4 4 4
num_beams 1 1 1 1
WER 2.99 2.99 2.82 6.33

Environment

  • intel i5 13500
  • nvidia 4060ti 16G
  • Tensorrt-LLM commitID b57221b
  • container nvidia-docker run --entrypoint /bin/bash -it nvidia/cuda:12.1.0-devel-ubuntu22.04

@yuekaizhang
Copy link

This is really nice work! Many thanks to you @Eddie-Wang1120. I would import this into internal gitlab and hopefully it could be done this week.

@Eddie-Wang1120
Copy link
Contributor Author

This is really nice work! Many thanks to you @Eddie-Wang1120. I would import this into internal gitlab and hopefully it could be done this week.

Thanks a lot!

@shashikg
Copy link

Thanks for the awesome contributions from you two!

Adding some of my minor observations relevant to this:

  1. Layer norm issue: +1, I observed similar behavior.

  2. bert_attention_plugin weird behavior:

    • On A30 GPU, setting network.plugin_config.set_bert_attention_plugin(dtype=args.use_bert_attention_plugin) negatively impacts performance.
    • Following the bert_attention_plugin docs, I added network.plugin_config.set_context_fmha(ContextFMHAType.enabled) to the builder config, which improved inference speed but increased memory usage.
    • Interestingly, simply disabling bert_attention_plugin achieves similar speed with lower memory usage.
    • On T4 GPU, unlike A30, disabling bert_attention_plugin outperforms using both bert_attention_plugin and context_fmha.

@Eddie-Wang1120
Copy link
Contributor Author

Thanks for your advices! @shashikg
Following your observations, I disabled bert_attention_plugin and got some results:

with bert_attention_plugin

\ float16 int8 weight-only int4 weight-only
GPU memory usage 8186MiB 6717MiB 6036MiB
RTF 0.0542 0.0488 0.0473
processing time 26.066s 23.492s 22.741s
batch_size 4 4 4
num_beams 1 1 1
WER 2.99 2.82 6.33

disable bert_attention_plugin

\ float16 int8 weight-only int4 weight-only
GPU memory usage 6065MiB 5696MiB 5024MiB
RTF 0.0464 0.0489 0.0465
processing time 22.314s 23.530s 22.379s
batch_size 4 4 4
num_beams 1 1 1
WER 2.99 2.82 6.33

The results shows that disable bert_attention_plugin indeed decrease memory usage, and may improve inference speed at some situations. Maybe we should consider using this plugin cautiously.

@symphonylyh
Copy link
Collaborator

symphonylyh commented Jan 31, 2024

@Eddie-Wang1120 great to know about the encoder_input_len_range issue when used together with weight only gemm plugin, I agree with your fix that the min value doesn't need to be 0 in all cases.

@Eddie-Wang1120 @shashikg general guidance on layernorm and bert plugin usage:

  1. For LayerNorm/RMSNorm plugin, it's in deprecation mode, so it's recommended to do without these normalization plugins
  2. For BERT plugin,
    First, it should be always used together with --enable_context_fmha, otherwise the comparison is not fair because it's using the unfused multi-head attention implementation
    Second, regarding w/ and w/o BERT plugin, we have done some investigation and 3 takeaways:
  • Peak memory usage wise, w/o BERT plugin is indeed better than w/ BERT plugin. If peak memory is a restriction, consider use w/o BERT plugin
  • Performance wise, on BERT example itself, w/o and w/ plugin paths are on par based on our benchmark. However, we mainly tested on newer GPUs such as Ampere and Hopper. It's possible on older ones like T4 you observed a different trend. In that case, it's recommended to try both on your specific GPU and decide
  • Lastly, from a practical perspective, w/o BERT plugin path has a limitation on padding removal -- that is, when you have ragged input, e.g., batch size = 2, text1 is length 10, text2 is length 100, w/ BERT plugin path can do padding removal by effectively doing a computation of length 10+100=110 text (the BERT example currently doesn't demonstrate this, which I plan to add and clarify this point), while the w/o BERT plugin path can only do computation on the padded one, so equivalently 100+100=200 text. This could make a big difference in real deployment. If this is a concern, this last point would become a deciding factor to favor the w/ plugin path.

@shashikg
Copy link

shashikg commented Jan 31, 2024

Thank you so much @symphonylyh for the guidelines!

Lastly, from a practical perspective, w/o BERT plugin path has a limitation on padding removal -- that is, when you have ragged input, e.g., batch size = 2, text1 is length 10, text2 is length 100, w/ BERT plugin path can do padding removal by effectively doing a computation of length 10+100=110 text (the BERT example currently doesn't demonstrate this, which I plan to add and clarify this point), while the w/o BERT plugin path can only do computation on the padded one, so equivalently 100+100=200 text. This could make a big difference in real deployment. If this is a concern, this last point would become a deciding factor to favor the w/ plugin path.

I see... Based on this I think now it make sense why w/ BERT plugin, performance on Whisper model does not improves (because I was running the inference on fixed 30 seconds input). So the whisper model is trained on fixed 30 seconds audios and during inference as well it expects to receive a 30 seconds audio. Even if an audio is smaller than 30 seconds and if we run the whisper's encoder on it without padding the input audio to 30 seconds, whisper's decoder falls more frequently in generating hallucinated outputs/ or repeated texts. So basically the inputs to whisper's encoder will always be of same length.

@yuekaizhang
Copy link

Thank you so much @symphonylyh for the guidelines!

Lastly, from a practical perspective, w/o BERT plugin path has a limitation on padding removal -- that is, when you have ragged input, e.g., batch size = 2, text1 is length 10, text2 is length 100, w/ BERT plugin path can do padding removal by effectively doing a computation of length 10+100=110 text (the BERT example currently doesn't demonstrate this, which I plan to add and clarify this point), while the w/o BERT plugin path can only do computation on the padded one, so equivalently 100+100=200 text. This could make a big difference in real deployment. If this is a concern, this last point would become a deciding factor to favor the w/ plugin path.

I see... Based on this I think now it make sense why w/ BERT plugin, performance on Whisper model does not improves (because I was running the inference on fixed 30 seconds input). So the whisper model is trained on fixed 30 seconds audios and during inference as well it expects to receive a 30 seconds audio. Even if an audio is smaller than 30 seconds and if we run the whisper's encoder on it without padding the input audio to 30 seconds, whisper's decoder falls more frequently in generating hallucinated outputs/ or repeated texts. So basically the inputs to whisper's encoder will always be of same length.

@shashikg We actually could remove the padding 30s restriction of encoder, see https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py#L15. It would save cross kv cache VRAM usage as well. However, there is a bug now if we set conv subsampling layers in encoder with dynamic seq_len dim.

@Eddie-Wang1120
Copy link
Contributor Author

@Eddie-Wang1120 great to know about the encoder_input_len_range issue when used together with weight only gemm plugin, I agree with your fix that the min value doesn't need to be 0 in all cases.

@Eddie-Wang1120 @shashikg general guidance on layernorm and bert plugin usage:

  1. For LayerNorm/RMSNorm plugin, it's in deprecation mode, so it's recommended to do without these normalization plugins
  2. For BERT plugin,
    First, it should be always used together with --enable_context_fmha, otherwise the comparison is not fair because it's using the unfused multi-head attention implementation
    Second, regarding w/ and w/o BERT plugin, we have done some investigation and 3 takeaways:
  • Peak memory usage wise, w/o BERT plugin is indeed better than w/ BERT plugin. If peak memory is a restriction, consider use w/o BERT plugin
  • Performance wise, on BERT example itself, w/o and w/ plugin paths are on par based on our benchmark. However, we mainly tested on newer GPUs such as Ampere and Hopper. It's possible on older ones like T4 you observed a different trend. In that case, it's recommended to try both on your specific GPU and decide
  • Lastly, from a practical perspective, w/o BERT plugin path has a limitation on padding removal -- that is, when you have ragged input, e.g., batch size = 2, text1 is length 10, text2 is length 100, w/ BERT plugin path can do padding removal by effectively doing a computation of length 10+100=110 text (the BERT example currently doesn't demonstrate this, which I plan to add and clarify this point), while the w/o BERT plugin path can only do computation on the padded one, so equivalently 100+100=200 text. This could make a big difference in real deployment. If this is a concern, this last point would become a deciding factor to favor the w/ plugin path.

Thanks for the guidelines! @symphonylyh

@shashikg
Copy link

shashikg commented Feb 1, 2024

@shashikg We actually could remove the padding 30s restriction of encoder, see https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py#L15. It would save cross kv cache VRAM usage as well.

Hey yes, I agree and most probably this should improve the inference time. I have tested dynamic seq_len in my project "WhisperS2T" (https://github.com/shashikg/WhisperS2T/blob/main/whisper_s2t/backends/__init__.py#L35) with CTranslate2 backend but currently it's in experimental phase (so can break thus not included in docs).

So my concern is not in whether we can run it or not. If we infer with dynamic seq len , what I observed is that whisper's decoder makes more error in generated text output (mostly non-stopping repeated text tokens). Definitely there are various heuristics we can use to work around. But after adding those heuristics inference time will increase. Moreover specifically these non-stopping repeated tokens also increase the generation time significantly. Definitely this issue can be avoided by fine-tuning the model on dynamic seq_len which openai didn't do for some reason.

However, there is a bug now if we set conv subsampling layers in encoder with dynamic seq_len dim.

I am curious what's the exact issue, normally the patch should work. I have tried out a similar thing in past. One issue I can think of is because of detect_language function, check this: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L51 -- this check will create issue if you use detect language function with dynamic seq len.

@yuekaizhang
Copy link

yuekaizhang commented Feb 1, 2024

So my concern is not in whether we can run it or not. If we infer with dynamic seq len , what I observed is that whisper's decoder makes more error in generated text output (mostly non-stopping repeated text tokens). Definitely there are various heuristics we can use to work around. But after adding those heuristics inference time will increase. Moreover specifically these non-stopping repeated tokens also increase the generation time significantly. Definitely this issue can be avoided by fine-tuning the model on dynamic seq_len which openai didn't do for some reason.

Yes, one of the heuristics is to pad 50 frames at the end. k2-fsa/sherpa-onnx#471

I am curious what's the exact issue, normally the patch should work. I have tried out a similar thing in past. One issue I can think of is because of detect_language function, check this: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L51 -- this check will create issue if you use detect language function with dynamic seq len.

https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/functional.py#L2813-L2814 This view operation has some issue. I think it should be a small fix to handle it.

@symphonylyh
Copy link
Collaborator

symphonylyh commented Feb 1, 2024

@yuekaizhang I have less background on the Whisper discussion here, but do you mean the current functional.py::conv2d() cannot handle dynamic axes due to the output.view(concat([output.size(1), output.size(2), output.size(3)])) call?

If I understand correctly, this call is doing a squeeze call to remove the 1st dimension, as symmetric to the unsqueeze(input) call before. In that case, do you think select(output, dim=0, index=0) can do the same and meanwhile support dynamic axis?

Update: please use more general squeeze implementation for now, add to functional.py

def squeeze(input: Tensor, dim: Union[int, Sequence[int]] = None):
    if dim is None:
        dim = list(range(input.ndim()))

    if isinstance(dim, int):
        dim = (dim, )

    new_shape = []
    for i, s in enumerate(input.shape):
        if s == 1 and i in dim:
            continue
        new_shape.append(shape(input, i))

    input = input.view(concat(new_shape))
    return input

@yuekaizhang
Copy link

@yuekaizhang I have less background on the Whisper discussion here, but do you mean the current functional.py::conv2d() cannot handle dynamic axes due to the output.view(concat([output.size(1), output.size(2), output.size(3)])) call?

If I understand correctly, this call is doing a squeeze call to remove the 1st dimension, as symmetric to the unsqueeze(input) call before. In that case, do you think select(output, dim=0, index=0) can do the same and meanwhile support dynamic axis?

Thanks, I would try your suggestion and give feedback to you. @shashikg @symphonylyh

@yuekaizhang
Copy link

Added a data point using A16 GPU.
Batch_size 4, num_beam 1

FP16 Weight-only-quant int8
35 secs Decoding Time 33 secs Decoding Time
2.48% Word Error Rate 2.48% Word Error Rate
5.5 GB VRAM 4 GB VRAM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants