Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] forwardAsync assertion failed #2494

Open
akhoroshev opened this issue Nov 25, 2024 · 6 comments
Open

[bug] forwardAsync assertion failed #2494

akhoroshev opened this issue Nov 25, 2024 · 6 comments
Assignees
Labels
Generic Runtime triaged Issue has been triaged by maintainers

Comments

@akhoroshev
Copy link
Contributor

akhoroshev commented Nov 25, 2024

My version

Assertion fails under load

[TensorRT-LLM][ERROR] Encountered an error in forwardAsync function: [TensorRT-LLM][ERROR] Assertion failed: Input length (6973) + max new tokens (4095) + draft tokens (0) must be less than max sequence length (8192). (/sources/contrib/tensorrt-llm/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp:444)
1       0x7fa8df465992 tensorrt_llm::common::throwRuntimeError(char const*, int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 78
2       0x7fa8df66b693 tensorrt_llm::runtime::GptDecoderBatched::newRequest(int, tensorrt_llm::runtime::decoder_batch::Request const&, tensorrt_llm::runtime::SamplingConfig const&) + 4307
3       0x7fa8df66d7cc tensorrt_llm::runtime::GptDecoderBatched::newRequests(std::vector<int, std::allocator<int> > const&, std::vector<tensorrt_llm::runtime::decoder_batch::Request, std::allocator<tensorrt_llm::runtime::decoder_batch::Request> > const&, std::vector<tensorrt_llm::runtime::SamplingConfig, std::allocator<tensorrt_llm::runtime::SamplingConfig> > const&) + 172
4       0x7fa8e15f93c5 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::setupDecoderStep(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 725
5       0x7fa8e15fbb90 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::forwardAsync(std::__cxx11::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 3792
6       0x7fa8e1625a71 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::__cxx11::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 353
7       0x7fa8e162a97f tensorrt_llm::executor::Executor::Impl::executionLoop() + 895
8       0x7fa8bafaba80 /opt/wmcore/lib/libtensorrt_llm_nvrtc_wrapper.so(+0x32c5a80) [0x7fa8bafaba80]
9       0x7fa8720d01ca /lib64/libpthread.so.0(+0x81ca) [0x7fa8720d01ca]
10      0x7fa87140de73 clone + 67
[TensorRT-LLM][INFO] TRTGptModel maxNumSequences: 256
[TensorRT-LLM][INFO] TRTGptModel maxBatchSize: 256
[TensorRT-LLM][INFO] TRTGptModel maxBeamWidth: 1
[TensorRT-LLM][INFO] TRTGptModel maxSequenceLen: 8192
[TensorRT-LLM][INFO] TRTGptModel maxDraftLen: 0
[TensorRT-LLM][INFO] TRTGptModel mMaxAttentionWindowSize: (8192) * 28
[TensorRT-LLM][INFO] TRTGptModel enableTrtOverlap: 0
[TensorRT-LLM][INFO] TRTGptModel normalizeLogProbs: 1
[TensorRT-LLM][INFO] TRTGptModel maxNumTokens: 4096
[TensorRT-LLM][INFO] TRTGptModel maxInputLen: 8191  = maxSequenceLen - 1 since chunked context is enabled

I don't know how it's possible because

  1. for all my requests input_length <= 7168
  2. for all my requests max_new_tokens=min(4096, 8192 - input_length)

Moreover, Executor additionally checks this invariant.

The only idea is that tensorrt_llm::batch_manager::TrtGptModelInflightBatching::setupDecoderStep is setting wrong max_new_tokens for decoder_batch::Request (under certain conditions)

@hello-11 hello-11 added triaged Issue has been triaged by maintainers runtime labels Nov 25, 2024
@nekorobov
Copy link
Collaborator

Hi @akhoroshev, thank you for taking time to report the issue. From just looking at code, the logic seems correct to me. I see no way how max_new_tokens can be equal to 4095. The check in the GenericLlmRequest::validate is called only via executor API. Old GptManager API does not call it.

Could you share a reproducer, please?

@nekorobov nekorobov self-assigned this Nov 25, 2024
@akhoroshev
Copy link
Contributor Author

akhoroshev commented Nov 25, 2024

@nekorobov

From just looking at code, the logic seems correct to me. I see no way how max_new_tokens can be equal to 4095

It happens under load, for example it's possible to have two requests (or more):

  1. input_length=4097, max_new_tokens=4095
  2. input_length=6973, max_new_tokens=1219

They are both valid (GenericLlmRequest::validate was called since I use Executor API)

But assertion fails

@akhoroshev
Copy link
Contributor Author

Could you share a reproducer, please?

I can't because it's a closed model.

@akhoroshev
Copy link
Contributor Author

Hi! Any updates here?

@akhoroshev
Copy link
Contributor Author

@nekorobov @nv-guomingz

@TriLoo
Copy link

TriLoo commented Dec 10, 2024

meet "Encountered an error in forwardAsync function: std::bad_cast" error when running BERT/Roberta,
install tensorrt-llm from source code

  • commit id: 340a1b6
  • GPU: A100
  • CUDA-12.4

File "/home/aiscuser/.conda/envs/py10/lib/python3.10/site-packages/fire/core.py", line 135, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/home/aiscuser/.conda/envs/py10/lib/python3.10/site-packages/fire/core.py", line 468, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/home/aiscuser/.conda/envs/py10/lib/python3.10/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/home/aiscuser/DistillLLM/inference/tensorrt_llm/timing_rc2.py", line 137, in timing outputs = runner.generate( File "/home/aiscuser/.conda/envs/py10/lib/python3.10/site-packages/tensorrt_llm/runtime/model_runner_cpp.py", line 747, in generate return self._initialize_and_fill_output( File "/home/aiscuser/.conda/envs/py10/lib/python3.10/site-packages/tensorrt_llm/runtime/model_runner_cpp.py", line 886, in _initialize_and_fill_output return self._fill_output( File "/home/aiscuser/.conda/envs/py10/lib/python3.10/site-packages/tensorrt_llm/runtime/model_runner_cpp.py", line 990, in _fill_output raise RuntimeError(response.error_msg) RuntimeError: Encountered an error in forwardAsync function: std::bad_cast

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Generic Runtime triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

5 participants