Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encountered an error in forwardAsync function: [TensorRT-LLM][ERROR] CUDA runtime error in cudaMemcpyAsync(dst, src.data(), src.getSizeInBytes(), cudaMemcpyDefault, mStream->get()): invalid argument #2358

Open
zhaocc1106 opened this issue Oct 20, 2024 · 8 comments
Assignees
Labels
bug Something isn't working triaged Issue has been triaged by maintainers

Comments

@zhaocc1106
Copy link

ENV:

CPU: Intel(R) Xeon(R) Gold 6230 CPU @ 2.10GHz
GPU: 2080Ti * 4
tensorrt-llm: 0.12.0
tensorrt: 10.3.0

ISSUE:
I use c++ api of "tensorrt_llm/batch_manager/" to deploy a multi-modal llm. When i build tensorrt-llm engine with --tp 4 and use 4 gpus to deploy a service. If the first request is a image request, will have following err.

[TensorRT-LLM][ERROR] Encountered an error in forwardAsync function: [TensorRT-LLM][ERROR] CUDA runtime error in cudaMemcpyAsync(dst, src.data(), src.getSizeInBytes(), cudaMemcpyDefault, mStream->get()): invalid argument (/workspace/tensorrt_llm/cpp/tensorrt_llm/runtime/bufferManager.cpp:146)
1       0x7fbc549104df void tensorrt_llm::common::check<cudaError>(cudaError, char const*, char const*, int) + 143
2       0x7fbbec78cce8 tensorrt_llm::batch_manager::PromptTuningBuffers::fill(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, tensorrt_llm::runtime::BufferManager const&, bool) + 1320
3       0x7fbbec791eb7 tensorrt_llm::batch_manager::RuntimeBuffers::setFromInputs(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int, int, tensorrt_llm::batch_manager::DecoderBuffers&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::rnn_state_manager::RnnStateManager*, std::map<unsigned long, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > >, std::less<unsigned long>, std::allocator<std::pair<unsigned long const, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > > > > > const&, tensorrt_llm::runtime::TllmRuntime const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&) + 9223
4       0x7fbbec794ab4 tensorrt_llm::batch_manager::RuntimeBuffers::prepareStep[abi:cxx11](std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int, int, tensorrt_llm::batch_manager::DecoderBuffers&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::rnn_state_manager::RnnStateManager*, std::map<unsigned long, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > >, std::less<unsigned long>, std::allocator<std::pair<unsigned long const, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > > > > > const&, tensorrt_llm::runtime::TllmRuntime const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&) + 276
5       0x7fbbec7b61df tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeStep(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int) + 191
6       0x7fbbec7b63d7 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeBatch(tensorrt_llm::batch_manager::ScheduledRequests const&) + 295
7       0x7fbbec7b7237 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::forwardAsync(std::__cxx11::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 2439
8       0x7fbbec7d8b83 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::__cxx11::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 419
9       0x7fbbec7dd8df tensorrt_llm::executor::Executor::Impl::executionLoop() + 975
10      0x7fbc516b0253 /lib/x86_64-linux-gnu/libstdc++.so.6(+0xdc253) [0x7fbc516b0253]
11      0x7fbc5143fac3 /lib/x86_64-linux-gnu/libc.so.6(+0x94ac3) [0x7fbc5143fac3]
12      0x7fbc514d1850 /lib/x86_64-linux-gnu/libc.so.6(+0x126850) [0x7fbc514d1850]
[TensorRT-LLM][ERROR] Encountered an error in forwardAsync function: [TensorRT-LLM][ERROR] CUDA runtime error in cudaMemcpyAsync(dst, src.data(), src.getSizeInBytes(), cudaMemcpyDefault, mStream->get()): invalid argument (/workspace/tensorrt_llm/cpp/tensorrt_llm/runtime/bufferManager.cpp:146)
1       0x7f3b1f5d14df void tensorrt_llm::common::check<cudaError>(cudaError, char const*, char const*, int) + 143
2       0x7f3ab738cce8 tensorrt_llm::batch_manager::PromptTuningBuffers::fill(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, tensorrt_llm::runtime::BufferManager const&, bool) + 1320
3       0x7f3ab7391eb7 tensorrt_llm::batch_manager::RuntimeBuffers::setFromInputs(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int, int, tensorrt_llm::batch_manager::DecoderBuffers&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::rnn_state_manager::RnnStateManager*, std::map<unsigned long, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > >, std::less<unsigned long>, std::allocator<std::pair<unsigned long const, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > > > > > const&, tensorrt_llm::runtime::TllmRuntime const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&) + 9223
4       0x7f3ab7394ab4 tensorrt_llm::batch_manager::RuntimeBuffers::prepareStep[abi:cxx11](std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int, int, tensorrt_llm::batch_manager::DecoderBuffers&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::rnn_state_manager::RnnStateManager*, std::map<unsigned long, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > >, std::less<unsigned long>, std::allocator<std::pair<unsigned long const, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > > > > > const&, tensorrt_llm::runtime::TllmRuntime const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&) + 276
5       0x7f3ab73b61df tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeStep(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int) + 191
6       0x7f3ab73b63d7 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeBatch(tensorrt_llm::batch_manager::ScheduledRequests const&) + 295
7       0x7f3ab73b7237 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::forwardAsync(std::__cxx11::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 2439
8       0x7f3ab73d8b83 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::__cxx11::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 419
9       0x7f3ab73dd8df tensorrt_llm::executor::Executor::Impl::executionLoop() + 975
10      0x7f3b1c4b0253 /lib/x86_64-linux-gnu/libstdc++.so.6(+0xdc253) [0x7f3b1c4b0253]
11      0x7f3b1c21fac3 /lib/x86_64-linux-gnu/libc.so.6(+0x94ac3) [0x7f3b1c21fac3]
12      0x7f3b1c2b1850 /lib/x86_64-linux-gnu/libc.so.6(+0x126850) [0x7f3b1c2b1850]
[TensorRT-LLM][ERROR] Encountered an error in forwardAsync function: [TensorRT-LLM][ERROR] CUDA runtime error in cudaMemcpyAsync(dst, src.data(), src.getSizeInBytes(), cudaMemcpyDefault, mStream->get()): invalid argument (/workspace/tensorrt_llm/cpp/tensorrt_llm/runtime/bufferManager.cpp:146)
1       0x7f9fcea024df void tensorrt_llm::common::check<cudaError>(cudaError, char const*, char const*, int) + 143
2       0x7f9f6678cce8 tensorrt_llm::batch_manager::PromptTuningBuffers::fill(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, tensorrt_llm::runtime::BufferManager const&, bool) + 1320
3       0x7f9f66791eb7 tensorrt_llm::batch_manager::RuntimeBuffers::setFromInputs(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int, int, tensorrt_llm::batch_manager::DecoderBuffers&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::rnn_state_manager::RnnStateManager*, std::map<unsigned long, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > >, std::less<unsigned long>, std::allocator<std::pair<unsigned long const, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > > > > > const&, tensorrt_llm::runtime::TllmRuntime const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&) + 9223
4       0x7f9f66794ab4 tensorrt_llm::batch_manager::RuntimeBuffers::prepareStep[abi:cxx11](std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int, int, tensorrt_llm::batch_manager::DecoderBuffers&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::rnn_state_manager::RnnStateManager*, std::map<unsigned long, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > >, std::less<unsigned long>, std::allocator<std::pair<unsigned long const, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > > > > > const&, tensorrt_llm::runtime::TllmRuntime const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&) + 276
5       0x7f9f667b61df tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeStep(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int) + 191
6       0x7f9f667b63d7 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeBatch(tensorrt_llm::batch_manager::ScheduledRequests const&) + 295
7       0x7f9f667b7237 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::forwardAsync(std::__cxx11::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 2439
8       0x7f9f667d8b83 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::__cxx11::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 419
9       0x7f9f667dd8df tensorrt_llm::executor::Executor::Impl::executionLoop() + 975
10      0x7f9fcb8b0253 /lib/x86_64-linux-gnu/libstdc++.so.6(+0xdc253) [0x7f9fcb8b0253]
11      0x7f9fcb63fac3 /lib/x86_64-linux-gnu/libc.so.6(+0x94ac3) [0x7f9fcb63fac3]
12      0x7f9fcb6d1850 /lib/x86_64-linux-gnu/libc.so.6(+0x126850) [0x7f9fcb6d1850]

But if the first request have no image, following image request will be ok. Even though, if i use 1 gpu to deploy, first image request will be also ok. Image request means have prompt-tunning table input, like codes:

static std::optional<executor::PromptTuningConfig> BuildPromptTuningForImages(const std::vector<std::string>& img_urls,
                                                                              VIT* vit,
                                                                              std::string& prompt) {
  if (img_urls.empty()) {
    return std::nullopt;
  }
  if (vit == nullptr) {
    throw std::invalid_argument("There is no vit model.");
  }
  std::shared_ptr<NamedTensor> vit_embeddings = vit->Encode(img_urls, prompt); // NamedTensor have image ptuning embeddings in gpu memory.
  if (vit_embeddings == nullptr) {
    throw std::runtime_error("Encode image to embeddings failed.");
  }
  auto e_tensor = executor::detail::ofITensor(vit_embeddings->tensor);
  // CLOG4(INFO, "e_tensor shape: [" << e_tensor.getShape()[0] << ", " << e_tensor.getShape()[1]
  //                                 << "], dtype: " << int(e_tensor.getDataType()) << ", size: " <<
  //                                 e_tensor.getSize());
  return executor::PromptTuningConfig(std::move(e_tensor));
}

auto p_tuning_config = BuildPromptTuningForImages(image_urls, vit, prompt);
executor::Request{***, p_tuning_config, ***};
@Superjomn Superjomn added bug Something isn't working triaged Issue has been triaged by maintainers labels Oct 20, 2024
@MartinMarciniszyn
Copy link
Collaborator

@symphonylyh, could you please take a look at this?

@symphonylyh
Copy link
Collaborator

Hi @zhaocc1106, where is this BuildPromptTuningForImages call? did you implement it yourself?
In parallel, currently we have plan to support end-to-end executor support for multimodal models. When it's done, I think your case can work fine as well

@zhaocc1106
Copy link
Author

zhaocc1106 commented Oct 21, 2024

Hi @zhaocc1106, where is this BuildPromptTuningForImages call? did you implement it yourself?
In parallel, currently we have plan to support end-to-end executor support for multimodal models. When it's done, I think your case can work fine as well

Yes, BuildPromptTuningForImages is myself function. I use c++ api.

@akhoroshev
Copy link
Contributor

akhoroshev commented Oct 23, 2024

I encountered same problem

[TensorRT-LLM][ERROR] tensorrt_llm::common::TllmException: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaFreeHost(ptr): unspecified launch failure (/home/askhoroshev/tensorrt-llm/cpp/tensorrt_llm/runtime/tllmBuffers.h:177)
1       0x7feb8af5e962 /home/askhoroshev/tensorrt-llm/cpp/build/tensorrt_llm/libtensorrt_llm.so(+0x70b962) [0x7feb8af5e962]
2       0x7feb8cc99fad virtual thunk to tensorrt_llm::runtime::GenericTensor<tensorrt_llm::runtime::PinnedAllocator>::~GenericTensor() + 125
3       0x7feb8d1453be tensorrt_llm::batch_manager::PromptTuningBuffers::fill(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, tensorrt_llm::runtime::BufferManager const&, bool) + 3758
4       0x7feb8d14d78f tensorrt_llm::batch_manager::RuntimeBuffers::setFromInputs(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int, int, tensorrt_llm::batch_manager::DecoderBuffers&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::rnn_state_manager::RnnStateManager*, std::map<unsigned long, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > >, std::less<unsigned long>, std::allocator<std::pair<unsigned long const, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > > > > > const&, tensorrt_llm::runtime::TllmRuntime const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&) + 8063
5       0x7feb8d14e052 tensorrt_llm::batch_manager::RuntimeBuffers::prepareStep[abi:cxx11](std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int, int, tensorrt_llm::batch_manager::DecoderBuffers&, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager*, tensorrt_llm::batch_manager::rnn_state_manager::RnnStateManager*, std::map<unsigned long, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > >, std::less<unsigned long>, std::allocator<std::pair<unsigned long const, std::shared_ptr<std::vector<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig, std::allocator<tensorrt_llm::runtime::LoraCache::TaskLayerModuleConfig> > > > > > const&, tensorrt_llm::runtime::TllmRuntime const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&) + 178
6       0x7feb8d16f6d4 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeStep(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int) + 164
7       0x7feb8d16f8de tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeBatch(tensorrt_llm::batch_manager::ScheduledRequests const&) + 222
8       0x7feb8d17003c tensorrt_llm::batch_manager::TrtGptModelInflightBatching::forwardAsync(std::__cxx11::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 1788
9       0x7feb8d19c231 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::__cxx11::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 353
10      0x7feb8d1a113f tensorrt_llm::executor::Executor::Impl::executionLoop() + 895
11      0x7feb772d1a80 /home/askhoroshev/tensorrt-llm/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/libtensorrt_llm_nvrtc_wrapper.so(+0x32c5a80) [0x7feb772d1a80]
12      0x7feb2e2941ca /lib64/libpthread.so.0(+0x81ca) [0x7feb2e2941ca]
13      0x7feb2d5c08d3 clone + 67

TP 4 llama like model and executor API

I'm passing the pinned tensor as an embeddings. If I pass the kCpu tensor everything will be fine.

I guess you miss synchronization point in batch_manager code. Because transfer from kCpu to kGpu implies implicit sync but transfer from kPinned to kGpu (and from kGpu to kGpu) does't.

@MartinMarciniszyn @symphonylyh

@akhoroshev
Copy link
Contributor

akhoroshev commented Oct 23, 2024

@zhaocc1106 try passing kCPU tensor instead of kGPU as a workaround

@zhaocc1106
Copy link
Author

@zhaocc1106 try passing kCPU tensor instead of kGPU as a workaround

But my vit_embeding is the output of tensorrt. It's in gpu device memory and i copy to trtllm gpu mem by D2D copy. Copy to cpu will result in time wasting.

It's strange that if the first request have no image, following image request will be ok.

@akhoroshev
Copy link
Contributor

akhoroshev commented Oct 23, 2024

But my vit_embeding is the output of tensorrt. It's in gpu device memory and i copy to trtllm gpu mem by D2D copy. Copy to cpu will result in time wasting.

I know, but it works for me )

@zhaocc1106
Copy link
Author

But my vit_embeding is the output of tensorrt. It's in gpu device memory and i copy to trtllm gpu mem by D2D copy. Copy to cpu will result in time wasting.

I know, but it works for me )

thanks, i will try

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

5 participants