Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with converting custom encoder model #2535

Open
2 of 4 tasks
AvivSham opened this issue Dec 4, 2024 · 9 comments
Open
2 of 4 tasks

Issue with converting custom encoder model #2535

AvivSham opened this issue Dec 4, 2024 · 9 comments
Labels
bug Something isn't working

Comments

@AvivSham
Copy link

AvivSham commented Dec 4, 2024

System Info

  • Ubuntu Linux
  • GPU NVIDIA A10G
  • ENV:
    • tensorrt 10.6.0.post1
    • tensorrt-cu12 10.6.0.post1
    • tensorrt-cu12-bindings 10.6.0.post1
    • tensorrt-cu12-libs 10.6.0.post1
    • tensorrt_llm 0.16.0.dev2024111900

Who can help?

@byshiue

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

We are trying to convert build and run a custom encoder-decoder model, it differs from Whisper by just a single linear layer.
We followed this guide and created Custom encoder model by adding a single linear layer:

class CustomEncoder(WhisperEncoder):
    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.lin = Linear(in_features=1280, out_features=1280)

    def forward(self,
                input_features: Tensor,
                input_lengths=None,
                position_ids=None):
        if default_net().plugin_config.remove_input_padding:
            # BXT,D -> 1,BxT,D -> 1,D,BxT
            input_features = unsqueeze(input_features, 0)
            input_features = transpose(input_features, 1, 2)
        # Encoder conv needs to run in fp32 on Volta/Turing
        x_type = input_features.dtype
        input_features = cast(input_features, self._conv_dtype)
        x = self.conv1(input_features)
        x = gelu(x)
        x = self.conv2(x)
        x = cast(x, x_type)
        x = gelu(x)
        x = transpose(x, 2, 1)
        x = x + cast(self.position_embedding(position_ids), x.dtype)

        if default_net().plugin_config.remove_input_padding:
            #B,T,D -> BxT,D
            x = x.view([-1, self.config.hidden_size])
        hidden_states = x
        input_lengths = input_lengths // self.downsample_factor
        for encoder_layer in self.encoder_layers:
            hidden_states = encoder_layer(hidden_states,
                                          input_lengths=input_lengths)

        x = hidden_states
        x = self.lin(x)
        x = self.ln_post(x)
        x.mark_output('encoder_output', self._dtype)
        return x

We needed to incorporate a few additional changes.

  1. we also wrote a new convert_checkpoint.py, just for sanity we added these lines to the convert_checkpoint.py file in whipser example:
    In line 246 we added the following lines since the added linear layer is not included in whisper-v3.pt file from the example
weights.update(get_tllm_linear_weight(
        torch.rand(1280, 1280, dtype=torch.float16), "lin", torch.rand(1280, dtype=torch.float16), True, torch.int8)
    )
  1. we created a new config.py file (it does not exist as part of the original repo) to match the new model:
from tensorrt_llm.models import PretrainedConfig


class CustomEncoderPretrainedConfig(PretrainedConfig):
    @classmethod
    def from_dict(cls, config: dict):
        # Maybe we need AutoConfig for this
        from . import NAME2ARCH
        model_cls = NAME2ARCH[config['architecture']]
        config_cls = getattr(model_cls, 'config_class', cls)
        return config_cls(**config)
  1. we added the model to the __init__.py file.

When trying to run the converted model we get the following error:

[TensorRT-LLM][ERROR] Encountered an error in forwardAsync function: Input tensor 'input_features' not found; expected shape: (-1, 128) (/home/jenkins/agent/workspace/LLM/main/L0_Test-x86_64/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmRuntime.cpp:484)
1       0x7f23d14fed5a tensorrt_llm::runtime::TllmRuntime::setInputTensorsImpl(int, std::unordered_map<std::string, std::shared_ptr<tensorrt_llm::runtime::ITensor>, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, std::shared_ptr<tensorrt_llm::runtime::ITensor> > > > const&, bool) + 1370
2       0x7f23d14ff455 tensorrt_llm::runtime::TllmRuntime::setInputTensors(int, std::unordered_map<std::string, std::shared_ptr<tensorrt_llm::runtime::ITensor>, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, std::shared_ptr<tensorrt_llm::runtime::ITensor> > > > const&) + 53
3       0x7f23d189b28b tensorrt_llm::batch_manager::TrtEncoderModel::executeBatch(tensorrt_llm::batch_manager::ScheduledRequests const&) + 107
4       0x7f23d189f862 tensorrt_llm::batch_manager::TrtEncoderModel::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 1010
5       0x7f23d18f1217 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 183
6       0x7f23d18f750d tensorrt_llm::executor::Executor::Impl::executionLoop() + 1325
7       0x7f25d17d85c0 /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so(+0x145c0) [0x7f25d17d85c0]
8       0x7f25e7246ac3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x94ac3) [0x7f25e7246ac3]
9       0x7f25e72d7a04 clone + 68
Traceback (most recent call last):
  File "/root/aiola-asr-serving/../TensorRT-LLM/examples/whisper/run.py", line 579, in <module>
    results, total_duration = decode_wav_file(
  File "/root/aiola-asr-serving/../TensorRT-LLM/examples/whisper/run.py", line 460, in decode_wav_file
    predictions = model.process_batch(mel, features_input_lengths, text_prefix,
  File "/root/aiola-asr-serving/../TensorRT-LLM/examples/whisper/run.py", line 414, in process_batch
    outputs = self.model_runner_cpp.generate(
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/runtime/model_runner_cpp.py", line 732, in generate
    return self._initialize_and_fill_output(
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/runtime/model_runner_cpp.py", line 853, in _initialize_and_fill_output
    return self._fill_output(responses, output_ids, end_id, return_dict,
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/runtime/model_runner_cpp.py", line 922, in _fill_output
    raise RuntimeError(response.error_msg)
RuntimeError: Encountered an error in forwardAsync function: Input tensor 'input_features' not found; expected shape: (-1, 128) (/home/jenkins/agent/workspace/LLM/main/L0_Test-x86_64/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmRuntime.cpp:484)
1       0x7f23d14fed5a tensorrt_llm::runtime::TllmRuntime::setInputTensorsImpl(int, std::unordered_map<std::string, std::shared_ptr<tensorrt_llm::runtime::ITensor>, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, std::shared_ptr<tensorrt_llm::runtime::ITensor> > > > const&, bool) + 1370
2       0x7f23d14ff455 tensorrt_llm::runtime::TllmRuntime::setInputTensors(int, std::unordered_map<std::string, std::shared_ptr<tensorrt_llm::runtime::ITensor>, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, std::shared_ptr<tensorrt_llm::runtime::ITensor> > > > const&) + 53
3       0x7f23d189b28b tensorrt_llm::batch_manager::TrtEncoderModel::executeBatch(tensorrt_llm::batch_manager::ScheduledRequests const&) + 107
4       0x7f23d189f862 tensorrt_llm::batch_manager::TrtEncoderModel::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 1010
5       0x7f23d18f1217 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 183
6       0x7f23d18f750d tensorrt_llm::executor::Executor::Impl::executionLoop() + 1325
7       0x7f25d17d85c0 /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so(+0x145c0) [0x7f25d17d85c0]
8       0x7f25e7246ac3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x94ac3) [0x7f25e7246ac3]
9       0x7f25e72d7a04 clone + 68
[TensorRT-LLM][INFO] Refreshed the MPI local session

However if we make the following changes in new class named WhisperEncoder (by adding the linear layer) it works as expected!

Expected behavior

We expect the converted model to function as WhisperEncoder.

actual behavior

See above. The model has an issue with the input dimension.

additional notes

In the guide attached above some of the steps we made are not documented, and it is not well explained how one should convert and run a custom model with architectural changes (not just weight values).

@AvivSham AvivSham added the bug Something isn't working label Dec 4, 2024
@AvivSham
Copy link
Author

@yuekaizhang maybe you can help here?
It relates to #2491

@yuekaizhang
Copy link

[TensorRT-LLM][ERROR] Encountered an error in forwardAsync function: Input tensor 'input_features' not found; expected shape: (-1, 128)

Can you check the mel tensor here https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/whisper/run.py#L416 ?

However if we make the following changes in new class named WhisperEncoder (by adding the linear layer) it works as expected!

I am a little bit confused. Could you explain more?

@AvivSham
Copy link
Author

Can you check the mel tensor here https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/whisper/run.py#L416 ?

the shape is (1,3000,128) which are batch_size x seq_len x freq_features as expected.

You can see that we named the class CustomEncoder if we change it to WhisperEncoder it works like a charm, maybe there is hard logical condition on the class name.

@yuekaizhang
Copy link

@AvivSham You're right. We currently hard-coding the WhisperEncoder in some closed sourced CPP codes. Would you mind renaming the CustomEncoder to overwrite the WhisperEncoder?

@AvivSham
Copy link
Author

AvivSham commented Dec 11, 2024

Ok, will do that. However, we think it should be flexible and support renaming.

We have a follow-up question, we want to pass an additional tensor to the decoder in addition to encoder_output instead of single output the encoder will have two or more outputs we want to pass to the decoder's forward func. For the discussion let's call this tensor encoder_output_prime what is the most natural way to do it? (we assume that CPP backend would not support it so let's assume we use Python backend)

@yuekaizhang
Copy link

yuekaizhang commented Dec 11, 2024

Ok, will do that. However, we think it should be flexible and support renaming.

Yeah, however, currently, we have no slot to change it. It would be changed to support more multi-modal models in the future.

We have a follow-up question, we want to pass an additional tensor to the decoder in addition to encoder_output instead of single output the encoder will have two or more outputs we want to pass to the decoder's forward func. For the discussion let's call this tensor encoder_output_prime what is the most natural way to do it? (we assume that CPP backend would not support it so let's assume we use Python backend)

Increasing input/output for encoder is easy. You can check https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/whisper/run.py#L188-L215. However, I think for decoder, it is complicated. You may need to modify https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/runtime/generation.py if there is an extra input for LLM/Whisper based decoder.

You may also check https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/model/test_llama.py to see how to make forward() work first.

@AvivSham
Copy link
Author

Thanks for the tips @yuekaizhang .
One final question, we observe that the input for the model is flattened. For example for batch_size=4 the input shape is (12000,128) where the first axis is equal to batch_size x seq_len, however we expect (batch_size, seq_len, mel_features) (as in the Whisper paper). In addition, we see that the decoder_input_ids are flattened to single dim, how does the decoding work in this case?

Given the questions above how can we debug the shapes in the forward passes? during build trt-llm uses dummy tensors with dynamic dims which are not informative. Additionally, during run we use the compiled computational graph (so we do not have access to tensor shapes).

@yuekaizhang
Copy link

Thanks for the tips @yuekaizhang . One final question, we observe that the input for the model is flattened. For example for batch_size=4 the input shape is (12000,128) where the first axis is equal to batch_size x seq_len, however we expect (batch_size, seq_len, mel_features) (as in the Whisper paper). In addition, we see that the decoder_input_ids are flattened to single dim, how does the decoding work in this case?

Given the questions above how can we debug the shapes in the forward passes? during build trt-llm uses dummy tensors with dynamic dims which are not informative. Additionally, during run we use the compiled computational graph (so we do not have access to tensor shapes).

You may check https://nvidia.github.io/TensorRT-LLM/performance/perf-best-practices.html#remove-input-padding.

See also https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/models/enc_dec/model.py#L1964. and https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/models/enc_dec/model.py#L2005

@AvivSham
Copy link
Author

What will be affected by setting remove_input_padding disable? will it result in a latency increase? Will it affect the model performance (i.e. WER)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants