diff --git a/README.md b/README.md
index 8f581897b..c2ab3f3bc 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@ TensorRT-LLM
[![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-12.5.1-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.4.0-green)](https://developer.nvidia.com/tensorrt)
-[![version](https://img.shields.io/badge/release-0.14.0.dev-green)](./tensorrt_llm/version.py)
+[![version](https://img.shields.io/badge/release-0.15.0.dev-green)](./tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
[Architecture](./docs/source/architecture/overview.md) | [Results](./docs/source/performance/perf-overview.md) | [Examples](./examples/) | [Documentation](./docs/source/)
@@ -17,12 +17,15 @@ TensorRT-LLM
## Latest News
-* [2024/09/29] 🌟 AI at Meta PyTorch + TensorRT v2.4 🌟 ⚡TensorRT 10.1 ⚡PyTorch 2.4 ⚡CUDA 12.4 ⚡Python 3.12
-[➡️ link](https://github.com/pytorch/TensorRT/releases/tag/v2.4.0)
+* [2024/10/07] 🚀🚀🚀Optimizing Microsoft Bing Visual Search with NVIDIA Accelerated Libraries
+[➡️ link](https://developer.nvidia.com/blog/optimizing-microsoft-bing-visual-search-with-nvidia-accelerated-libraries/)
-
+
+* [2024/09/29] 🌟 AI at Meta PyTorch + TensorRT v2.4 🌟 ⚡TensorRT 10.1 ⚡PyTorch 2.4 ⚡CUDA 12.4 ⚡Python 3.12
+[➡️ link](https://github.com/pytorch/TensorRT/releases/tag/v2.4.0)
+
* [2024/09/17] ✨ NVIDIA TensorRT-LLM Meetup
[➡️ link](https://drive.google.com/file/d/1RR8GqC-QbuaKuHj82rZcXb3MS20SWo6F/view?usp=share_link)
diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp
index b901a17bc..585579755 100644
--- a/benchmarks/cpp/gptManagerBenchmark.cpp
+++ b/benchmarks/cpp/gptManagerBenchmark.cpp
@@ -426,6 +426,7 @@ class Recorder
void initialize()
{
mStart = std::chrono::steady_clock::now();
+ mRequestsQueueingLatencies.clear();
}
void finalize()
@@ -433,6 +434,11 @@ class Recorder
mEnd = std::chrono::steady_clock::now();
}
+ void recordQueueLatency(std::vector
const& latencies)
+ {
+ mRequestsQueueingLatencies.insert(mRequestsQueueingLatencies.end(), latencies.begin(), latencies.end());
+ }
+
void recordStart(std::shared_ptr request, uint64_t requestId)
{
auto const inputLength = request->getInputIds()->getSize();
@@ -677,6 +683,16 @@ class Recorder
mMaxGenT2TLatency = genT2TLatencies.back();
mMinGenT2TLatency = genT2TLatencies.front();
}
+
+ mAvgReqQueueingLatency
+ = std::accumulate(mRequestsQueueingLatencies.begin(), mRequestsQueueingLatencies.end(), 0.F)
+ / mRequestsQueueingLatencies.size();
+ std::sort(mRequestsQueueingLatencies.begin(), mRequestsQueueingLatencies.end());
+ mP99ReqQueueingLatency = calcPercentile(mRequestsQueueingLatencies, 99);
+ mP90ReqQueueingLatency = calcPercentile(mRequestsQueueingLatencies, 90);
+ mP50ReqQueueingLatency = calcPercentile(mRequestsQueueingLatencies, 50);
+ mMaxReqQueueingLatency = mRequestsQueueingLatencies.back();
+ mMinReqQueueingLatency = mRequestsQueueingLatencies.front();
}
}
@@ -713,6 +729,13 @@ class Recorder
printf("[BENCHMARK] p99_inter_token_latency(ms) %.2f\n", mP99GenT2TLatency);
printf("[BENCHMARK] p90_inter_token_latency(ms) %.2f\n", mP90GenT2TLatency);
printf("[BENCHMARK] p50_inter_token_latency(ms) %.2f\n\n", mP50GenT2TLatency);
+
+ printf("[BENCHMARK] avg_request_queueing_latency(ms) %.2f\n", mAvgReqQueueingLatency);
+ printf("[BENCHMARK] max_request_queueing_latency(ms) %.2f\n", mMaxReqQueueingLatency);
+ printf("[BENCHMARK] min_request_queueing_latency(ms) %.2f\n", mMinReqQueueingLatency);
+ printf("[BENCHMARK] p99_request_queueing_latency(ms) %.2f\n", mP99ReqQueueingLatency);
+ printf("[BENCHMARK] p90_request_queueing_latency(ms) %.2f\n", mP90ReqQueueingLatency);
+ printf("[BENCHMARK] p50_request_queueing_latency(ms) %.2f\n\n", mP50ReqQueueingLatency);
}
}
@@ -820,6 +843,13 @@ class Recorder
float mP50GenT2TLatency{};
float mMaxGenT2TLatency{};
float mMinGenT2TLatency{};
+ float mAvgReqQueueingLatency{};
+ float mP99ReqQueueingLatency{};
+ float mP90ReqQueueingLatency{};
+ float mP50ReqQueueingLatency{};
+ float mMaxReqQueueingLatency{};
+ float mMinReqQueueingLatency{};
+ std::vector mRequestsQueueingLatencies{};
std::string mOpCsvFile;
bool mStreaming;
@@ -846,6 +876,7 @@ class ExecutorServer
, mActiveCount(0)
, mNumFinished(0)
, mShutdown(false)
+ , mLogIterationData(logIterationData)
{
texec::SchedulerConfig schedulerConfig(capacitySchedulerPolicy);
@@ -899,7 +930,9 @@ class ExecutorServer
TLLM_LOG_ERROR("not a supported executor model type in executor server.");
}
- if (logIterationData)
+ auto const& world = tensorrt_llm::mpi::MpiComm::world();
+ auto worldRank = world.getRank();
+ if (worldRank == 0)
{
mCollectStatsThread = std::thread(&ExecutorServer::collectStats, this);
}
@@ -988,7 +1021,18 @@ class ExecutorServer
auto iterStats = mExecutor->getLatestIterationStats();
for (auto const& iterStat : iterStats)
{
- TLLM_LOG_INFO(texec::JsonSerialization::toJsonStr(iterStat));
+ SizeType32 numNewActiveRequests = iterStat.numNewActiveRequests;
+ if (numNewActiveRequests > 0)
+ {
+ float avgQueueingTime
+ = static_cast(iterStat.newActiveRequestsQueueLatencyMS / numNewActiveRequests);
+ std::vector requestsQueueLatencyMS(numNewActiveRequests, avgQueueingTime);
+ mRecorder->recordQueueLatency(requestsQueueLatencyMS);
+ }
+ if (mLogIterationData)
+ {
+ TLLM_LOG_INFO(texec::JsonSerialization::toJsonStr(iterStat));
+ }
}
auto const waitSleep = std::chrono::milliseconds(50);
std::this_thread::sleep_for(waitSleep);
@@ -1005,6 +1049,7 @@ class ExecutorServer
std::atomic mActiveCount;
std::atomic mNumFinished;
std::atomic mShutdown;
+ bool mLogIterationData;
}; // class ExecutorServer
class GptServer
diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
index 475970b7b..a323a15f9 100644
--- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
+++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
@@ -201,6 +201,7 @@ class GenericLlmRequest
, mDecodingIter(0)
, mPriority(req.getPriority())
, mFinishReasons(mSamplingConfig.beamWidth)
+ , mEncoderInputFeatures(std::nullopt)
, mEncoderOutputLength(req.getEncoderOutputLength())
, mContextPhaseParams(req.getContextPhaseParams())
, mInputTokenExtraIds(std::nullopt)
@@ -263,7 +264,8 @@ class GenericLlmRequest
auto pTuningConfig = req.getPromptTuningConfig();
if (pTuningConfig)
{
- mPromptEmbeddingTable = executor::detail::toITensor(pTuningConfig.value().getEmbeddingTable());
+ mPromptEmbeddingTable = tensorrt_llm::runtime::ITensor::view(
+ executor::detail::toITensor(pTuningConfig.value().getEmbeddingTable()));
TLLM_CHECK(mPromptEmbeddingTable.value()->getShape().nbDims == 2);
mPromptVocabSize = mPromptEmbeddingTable.value()->getShape().d[0];
mPromptEmbeddingTable.value()->unsqueeze(0);
@@ -1438,6 +1440,36 @@ class GenericLlmRequest
0.0, std::chrono::duration(mKvCacheTransferEnd - mKvCacheTransferStart).count());
}
+ void updateAllocTotalBlocksPerRequest(SizeType32 allocTotalBlocksPerRequest)
+ {
+ mAllocTotalBlocksPerRequest += allocTotalBlocksPerRequest;
+ }
+
+ [[nodiscard]] SizeType32 getAllocTotalBlocksPerRequest() const
+ {
+ return mAllocTotalBlocksPerRequest;
+ }
+
+ void updateAllocNewBlocksPerRequest(SizeType32 allocNewBlocksPerRequest)
+ {
+ mAllocNewBlocksPerRequest += allocNewBlocksPerRequest;
+ }
+
+ [[nodiscard]] SizeType32 getAllocNewBlocksPerRequest() const
+ {
+ return mAllocNewBlocksPerRequest;
+ }
+
+ void updateReusedBlocksPerRequest(SizeType32 reusedBlocksPerRequest)
+ {
+ mReusedBlocksPerRequest += reusedBlocksPerRequest;
+ }
+
+ [[nodiscard]] SizeType32 getReusedBlocksPerRequest() const
+ {
+ return mReusedBlocksPerRequest;
+ }
+
RequestIdType mRequestId;
SizeType32 mPromptLen;
SizeType32 mMaxNewTokens;
@@ -1545,6 +1577,10 @@ class GenericLlmRequest
std::chrono::time_point mKvCacheTransferStart;
std::chrono::time_point mKvCacheTransferEnd;
+ SizeType32 mAllocTotalBlocksPerRequest{0};
+ SizeType32 mAllocNewBlocksPerRequest{0};
+ SizeType32 mReusedBlocksPerRequest{0};
+
private:
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
{
diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h
index 5a8525caf..c9ff1e099 100644
--- a/cpp/include/tensorrt_llm/executor/types.h
+++ b/cpp/include/tensorrt_llm/executor/types.h
@@ -297,6 +297,8 @@ struct IterationStats
double iterLatencyMS;
/// @brief The total time spent in queue by the requests that became active in this iteration (ms)
double newActiveRequestsQueueLatencyMS;
+ /// @brief Number of new fetched active requests
+ SizeType32 numNewActiveRequests;
/// @brief Number of active requests
SizeType32 numActiveRequests;
/// @brief Number of queued requests
@@ -364,6 +366,12 @@ struct RequestStats
bool paused;
/// @brief Stats specific to disaggregated serving
std::optional disServingStats;
+ /// @brief Number of total allocated blocks per request
+ SizeType32 allocTotalBlocksPerRequest;
+ /// @brief Number of newly allocated blocks per request
+ SizeType32 allocNewBlocksPerRequest;
+ /// @brief Number of reused blocks per request
+ SizeType32 reusedBlocksPerRequest;
};
/// @brief Struct that holds the stats of all requests in an iteration
diff --git a/cpp/include/tensorrt_llm/runtime/gptSession.h b/cpp/include/tensorrt_llm/runtime/gptSession.h
index 46cd19902..a4b8e4cc3 100644
--- a/cpp/include/tensorrt_llm/runtime/gptSession.h
+++ b/cpp/include/tensorrt_llm/runtime/gptSession.h
@@ -115,7 +115,6 @@ class [[deprecated("Use the executor API instead.")]] GptSession
std::optional genMicroBatchSize = std::nullopt;
std::optional decodingMode = std::nullopt;
bool normalizeLogProbs = true;
- std::optional enginePath;
};
//! @brief Optional profiler class to profile the generation phase of an inference request
diff --git a/cpp/include/tensorrt_llm/runtime/modelConfig.h b/cpp/include/tensorrt_llm/runtime/modelConfig.h
index b1b495e75..ce8985b56 100644
--- a/cpp/include/tensorrt_llm/runtime/modelConfig.h
+++ b/cpp/include/tensorrt_llm/runtime/modelConfig.h
@@ -127,6 +127,7 @@ class ModelConfig
, mContextFMHA(false)
, mPagedContextFMHA(false)
, mUseXQA{false}
+ , mPpReduceScatter{false}
, mUseLoraPlugin(false)
, mMlpHiddenSize(0)
, mUseCrossAttention(false)
@@ -468,6 +469,16 @@ class ModelConfig
return mUseXQA;
}
+ void constexpr setPpReduceScatter(bool ppReduceScatter) noexcept
+ {
+ mPpReduceScatter = ppReduceScatter;
+ }
+
+ [[nodiscard]] bool constexpr getPpReduceScatter() const noexcept
+ {
+ return mPpReduceScatter;
+ }
+
[[nodiscard]] bool constexpr useLoraPlugin() const noexcept
{
return mUseLoraPlugin;
@@ -759,6 +770,7 @@ class ModelConfig
bool mContextFMHA;
bool mPagedContextFMHA;
bool mUseXQA;
+ bool mPpReduceScatter;
bool mUseLoraPlugin;
std::vector mLoraModules;
diff --git a/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h b/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h
index e739e8188..3b396122b 100644
--- a/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h
+++ b/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h
@@ -50,6 +50,11 @@ class SpeculativeDecodingMode
return SpeculativeDecodingMode{kExplicitDraftTokens};
}
+ static auto constexpr Eagle()
+ {
+ return SpeculativeDecodingMode{kEagle};
+ }
+
[[nodiscard]] bool constexpr isNone() const
{
return anyBitSet(kNone);
@@ -75,29 +80,34 @@ class SpeculativeDecodingMode
return anyBitSet(kExplicitDraftTokens);
}
+ [[nodiscard]] bool constexpr isEagle() const
+ {
+ return anyBitSet(kEagle);
+ }
+
[[nodiscard]] bool constexpr updatesPositionIds() const
{
- return anyBitSet(kLookaheadDecoding | kExplicitDraftTokens);
+ return anyBitSet(kLookaheadDecoding | kExplicitDraftTokens | kEagle);
}
[[nodiscard]] bool constexpr requiresAttentionMask() const
{
- return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens);
+ return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens | kEagle);
}
[[nodiscard]] bool constexpr predictsDraftTokens() const
{
- return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens);
+ return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens | kEagle);
}
[[nodiscard]] bool constexpr needsKVCacheRewind() const
{
- return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens);
+ return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens | kEagle);
}
[[nodiscard]] bool constexpr variableDraftLength() const
{
- return anyBitSet(kDraftTokensExternal | kExplicitDraftTokens | kLookaheadDecoding);
+ return anyBitSet(kDraftTokensExternal | kExplicitDraftTokens | kLookaheadDecoding | kEagle);
}
[[nodiscard]] bool constexpr hasDraftLogits() const
@@ -107,7 +117,7 @@ class SpeculativeDecodingMode
[[nodiscard]] bool constexpr needsDecoderPrologue() const
{
- return anyBitSet(kExplicitDraftTokens | kLookaheadDecoding);
+ return anyBitSet(kExplicitDraftTokens | kLookaheadDecoding | kEagle);
}
using UnderlyingType = std::uint8_t;
@@ -129,6 +139,7 @@ class SpeculativeDecodingMode
static UnderlyingType constexpr kMedusa{1U << 2U};
static UnderlyingType constexpr kLookaheadDecoding{1U << 3U};
static UnderlyingType constexpr kExplicitDraftTokens{1U << 4U};
+ static UnderlyingType constexpr kEagle{1U << 5U};
[[nodiscard]] bool constexpr anyBitSet(UnderlyingType bits) const
{
@@ -173,4 +184,11 @@ static_assert(!SpeculativeDecodingMode::ExplicitDraftTokens().isDraftTokensExter
static_assert(!SpeculativeDecodingMode::ExplicitDraftTokens().isMedusa());
static_assert(!SpeculativeDecodingMode::ExplicitDraftTokens().isLookaheadDecoding());
+static_assert(SpeculativeDecodingMode::Eagle().isEagle());
+static_assert(!SpeculativeDecodingMode::Eagle().isNone());
+static_assert(!SpeculativeDecodingMode::Eagle().isDraftTokensExternal());
+static_assert(!SpeculativeDecodingMode::Eagle().isMedusa());
+static_assert(!SpeculativeDecodingMode::Eagle().isExplicitDraftTokens());
+static_assert(!SpeculativeDecodingMode::Eagle().isLookaheadDecoding());
+
} // namespace tensorrt_llm::runtime
diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
index d04677b80..ddf6e4806 100644
--- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
+++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:1a292517d802f2297c5d12d5d14ab597f47f46ebd31412fac044ceb9ca51a482
-size 5160586
+oid sha256:a55035628e0035141b4ea79b946f49ad77893d6e5d1ab47c402e1a9b95fbbb6c
+size 5160128
diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
index 462c03949..850e53457 100644
--- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:8575fb58200701ae30feb4b8bd3f325f8018aac5505167fdba42e269adb3bd8c
-size 5271836
+oid sha256:ed219fad83caf000a40f0688fdb20cb8593a5fe8096316d645229ee160c42514
+size 5271480
diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
index aff5e53bd..1d38b0ca3 100644
--- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-954182e0c057f71f858a84f746201044 libtensorrt_llm_batch_manager_static.a
-dfe6ca360cf1d24a3dcae0a2bf8589c0 libtensorrt_llm_batch_manager_static.pre_cxx11.a
-4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
\ No newline at end of file
+d7508bec7b6f112a2eac04cbeaf8b5da libtensorrt_llm_batch_manager_static.a
+d8969624b327af844d9ffba910084b93 libtensorrt_llm_batch_manager_static.pre_cxx11.a
+3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
index 4e5be000e..d11b40f7c 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:8fe84073b7ccff8dc361fdee64c3ef30bc523909e0bf9c16547f76a05a53fb5c
-size 5009886
+oid sha256:36479d1577d131e36ca03549467a6cfe4822868ca0f3dda3b5d254ee4680341f
+size 5009646
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
index 46d8c1b5c..a1485d52a 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:6e565c2c3ce58656742772591d992aca91c7e46eb9fc711599d2d51928b88b48
-size 4970532
+oid sha256:b5caef410133f1552418978aa20cc1d3f7b6500b1dbc8b9f44232554b7cc8390
+size 4971234
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt
index 2c9c2852f..89f9c2b21 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-61fd34e765788884d42f4ba27f085520 libtensorrt_llm_batch_manager_static.a
-e8a64dd19a234304483ef6756e67fd40 libtensorrt_llm_batch_manager_static.pre_cxx11.a
-4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
\ No newline at end of file
+7029ee9cb0a921a3603e98815da18985 libtensorrt_llm_batch_manager_static.a
+0e7fe69b6621fe6dabcc0b372c3440f4 libtensorrt_llm_batch_manager_static.pre_cxx11.a
+3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib
index d1664c2e8..42a6fe97d 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:200a6721aa1d6e009c94866adab36ac686eb1beef02df267af7e18e31e11612b
-size 32436708
+oid sha256:b86e215e86c7b0f8b0c9618fb655e6e4f31cc731f778cf0ca12fde93c7afbcab
+size 32389592
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/version.txt b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/version.txt
index 45482c43b..0679a9114 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/version.txt
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/version.txt
@@ -1,2 +1,2 @@
-9485cfa635b17378f23d1624b3acfbaf tensorrt_llm_batch_manager_static.lib
-4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
\ No newline at end of file
+afac175cfda36b14d76e17517bad8b24 tensorrt_llm_batch_manager_static.lib
+3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h
index a4f80dc6f..f81961dee 100644
--- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h
+++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h
@@ -92,7 +92,7 @@ template <
typename Policy_,
/// Number of stages,
int Stages,
- /// Converter for B matrix applited immediately after the LDS
+ /// Converter for B matrix applied immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_,
diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a
index 26b60736a..2fd74350e 100644
--- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a
+++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:809a1da76123ec4c640d63efc902209585223b66e23d887db9a198c5836986a2
-size 3349066
+oid sha256:414606be5b56f592fc7bd25f1e9fbf958c900dd2b01e01907029dfe19408ce59
+size 3349230
diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
index 007fa3207..095132fac 100644
--- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:6846ecefa017d03ab7d853908794c884ab4e92a500e223278b1d64eab59ed061
-size 3376088
+oid sha256:682cf952def054fce6116983a3b5686994b71744fcc85a65e3c9a6e44549c82d
+size 3377832
diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt
index 4a30230b9..e73a6e86b 100644
--- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-5a771664fdb75d99ba5fb90249ac26f0 libtensorrt_llm_executor_static.a
-3b433ea93b7d1d6fa471b457980f2680 libtensorrt_llm_executor_static.pre_cxx11.a
-4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
\ No newline at end of file
+dc9b4081af6357227886180a1b9a6d8d libtensorrt_llm_executor_static.a
+8291552cf3e8da9dc368c8c37cd35abe libtensorrt_llm_executor_static.pre_cxx11.a
+3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a
index 7584b1fe6..5d0776dfa 100644
--- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a
+++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:479e86f410763445357f5d879cc666d210352dda9709ab5ab56e73591a9e8af8
-size 7851266
+oid sha256:88810c1dac205a1111fc833c0fe0d38486152b4b878fd972585eec2ac27d5160
+size 7857242
diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
index 0f764244d..425d255fc 100644
--- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:6473c77d18929fa75342d63ffc591df39e8aeba1dda0b920b0187d4888710559
-size 7767384
+oid sha256:c023d6bad569fb3b3c528f3e003afa6a5f11a045bdccb06ca875607a6c781ade
+size 7769728
diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt
index 4baf60ba7..9ff444cfe 100644
--- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-5424fb0f82076e03b5316f73aed04434 libtensorrt_llm_executor_static.a
-d0b1236baf61fc5c43383bbc1cd50fa8 libtensorrt_llm_executor_static.pre_cxx11.a
-4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
\ No newline at end of file
+fd9cb10c300350266f65957475404bff libtensorrt_llm_executor_static.a
+b8b0ae2861ef66853330441752ab1e32 libtensorrt_llm_executor_static.pre_cxx11.a
+3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib
index efd7ecf87..f9e5e12f7 100644
--- a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib
+++ b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:dee57c9257a6678833e3c0d83e8df07aff25c185bc085db75938cec6652044c0
-size 24568210
+oid sha256:baf4dd1bacd75c4eae6d98fe411bbb5d478dc5905a298d4238db3db21121ebca
+size 24630026
diff --git a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/version.txt b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/version.txt
index 681dc3284..f46f09905 100644
--- a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/version.txt
+++ b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/version.txt
@@ -1,2 +1,2 @@
-305fac5d046a574ded2d46d968f746b0 tensorrt_llm_executor_static.lib
-4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
\ No newline at end of file
+30d62c80211e4a2dc38bbe9dc5257839 tensorrt_llm_executor_static.lib
+3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl
index 1a0f6bc65..126e761ec 100644
--- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl
+++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl
@@ -74,7 +74,6 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe
int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks());
int const threadblock_count = multi_processor_count * occupancy;
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel");
- GemmType gemm;
using Arguments = typename GemmType::Arguments;
Arguments args{{const_cast(A), const_cast(B), const_cast(biases),
reinterpret_cast(C), total_tokens_including_expert, static_cast(gemm_n),
diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt
index 92ae4d99b..c19ceafee 100644
--- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt
@@ -1,2 +1,2 @@
88c30973b9b3452baa3f063d34d08169 libtensorrt_llm_nvrtc_wrapper.so
-4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
\ No newline at end of file
+3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/version.txt
index e2ce46ae4..9fa1f5280 100644
--- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/version.txt
@@ -1,2 +1,2 @@
95e9f87610383348e444d2d0b8396f2d libtensorrt_llm_nvrtc_wrapper.so
-4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
\ No newline at end of file
+3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll
index 3f82a0827..ccb5cdd40 100644
--- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll
+++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:db512d533ab4e4a4abd0047a65d891dfd6e1522f2d34c90f29296c3239fd3cc1
+oid sha256:3bc495e1e677616db2756eb7d56d1161c34ae723896db34487883a955e2b3442
size 1128448
diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.lib b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.lib
index cfe4399d6..eb4782449 100644
--- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.lib
+++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.lib
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:e207a8f57b944529163c7ed2ab30639a5f2779c5118602c6ebd50a623d16f845
+oid sha256:1a6c03470aaa69378d4989971ab9dd00ee427f7e14a85ba5e114ea0594c4de5e
size 3488
diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/version.txt
index 465df4be7..1f123d67b 100644
--- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/version.txt
+++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/version.txt
@@ -1,3 +1,3 @@
-b7e624ba775e9f5090ef4b67bcdbd7a2 tensorrt_llm_nvrtc_wrapper.lib
-d89a0a140d2d427af13c3794a4b21e2c tensorrt_llm_nvrtc_wrapper.dll
-4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
\ No newline at end of file
+c5f36e093e875c8ea84523fb1566d986 tensorrt_llm_nvrtc_wrapper.lib
+de4b2f87f8eb1027f89c0f5cb05ca047 tensorrt_llm_nvrtc_wrapper.dll
+3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a
index 70cc1d3d6..6b5ab2887 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:0814af36fed752bbe70d953cefbb78dd306c42f3d9f6848b7043a865e48f9662
+oid sha256:80dbb6e3a34380bf4e375901ad9b71df24ec97cddcaa9f226bc0a278d11cbdd6
size 25364090
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
index 84879c280..2910af2e3 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:ee46f2d1c9162f4302a1031f778fcb7c7110c84110427f97af6532ed9bd342fd
+oid sha256:31e5cd6ef9e3599d55501ab0484b81f82ef1f22a79360a2699cd4a62c4928115
size 25768990
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt
index 736fddd4a..8c8438147 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-90740ead1def66f350e14c133278463d libtensorrt_llm_internal_cutlass_kernels_static.a
-b0104227ffd1ce19fc1fdb45e349df36 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
-4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
\ No newline at end of file
+1febd9d1bf244163deb269e2bebcd1e3 libtensorrt_llm_internal_cutlass_kernels_static.a
+8fdb39f871225dedd32ca6651f1944ba libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
+3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a
index 573caf92e..3ac157472 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:4d9ba0f8b95cf64227cb0b17654fb7c9bc1741fe003889658b305750b388a4dc
+oid sha256:3431f91bcb2cadb8a2641c4ea54d1f8f90c5aa7648591510e3a27865c94169ea
size 44173632
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
index daa8557bd..f9ab0f6e1 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:4f848d5beebbd69792047a96b16f7145f8e1e3e311d2a19789ce639ad8149b0e
+oid sha256:1dedd4dd1df76a57576e749b4105a5d5f5070a6f7ee30d11944105742fea9b4b
size 43561206
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt
index 0c0c38e19..69baedf76 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-2aaf05cb84f52b024e89d4fa634d6900 libtensorrt_llm_internal_cutlass_kernels_static.a
-f17ce186e9105c594e39d252777ce4c7 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
-4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
\ No newline at end of file
+8683b15e77bf62ee9f57a2507e21e6a7 libtensorrt_llm_internal_cutlass_kernels_static.a
+a065a7b6a11b079ee544664dddcf59a6 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
+3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/tensorrt_llm_internal_cutlass_kernels_static.lib b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/tensorrt_llm_internal_cutlass_kernels_static.lib
index 5aa0009ca..00c671277 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/tensorrt_llm_internal_cutlass_kernels_static.lib
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/tensorrt_llm_internal_cutlass_kernels_static.lib
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:c429687e335c75f08186bcd8f629b50467cb0f2e484d755834c5b1cdbb9ecaf3
-size 88140796
+oid sha256:c7afdf2c313685b0e31f4e5572e20cd11d94227177849784ce7405e15a3587f6
+size 88140804
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/version.txt
index e14aff7e8..889c9577f 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/version.txt
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/version.txt
@@ -1,2 +1,2 @@
-4f663be2b768088805ccec6dc33545fc tensorrt_llm_internal_cutlass_kernels_static.lib
-4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
\ No newline at end of file
+7eee845e969cfb8d589074d81288b700 tensorrt_llm_internal_cutlass_kernels_static.lib
+3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu b/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu
index 04edc841a..8811c484e 100644
--- a/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu
+++ b/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu
@@ -48,7 +48,7 @@ __global__ void topKStage1(T const* __restrict logProbs, T const* const* __restr
auto const tokenIdx = static_cast(blockIdx.y);
auto const batchId = bid / BLOCKS_PER_BEAM_; // row id for logProbs
- auto const batchSlot = batchSlots[batchId];
+ auto const batchSlot = batchSlots == nullptr ? batchId : batchSlots[batchId];
if (tokensPerStep != nullptr && tokenIdx >= tokensPerStep[batchSlot])
{
return;
@@ -63,7 +63,6 @@ __global__ void topKStage1(T const* __restrict logProbs, T const* const* __restr
auto const logBufIndex = batchId * maxTokensPerStep * vocabSize + tokenIdx * vocabSize;
auto logProbsSlot
= logProbsPtrs == nullptr ? logProbs + logBufIndex : logProbsPtrs[batchId * maxTokensPerStep + tokenIdx];
-
auto const blockLane = bid % BLOCKS_PER_BEAM_; // block id for a beam
auto const k = (topKs != nullptr) ? topKs[batchSlot] : maxTopK; // batchId = batch index
@@ -77,7 +76,7 @@ __global__ void topKStage1(T const* __restrict logProbs, T const* const* __restr
if (finished != nullptr && finishState.isFinished())
{
- if (tid < k)
+ if (tid < k && endIds != nullptr) // if returnAllSelectedToken, endIds would not be an input
{
auto const index = tmpTopKBufIndex + tid;
if (blockLane == 0 && tid == 0)
@@ -134,7 +133,7 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
float const* topPs, curandState_t* curandState, TokenIdType const* endIds, SizeType32 vocabSize,
bool const* skipDecode, SizeType32 const* batchSlots, SizeType32 maxBatchSize, bool normalizeLogProbs,
bool logitHasProbs, SizeType32 const* tokensPerStep, SizeType32 maxTokensPerStep, SizeType32 maxSeqLen,
- bool returnAllTopK)
+ bool returnAllSelectedTokens)
{
bool const IS_FP16 = std::is_same::value;
T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
@@ -142,7 +141,7 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
auto const tid = static_cast(threadIdx.x);
auto const batchIdx = static_cast(blockIdx.x);
auto const tokenIdx = static_cast(blockIdx.y);
- auto const batchSlot = batchSlots[batchIdx];
+ auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
FinishedState const finishState = finishedInput != nullptr ? finishedInput[batchSlot] : FinishedState::empty();
if ((skipDecode != nullptr && skipDecode[batchSlot]) || (finishState.isSkipDecoding()))
{
@@ -215,13 +214,16 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
if (tid == 0)
{
- auto randNum = static_cast(curand_uniform(curandState + batchSlot) * probThreshold * sSum);
+ // if we want to return all top k indices, we should not do random sampling for probThreshold
+ auto randNum = (returnAllSelectedTokens || curandState == nullptr)
+ ? static_cast(probThreshold * sSum)
+ : static_cast(curand_uniform(curandState + batchSlot) * probThreshold * sSum);
auto* outputIdsRequestPtr = idsPtrs == nullptr ? ids + batchSlot * maxSeqLen : idsPtrs[batchSlot];
for (SizeType32 ki = 0; ki < k; ki++)
{
auto expLogit = sVal2[ki];
randNum = randNum - expLogit;
- if (randNum <= 0.0f || ki == k - 1 || returnAllTopK)
+ if (randNum <= 0.0f || ki == k - 1 || returnAllSelectedTokens)
{
auto idx = sId[ki];
// If sId is -1 here we force output token to the last from vocabulary to get vivid indicator of smth
@@ -230,10 +232,10 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
? topKTmpIdBuf[(batchIdx * maxTokensPerStep + tokenIdx) * stride + idx] % vocabSize
: vocabSize - 1;
auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot];
- auto const outIdx = returnAllTopK ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx;
+ auto const outIdx = returnAllSelectedTokens ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx;
outputIdsRequestPtr[outIdx] = outputId;
- // cum log prob is not supported with returnAllTopK
- if (!returnAllTopK)
+ // cum log prob is not supported with returnAllSelectedTokens
+ if (!returnAllSelectedTokens)
{
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
{
@@ -255,9 +257,17 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
}
break;
}
+ if (returnAllSelectedTokens && randNum <= 0.0f)
+ {
+ if (ki < k - 1)
+ { // not the last k, write a -1 to to log top p tokens boundary for external draft token masking
+ outputIdsRequestPtr[outIdx + 1] = -1;
+ }
+ break;
+ }
}
}
- if (maxTokensPerStep == 1 && !returnAllTopK && sequenceLengths != nullptr && finishedOutput != nullptr
+ if (maxTokensPerStep == 1 && !returnAllSelectedTokens && sequenceLengths != nullptr && finishedOutput != nullptr
&& endIds != nullptr)
{
auto const seqLen = sequenceLengths[batchSlot];
@@ -297,7 +307,7 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
params.maxTopK, params.topKs, params.maxTopP, params.topPs, params.curandState, params.endIds, \
params.vocabSizePadded, params.skipDecode, params.batchSlots, params.maxBatchSize, \
params.normalizeLogProbs, params.logitsHasProbs, params.tokensPerStep, params.maxTokensPerStep, \
- params.maxSeqLen, params.returnAllTopK); \
+ params.maxSeqLen, params.returnAllSelectedTokens); \
} \
} while (0)
diff --git a/cpp/tensorrt_llm/kernels/samplingTopKKernels.h b/cpp/tensorrt_llm/kernels/samplingTopKKernels.h
index 0330cad31..c14e73ab9 100644
--- a/cpp/tensorrt_llm/kernels/samplingTopKKernels.h
+++ b/cpp/tensorrt_llm/kernels/samplingTopKKernels.h
@@ -34,8 +34,8 @@ struct TopKSamplingKernelParams
//! Log probabilities of each token in the vocab. If logitsHasProbs is true,
//! logProbs must contain **just** probabilities instead of log probabilities.
T const* logProbs{nullptr};
- //! input buffer [batchSize][vocabSizePadded] array of pointers to logits.
- //! If nullptr, logProbs is used. Only maxTokensPerStep == 1 is supported.
+ //! input buffer [batchSize][tokensPerStep, vocabSizePadded] array of pointers to logits.
+ //! If nullptr, logProbs is used.
T const* const* logProbsPtrs{nullptr};
//! output buffer [maxBatchSize][maxSeqLen], optional. Contains pointers to rows
@@ -82,7 +82,8 @@ struct TopKSamplingKernelParams
//! Ignored if nullptr.
float* outputLogProbs{nullptr};
- //! input buffer [maxBatchSize]. Initialized curand states
+ //! input buffer [maxBatchSize], optional. Initialized curand states.
+ //! If nullptr, 1 is always used.
curandState_t* curandState{nullptr};
//! input buffer [maxBatchSize]. K for topK sampling per request.
//! Supported K is in range [1; 1024]. Where K=1 is greedy search.
@@ -106,8 +107,8 @@ struct TopKSamplingKernelParams
bool normalizeLogProbs{false};
//! flag to highlight that logProbs contains probabilities
bool logitsHasProbs{false};
- //! flag to return all selectedTopK results
- bool returnAllTopK{false};
+ //! flag to return all selected TopK results
+ bool returnAllSelectedTokens{false};
void checkParams() const
{
@@ -131,13 +132,12 @@ struct TopKSamplingKernelParams
}
TLLM_CHECK(workspace);
- TLLM_CHECK(curandState);
- TLLM_CHECK(maxTokensPerStep != 1 || returnAllTopK || sequenceLengths);
- TLLM_CHECK(maxTokensPerStep != 1 || returnAllTopK || endIds);
+ TLLM_CHECK(maxTokensPerStep != 1 || returnAllSelectedTokens || sequenceLengths);
+ TLLM_CHECK(maxTokensPerStep != 1 || returnAllSelectedTokens || endIds);
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
{
- TLLM_CHECK(maxTokensPerStep == 1 && !returnAllTopK);
+ TLLM_CHECK(maxTokensPerStep == 1 && !returnAllSelectedTokens);
}
TLLM_CHECK(((finishedOutput == nullptr) ^ (endIds == nullptr)) == 0);
diff --git a/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu b/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu
index 5605dbcea..23a8db7bb 100644
--- a/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu
+++ b/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu
@@ -200,7 +200,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
SizeType32* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, SizeType32 const* beginOffsetBuf, SizeType32 const* offsetBuf, SizeType32 vocabSize,
curandState_t* curandState, float const* topPs, TokenIdType const* endIds, SizeType32 maxBatchSize,
- bool const* skipDecode, SizeType32 const* batchSlots, bool returnAllTopP, SizeType32 maxSeqLen)
+ bool const* skipDecode, SizeType32 const* batchSlots, bool returnAllSelectedTokens, SizeType32 maxSeqLen)
{
/**
* Each block processes one request row sorted in descending order by probabilities.
@@ -244,7 +244,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
if (threadIdx.x == 0)
{
// if we want to return all top p indices, we should not do random sampling for probThreshold
- randNumS = returnAllTopP ? probThreshold : curand_uniform(curandState + blockIdx.x) * probThreshold;
+ randNumS = returnAllSelectedTokens ? probThreshold : curand_uniform(curandState + blockIdx.x) * probThreshold;
}
// if beginOffsetBuf and offsetBuf of sorting have same value,
@@ -255,7 +255,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
if (tid == 0)
{
auto offset = batchId * vocabSize;
- if (returnAllTopP)
+ if (returnAllSelectedTokens)
{
outputIdsRequestPtr[currentStep] = sortedIdVals[offset];
}
@@ -294,7 +294,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
}
}
- if (returnAllTopP)
+ if (returnAllSelectedTokens)
{
__shared__ SizeType32 sharedSelectedTokenId;
if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1))
@@ -403,7 +403,7 @@ void invokeBatchTopPSampling(TopPSamplingKernelParams const& params, cudaStre
params.outputIds, params.outputIdsPtrs, params.sequenceLength, params.finishedInput, params.finishedOutput,
params.cumLogProbs, params.outputLogProbs, beginOffsetBuf, offsetBuf + 1, params.vocabSizePadded,
params.curandState, params.topPs, params.endIds, params.maxBatchSize, params.skipDecode, params.batchSlots,
- params.returnAllTopP, params.maxSeqLen);
+ params.returnAllSelectedTokens, params.maxSeqLen);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
diff --git a/cpp/tensorrt_llm/kernels/samplingTopPKernels.h b/cpp/tensorrt_llm/kernels/samplingTopPKernels.h
index 2ab025ba0..639d7d4d6 100644
--- a/cpp/tensorrt_llm/kernels/samplingTopPKernels.h
+++ b/cpp/tensorrt_llm/kernels/samplingTopPKernels.h
@@ -80,7 +80,7 @@ struct TopPSamplingKernelParams
runtime::SizeType32 vocabSizePadded{-1};
runtime::SizeType32 maxSeqLen{-1};
- bool returnAllTopP{false};
+ bool returnAllSelectedTokens{false};
void checkParams() const
{
@@ -91,7 +91,7 @@ struct TopPSamplingKernelParams
TLLM_CHECK(probs);
TLLM_CHECK(outputIds || outputIdsPtrs);
TLLM_CHECK(workspace);
- TLLM_CHECK((sequenceLength != nullptr) || returnAllTopP);
+ TLLM_CHECK((sequenceLength != nullptr) || returnAllSelectedTokens);
TLLM_CHECK(curandState);
TLLM_CHECK(topPs);
diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu b/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu
new file mode 100644
index 000000000..b03b674fa
--- /dev/null
+++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu
@@ -0,0 +1,136 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tensorrt_llm/common/assert.h"
+#include "tensorrt_llm/common/cudaTypeUtils.cuh"
+#include "tensorrt_llm/common/cudaUtils.h"
+#include "tensorrt_llm/common/memoryUtils.h"
+#include "tensorrt_llm/common/reduceKernelUtils.cuh"
+#include "tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h"
+#ifndef CUDART_VERSION
+#error CUDART_VERSION Undefined!
+#elif (CUDART_VERSION >= 11050)
+#include
+#else
+#include "3rdparty/cub/cub.cuh"
+#endif
+
+using namespace tensorrt_llm::common;
+using namespace tensorrt_llm::runtime;
+
+namespace tensorrt_llm::kernels::speculative_decoding
+{
+namespace
+{
+template
+__global__ void assembleTargetLogitsOffsets(T const** logitsPtrs, SizeType32* decodingTokens, T const* logits,
+ SizeType32 const* draftDecodingTokens, SizeType32 batchSize, SizeType32 maxDecodingTokens,
+ SizeType32 vocabSizePadded)
+{
+ typedef cub::BlockScan BlockScan;
+ __shared__ typename BlockScan::TempStorage tempStorage;
+
+ auto const tix = static_cast(threadIdx.x);
+
+ SizeType32 numDecodingTokens{0};
+ if (tix < batchSize)
+ {
+ numDecodingTokens = draftDecodingTokens[tix] + 1;
+ decodingTokens[tix] = numDecodingTokens;
+ }
+
+ SizeType32 logitsOffset{0};
+ BlockScan(tempStorage).ExclusiveSum(numDecodingTokens, logitsOffset);
+
+ if (tix < batchSize)
+ {
+ for (SizeType32 ti = 0; ti < numDecodingTokens; ++ti)
+ {
+ logitsPtrs[tix * maxDecodingTokens + ti] = logits + (logitsOffset + ti) * vocabSizePadded;
+ }
+ }
+}
+} // namespace
+
+template
+void invokeAssembleTargetLogitsOffsets(T const** logitsPtrs, SizeType32* decodingTokens, T const* logits,
+ SizeType32 const* draftDecodingTokens, SizeType32 batchSize, SizeType32 maxDecodingTokens,
+ SizeType32 vocabSizePadded, cudaStream_t stream)
+{
+ SizeType32 constexpr BLOCK_SIZE = 512;
+ TLLM_CHECK_WITH_INFO(
+ batchSize <= BLOCK_SIZE, "Batch size larger than %d is not supported for EAGLE yet", batchSize);
+ assembleTargetLogitsOffsets<<<1, BLOCK_SIZE, 0, stream>>>(
+ logitsPtrs, decodingTokens, logits, draftDecodingTokens, batchSize, maxDecodingTokens, vocabSizePadded);
+
+ sync_check_cuda_error();
+}
+
+template void invokeAssembleTargetLogitsOffsets(float const** logitsPtrs, SizeType32* decodingTokens,
+ float const* logits, SizeType32 const* draftDecodingTokens, SizeType32 batchSize, SizeType32 maxDecodingTokens,
+ SizeType32 vocabSizePadded, cudaStream_t stream);
+template void invokeAssembleTargetLogitsOffsets(__half const** logitsPtrs, SizeType32* decodingTokens,
+ __half const* logits, SizeType32 const* draftDecodingTokens, SizeType32 batchSize, SizeType32 maxDecodingTokens,
+ SizeType32 vocabSizePadded, cudaStream_t stream);
+
+namespace
+{
+template
+__global__ void selectLastAccTokenAndComputeIndicesCumSum(TokenIdType* lastAcceptedTokenIds,
+ SizeType32* exclusiveSumLastAcceptedIndices, SizeType32 const* draftDecodingTokens,
+ TokenIdType const* acceptedTokenIds, SizeType32 const* acceptedLengths, SizeType32 const* bestPathIds,
+ SizeType32 const* paths, SizeType32 batchSize, SizeType32 maxDecodingTokens, SizeType32 maxPathLen)
+{
+ typedef cub::BlockScan BlockScan;
+ __shared__ typename BlockScan::TempStorage tempStorage;
+
+ auto const tix = static_cast(threadIdx.x);
+ SizeType32 decodingTokens{0};
+ SizeType32 lastTokenId{0};
+ if (tix < batchSize)
+ {
+ auto const acceptedLen = acceptedLengths[tix];
+ lastAcceptedTokenIds[tix] = acceptedTokenIds[tix * maxPathLen + acceptedLen - 1];
+ auto const bestPathId = bestPathIds[tix];
+ auto const pathIdx = flat_index3(tix, bestPathId, acceptedLen - 1, maxDecodingTokens, maxPathLen);
+ lastTokenId = paths[pathIdx];
+ decodingTokens = draftDecodingTokens[tix] + 1;
+ }
+
+ BlockScan(tempStorage).ExclusiveSum(decodingTokens, decodingTokens);
+
+ if (tix < batchSize)
+ {
+ exclusiveSumLastAcceptedIndices[tix] = decodingTokens + lastTokenId;
+ }
+}
+} // namespace
+
+void invokeSelectLastAccTokenAndComputeIndicesCumSum(TokenIdType* lastAcceptedTokenIds,
+ SizeType32* exclusiveSumLastAcceptedIndices, SizeType32 const* draftDecodingTokens,
+ TokenIdType const* acceptedTokenIds, SizeType32 const* acceptedLengths, SizeType32 const* bestPathIds,
+ SizeType32 const* paths, SizeType32 batchSize, SizeType32 maxDecodingTokens, SizeType32 maxPathLen,
+ cudaStream_t stream)
+{
+ SizeType32 constexpr BLOCK_SIZE = 512;
+ TLLM_CHECK_WITH_INFO(
+ batchSize <= BLOCK_SIZE, "Batch size larger than %d is not supported for EAGLE yet", batchSize);
+ selectLastAccTokenAndComputeIndicesCumSum<<<1, BLOCK_SIZE, 0, stream>>>(lastAcceptedTokenIds,
+ exclusiveSumLastAcceptedIndices, draftDecodingTokens, acceptedTokenIds, acceptedLengths, bestPathIds, paths,
+ batchSize, maxDecodingTokens, maxPathLen);
+}
+
+} // namespace tensorrt_llm::kernels::speculative_decoding
diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h b/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h
new file mode 100644
index 000000000..b8e1430eb
--- /dev/null
+++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "tensorrt_llm/kernels/decodingCommon.h"
+#include "tensorrt_llm/kernels/speculativeDecoding/common.h"
+#include "tensorrt_llm/runtime/common.h"
+#include
+#include
+#include
+
+namespace tensorrt_llm::kernels::speculative_decoding
+{
+
+//! \brief Sets pointers to logits in logitsPtrs according to the draftDecodingTokens.
+//! \param logitsPtrs [batchSize][vocabSizePadded]
+//! \param decodingTokens [batchSize], on GPU. draftDecodingTokens + 1.
+//! \param logits [numTokens, vocabSizePadded], on GPU. Continuous logits in memory.
+//! \param draftDecodingTokens [batchSize], on GPU. 0 for context requests, and actual draft len for gen requests
+//! \param batchSize batch size. Only batch size <= 512 is supported at the moment
+//! \param maxDecodingTokens maximum number of decoding tokens per step per request
+//! \param vocabSizePadded vocab size of the logits
+//! \param stream cuda stream
+template
+void invokeAssembleTargetLogitsOffsets(T const** logitsPtrs, runtime::SizeType32* decodingTokens, T const* logits,
+ runtime::SizeType32 const* draftDecodingTokens, runtime::SizeType32 batchSize,
+ runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 vocabSizePadded, cudaStream_t stream);
+
+//! \brief Sets last accepted token ids and computes inclusive sum of the indices of the last accepted tokens in
+//! flattened input_ids tensor.
+//! \param lastAcceptedTokenIds [batchSize], on GPU. Token ids of the last accepted tokens.
+//! \param exclusiveSumLastAcceptedIndices [batchSize], on GPU. Exclusive sum of the positions of the last accepted
+//! tokens in the original flattened draft sequence.
+//! \param draftDecodingTokens [batchSize], on GPU. 0 for context
+//! requests, and actual draft len for gen requests.
+//! \param acceptedTokenIds [batchSize, maxPathLen], on GPU. Ids of the
+//! accepted tokens per request.
+//! \param acceptedLengths [batchSize], on GPU. Lengths of the accepted draft sequences
+//! per request.
+//! \param bestPathIds [batchSize], on GPU. Selected path id per request
+//! \param paths [batchSize,
+//! maxDecodingTokens, maxPathLen], on GPU. Indices of the draft sequences
+//! \param batchSize batch size. Only batch size
+//! <= 512 is supported at the moment
+//! \param maxDecodingTokens maximum number of decoding tokens per step per request
+//! \param maxPathLen maximum path len of the draft sequence
+//! \param stream cuda stream
+void invokeSelectLastAccTokenAndComputeIndicesCumSum(runtime::TokenIdType* lastAcceptedTokenIds,
+ runtime::SizeType32* exclusiveSumLastAcceptedIndices, runtime::SizeType32 const* draftDecodingTokens,
+ runtime::TokenIdType const* acceptedTokenIds, runtime::SizeType32 const* acceptedLengths,
+ runtime::SizeType32 const* bestPathIds, runtime::SizeType32 const* paths, runtime::SizeType32 batchSize,
+ runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxPathLen, cudaStream_t stream);
+
+} // namespace tensorrt_llm::kernels::speculative_decoding
diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/externalDraftTokensKernels.cu b/cpp/tensorrt_llm/kernels/speculativeDecoding/externalDraftTokensKernels.cu
index 19336e2ed..427f1bb6b 100644
--- a/cpp/tensorrt_llm/kernels/speculativeDecoding/externalDraftTokensKernels.cu
+++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/externalDraftTokensKernels.cu
@@ -60,7 +60,7 @@ __global__ void maskTargetLogitsKernel(T* targetLogits, SizeType32 const* batchS
auto* outputIdsAfterSamplingPtr = outputIdsAfterSampling + batchSlot * vocabSize;
auto const useDraftLogits = batchUseDraftLogits[batchSlot];
- if (finishedState.isSkipDecoding())
+ if (finishedState.isSkipDecoding() || finishedState.isFinished())
{
return;
}
@@ -75,8 +75,8 @@ __global__ void maskTargetLogitsKernel(T* targetLogits, SizeType32 const* batchS
for (SizeType32 vIdx = tid; vIdx < vocabSize; vIdx += static_cast(blockDim.x))
{
- if (tokensToMask == 0 && outputIdsAfterSamplingPtr[vIdx] == -1)
- { // we need to find the -1 boundary from returnAllTopP outputIds if topK == 0
+ if (outputIdsAfterSamplingPtr[vIdx] == -1)
+ { // we need to find the -1 boundary from returnAllTopP outputIds if topK == 0 or number of topP indices < topK
tokensToMask = vIdx;
}
maskBuffer[vIdx] = false;
@@ -124,12 +124,21 @@ __global__ void acceptDraftTokensKernel(T const* draftProbs, T* targetProbs, Siz
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
auto const useDraftLogits = batchUseDraftLogits[batchSlotBeamWidth];
- if (draftTokenIdx > numDraftTokens || finishedInput[batchSlot].isSkipDecoding())
+ if (draftTokenIdx > numDraftTokens || finishedInput[batchSlot].isSkipDecoding()
+ || finishedInput[batchSlot].isFinished())
{
if (tid == 0)
{
batchIsAccepted[batchSlot] = true;
+
+ // either finished or skip decode in previous step, this step don't need decoding
finishedOutput[batchSlot].setSkipDecoding();
+
+ // if previous step is finished, write the state to next step too
+ if (finishedInput[batchSlot].isFinished())
+ {
+ finishedOutput[batchSlot] = finishedInput[batchSlot];
+ }
}
return;
}
@@ -214,7 +223,8 @@ __global__ void forwardAcceptedTokensKernel(SizeType32 batchSize, SizeType32 con
for (SizeType32 bi = index; bi < batchSize; bi += static_cast(gridDim.x * blockDim.x))
{
auto const batchSlot = batchSlots[bi];
- if (batchIsAccepted[batchSlot] && !finishedOutput[batchSlot].isSkipDecoding())
+ if (batchIsAccepted[batchSlot] && !finishedOutput[batchSlot].isSkipDecoding()
+ && !finishedOutput[batchSlot].isFinished())
{
auto const curSeqLen = sequenceLengths[batchSlot];
auto const draftTokenIdx = step;
diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.cu b/cpp/tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.cu
index 7a6d8540e..4c876bd96 100644
--- a/cpp/tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.cu
+++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.cu
@@ -46,22 +46,22 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
T const** medusaLogits, T const** logitsPtrs, SizeType32* curTokensPerStep, SizeType32 const* targetTokensPerStep,
SizeType32* bestPathIds, SizeType32 batchSize, SizeType32 vocabSize, SizeType32 maxBatchSize, SizeType32 maxSeqLen,
- SizeType32 maxNumHeads, SizeType32 maxDecodingTokens)
+ SizeType32 maxDraftPathLen, SizeType32 maxDecodingTokens)
{
auto const batchIdx = static_cast(blockIdx.x);
- auto const batchSlot = batchSlots[batchIdx];
- auto const inputLength = sequenceLengths[batchSlot];
- auto const endId = endIds[batchSlot];
- auto const numTokensPerStep = curTokensPerStep[batchSlot];
- auto const maxNumDraftTokens = maxNumHeads + 1;
+ auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
+ auto const inputLength = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot];
+ auto const endId = endIds == nullptr ? -1 : endIds[batchSlot];
+ auto const numTokensPerStep = curTokensPerStep == nullptr ? maxDecodingTokens : curTokensPerStep[batchSlot];
+ auto const maxPathLen = maxDraftPathLen + 1;
int4 partialMax{-1, -1, 0, 0};
// Go over different paths and construct implicit sequences
for (auto pathIdx = static_cast(threadIdx.x); pathIdx < maxDecodingTokens;
pathIdx += static_cast(blockDim.x))
{
- auto acceptedLength = maxNumDraftTokens;
- auto const pathOffset = flat_index3(batchSlot, pathIdx, 0, maxDecodingTokens, maxNumDraftTokens);
+ auto acceptedLength = maxPathLen;
+ auto const pathOffset = flat_index3(batchSlot, pathIdx, 0, maxDecodingTokens, maxPathLen);
bool hasEnd = false;
auto const tokenId = paths[pathOffset];
@@ -75,13 +75,14 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
auto nextIdx = tokenId;
// Go along the path
- for (SizeType32 ti = 1; ti < maxNumDraftTokens; ++ti)
+ for (SizeType32 ti = 1; ti < maxPathLen; ++ti)
{
auto const tokenId = paths[pathOffset + ti];
// Break if path terminates
if (tokenId == -1)
{
- hasEnd = targetToken == endId; // check if last token is EOS when path terminates.
+ hasEnd = endIds == nullptr ? false
+ : targetToken == endId; // check if last token is EOS when path terminates.
acceptedLength = hasEnd ? ti - 1 : ti;
break;
}
@@ -91,7 +92,7 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
auto const draftToken = tokenId >= numTokensPerStep ? -1 : draftIds[draftTokenIdx];
// Check if draft tokens are the same as target tokens
bool const accepted = draftToken == targetToken;
- hasEnd = targetToken == endId;
+ hasEnd = endIds == nullptr ? false : targetToken == endId;
if (!accepted || hasEnd)
{
acceptedLength = hasEnd ? ti - 1 : ti;
@@ -126,7 +127,7 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
auto const acceptedLength = totalShared.x;
auto const bestPathIdx = totalShared.y;
auto const bestNextIdx = numTokensPerStep == 1 ? 0 : totalShared.w;
- auto const pathOffset = flat_index3(batchSlot, bestPathIdx, 0, maxDecodingTokens, maxNumDraftTokens);
+ auto const pathOffset = flat_index3(batchSlot, bestPathIdx, 0, maxDecodingTokens, maxPathLen);
for (auto ti = static_cast(threadIdx.x); ti < acceptedLength; ti += static_cast(blockDim.x))
{
auto const tokenId = paths[pathOffset + ti];
@@ -142,15 +143,18 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
{
auto const hasEnd = totalShared.z;
// Set end condition
- if (hasEnd)
+ if (hasEnd && finishedFinal)
{
finishedFinal[batchSlot].setFinishedEOS();
}
// Make correction to the sequence length
- sequenceLengths[batchSlot] += acceptedLength;
+ if (sequenceLengths)
+ {
+ sequenceLengths[batchSlot] += acceptedLength;
+ }
acceptedLengths[batchSlot] = acceptedLength;
// In Medusa decoding step, number of draft tokens is 0 and must be updated for the next steps
- if (numTokensPerStep == 1)
+ if (curTokensPerStep && targetTokensPerStep && numTokensPerStep == 1)
{
curTokensPerStep[batchSlot] = targetTokensPerStep[batchSlot];
}
@@ -158,45 +162,33 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
}
// Prepare logits pointers to respective logits from Medusa Heads for the all-top-K sampling kernel
- for (auto hi = static_cast(threadIdx.x); hi < maxNumHeads; hi += static_cast(blockDim.x))
+ if (medusaLogits && logitsPtrs)
{
- logitsPtrs[batchIdx * maxNumHeads + hi]
- = medusaLogits[batchSlot * maxNumHeads + hi] + flat_index2(bestNextIdx, 0, vocabSize);
+ for (auto hi = static_cast(threadIdx.x); hi < maxDraftPathLen;
+ hi += static_cast(blockDim.x))
+ {
+ logitsPtrs[batchIdx * maxDraftPathLen + hi]
+ = medusaLogits[batchSlot * maxDraftPathLen + hi] + flat_index2(bestNextIdx, 0, vocabSize);
+ }
}
}
} // namespace
template
-void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds, TokenIdType const* targetIds,
- SizeType32* sequenceLengths, SizeType32* acceptedLengths, FinishedState* finishedFinal,
- SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds, T const** medusaLogits,
- T const** logitsPtrs, SizeType32* curTokensPerStep, SizeType32 const* targetTokensPerStep, SizeType32* bestPathIds,
- SizeType32 batchSize, SizeType32 vocabSize, SizeType32 maxBatchSize, SizeType32 maxSeqLen, SizeType32 maxNumHeads,
- SizeType32 maxDecodingTokens, cudaStream_t stream)
+void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams const& params)
{
constexpr SizeType32 BLOCK_SIZE = 256;
dim3 block(BLOCK_SIZE);
- dim3 grid(batchSize);
- acceptDraftTokensByIdsWithPaths<<>>(outputIds, draftIds, targetIds,
- sequenceLengths, acceptedLengths, finishedFinal, batchSlots, paths, endIds, medusaLogits, logitsPtrs,
- curTokensPerStep, targetTokensPerStep, bestPathIds, batchSize, vocabSize, maxBatchSize, maxSeqLen, maxNumHeads,
- maxDecodingTokens);
+ dim3 grid(params.batchSize);
+ acceptDraftTokensByIdsWithPaths<<>>(params.outputIds, params.draftIds,
+ params.targetIds, params.sequenceLengths, params.acceptedLengths, params.finishedFinal, params.batchSlots,
+ params.paths, params.endIds, params.medusaLogits, params.logitsPtrs, params.curTokensPerStep,
+ params.targetTokensPerStep, params.bestPathIds, params.batchSize, params.vocabSize, params.maxBatchSize,
+ params.maxSeqLen, params.maxDraftPathLen, params.maxDecodingTokens);
}
-template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
- TokenIdType const* targetIds, SizeType32* sequenceLengths, SizeType32* acceptedLengths,
- FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
- float const** medusaLogits, float const** logitsPtrs, SizeType32* curTokensPerStep,
- SizeType32 const* targetTokensPerStep, SizeType32* bestPathIds, SizeType32 batchSize, SizeType32 vocabSize,
- SizeType32 maxBatchSize, SizeType32 maxSeqLen, SizeType32 maxNumHeads, SizeType32 maxDecodingTokens,
- cudaStream_t stream);
-template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
- TokenIdType const* targetIds, SizeType32* sequenceLengths, SizeType32* acceptedLengths,
- FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
- half const** medusaLogits, half const** logitsPtrs, SizeType32* curTokensPerStep,
- SizeType32 const* targetTokensPerStep, SizeType32* bestPathIds, SizeType32 batchSize, SizeType32 vocabSize,
- SizeType32 maxBatchSize, SizeType32 maxSeqLen, SizeType32 maxNumHeads, SizeType32 maxDecodingTokens,
- cudaStream_t stream);
+template void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams const& params);
+template void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams<__half> const& params);
namespace
{
diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.h b/cpp/tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.h
index 6a1fae1a7..67f43c9fc 100644
--- a/cpp/tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.h
+++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.h
@@ -26,46 +26,87 @@
namespace tensorrt_llm::kernels::speculative_decoding
{
+template
+struct AcceptDraftTokensByIdsWithPathsParams
+{
+ //! output buffer [maxBatchSize, maxSeqLen], input tokens.
+ runtime::TokenIdType* outputIds{nullptr};
+ //! input buffer [maxBatchSize, maxDecodingTokens], draft tokens
+ runtime::TokenIdType const* draftIds{nullptr};
+ //! input buffer [maxBatchSize, maxDecodingTokens], tokens predicted from the target medusa head
+ runtime::TokenIdType const* targetIds{nullptr};
+ //! input/output buffer [maxBatchSize], optional.
+ //! Length of the data in outputIds without draft tokens.
+ //! If set, incrememnted according to the accepted length.
+ runtime::SizeType32* sequenceLengths{nullptr};
+ //! output buffer [maxBatchSize], length of the data accepted tokens
+ runtime::SizeType32* acceptedLengths{nullptr};
+ //! input buffer [maxBatchSize], optional. Finished states per request
+ FinishedState* finishedFinal{nullptr};
+ //! input buffer [batchSize], optional. Address map from local index
+ //! to global index [0, batchSize] -> [0, maxBatchSize].
+ //! If nullptr, batchIdx is used.
+ runtime::SizeType32 const* batchSlots{nullptr};
+ //! input buffer [maxBatchSize, maxDecodingTokens, maxDraftPathLen+1],
+ //! paths to restore sequences from outputIds and targetIds. Should be filled with -1 for everything that is not
+ //! path.
+ runtime::SizeType32 const* paths{nullptr};
+ //! input buffer [maxBatchSize], optional. EOS ids per request.
+ //! No EOS checks if nullptr.
+ runtime::TokenIdType const* endIds{nullptr};
+ //! input buffer [maxDraftPathLen, maxBatchSize, maxDecodingTokens, vocabSize], optional.
+ //! Pointer to the logits from medusa heads.
+ T const** medusaLogits{nullptr};
+ //! output buffer [batchSize, maxDraftPathLen], optional. Contains pointers to the
+ //! respective rows of the medusaLogits for the next after the accepted token
+ T const** logitsPtrs{nullptr};
+ //! current tokens to compute per step will be updated to
+ //! targetTokensPerStep if curTokensPerStep == 1
+ runtime::SizeType32* curTokensPerStep{nullptr};
+ //! target values of tokens to compute per step
+ runtime::SizeType32 const* targetTokensPerStep{nullptr};
+ //! output buffer [maxBatchSize], indices of the selected paths
+ runtime::SizeType32* bestPathIds{nullptr};
+ //! current batch size
+ runtime::SizeType32 batchSize{0};
+ //! maximum batch size
+ runtime::SizeType32 maxBatchSize{0};
+ //! vocab size
+ runtime::SizeType32 vocabSize{0};
+ //! maximum sequence length of output ids
+ runtime::SizeType32 maxSeqLen{0};
+ //! maximum number of medusa heads
+ runtime::SizeType32 maxDraftPathLen{0};
+ //! maximum number of tokens per step configured in the system
+ runtime::SizeType32 maxDecodingTokens{0};
+ //! stream
+ cudaStream_t stream;
+
+ void checkParams() const
+ {
+ TLLM_CHECK(outputIds);
+ TLLM_CHECK(draftIds);
+ TLLM_CHECK(targetIds);
+ TLLM_CHECK(acceptedLengths);
+ TLLM_CHECK(paths);
+ TLLM_CHECK(bestPathIds);
+ TLLM_CHECK((curTokensPerStep == nullptr) ^ (targetTokensPerStep == nullptr) == 0);
+ TLLM_CHECK((medusaLogits == nullptr) ^ (logitsPtrs == nullptr) == 0);
+
+ TLLM_CHECK(batchSize > 0);
+ TLLM_CHECK(batchSize <= maxBatchSize);
+ TLLM_CHECK(vocabSize > 0);
+ TLLM_CHECK(maxSeqLen > 0);
+ TLLM_CHECK(maxDraftPathLen > 0);
+ TLLM_CHECK(maxDecodingTokens > 0);
+ }
+};
+
//! \brief verifies draft medusa tokens given target tokens. Modifies outputIds tensor accordingly filling it with
//! accepted tokens. Fills logitsPtrs tensor with the pointers to the respective medusa logits tensor according
//! to the next after the last accepted token.
-//!
-//! \param outputIds output buffer [maxBatchSize, maxSeqLen], input tokens.
-//! \param draftIds input buffer [maxBatchSize, maxDecodingTokens], draft tokens
-//! \param targetIds input buffer [maxBatchSize, maxDecodingTokens], tokens predicted from the target medusa head
-//! \param sequenceLengths input/output buffer [maxBatchSize], length of the data in outputIds without draft tokens
-//! Incrememnted according to the accepted length
-//! \param acceptedLengths output buffer [maxBatchSize], length of the data accepted tokens
-//! \param finishedFinal input buffer [maxBatchSize], finished states per request
-//! \param batchSlots input buffer [batchSize], address map from local index
-//! to global index [0, batchSize] -> [0, maxBatchSize]
-//! \param paths input buffer [maxBatchSize, maxDecodingTokens, maxNumHeads+1],
-//! paths to restore sequences from outputIds and targetIds. Should be filled with -1 for everything that is not path.
-//! \param endIds input buffer [maxBatchSize], EOS ids per request
-//! \param medusaLogits input buffer [maxNumHeads, maxBatchSize, maxDecodingTokens, vocabSize], pointer
-//! to the logits from medusa heads
-//! \param logitsPtrs output buffer [batchSize, maxNumHeads], contains pointers to the
-//! respective rows of the medusaLogits for the next after the accepted token
-//! \param curTokensPerStep current tokens to compute per step will be updated to
-//! targetTokensPerStep if curTokensPerStep == 1
-//! \param targetTokensPerStep target values of tokens to compute per step
-//! \param bestPathIds output buffer [maxBatchSize], indices of the selected paths
-//! \param batchSize current batch size
-//! \param maxBatchSize maximum batch size
-//! \param vocabSize vocab size
-//! \param maxSeqLen maximum sequence length of output ids
-//! \param maxNumHeads maximum number of medusa heads
-//! \param maxDecodingTokens maximum number of tokens per step configured in the system
-//! \param stream stream
template
-void acceptDraftTokensByIdsWithPaths(runtime::TokenIdType* outputIds, runtime::TokenIdType const* draftIds,
- runtime::TokenIdType const* targetIds, runtime::SizeType32* sequenceLengths, runtime::SizeType32* acceptedLengths,
- FinishedState* finishedFinal, runtime::SizeType32 const* batchSlots, runtime::SizeType32 const* paths,
- runtime::TokenIdType const* endIds, T const** medusaLogits, T const** logitsPtrs,
- runtime::SizeType32* curTokensPerStep, runtime::SizeType32 const* targetTokensPerStep,
- runtime::SizeType32* bestPathIds, runtime::SizeType32 batchSize, runtime::SizeType32 maxBatchSize,
- runtime::SizeType32 vocabSize, runtime::SizeType32 maxSeqLen, runtime::SizeType32 maxNumHeads,
- runtime::SizeType32 maxDecodingTokens, cudaStream_t stream);
+void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams const&);
//! \brief assembles draft tokens to treeDraftIds from sourceDraftIds using indices of treeIds
//!
diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h
index c53510d3e..f6ecd5b72 100644
--- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h
+++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h
@@ -507,15 +507,12 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams(global_token_idx) * params.q_hidden_size + hidden_idx;
- QuantizedEltType* quantized_q_ptr = STORE_QKV
- ? reinterpret_cast(params.QuantizedQKV) + src_q_idx
- : reinterpret_cast(params.Q) + dst_q_idx;
VecType* q_ptr = STORE_QKV ? reinterpret_ptr(params.QKV, src_q_idx)
: reinterpret_ptr(params.Q, dst_q_idx);
// Cast float scale to dst data type.
using TScale = typename mmha::kv_cache_scale_type_t::Type;
- TScale scaleOrigQuant;
+ [[maybe_unused]] TScale scaleOrigQuant;
if constexpr (FP8_OUTPUT || ENABLE_8BITS_CACHE)
{
mmha::convert_from_float(
@@ -525,6 +522,9 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams(params.QuantizedQKV) + src_q_idx
+ : reinterpret_cast(params.Q) + dst_q_idx;
mmha::store_8bits_vec(quantized_q_ptr, q, 0, scaleOrigQuant);
}
else
@@ -813,15 +813,12 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams(global_token_idx) * params.q_hidden_size + hidden_idx;
- QuantizedEltType* quantized_q_ptr = STORE_QKV
- ? reinterpret_cast(params.QuantizedQKV) + src_q_idx
- : reinterpret_cast(params.Q) + dst_q_idx;
VecT* q_ptr = STORE_QKV ? reinterpret_ptr(params.QKV, src_q_idx)
: reinterpret_ptr(params.Q, dst_q_idx);
// Cast float scale to dst data type.
using TScale = typename mmha::kv_cache_scale_type_t::Type;
- TScale scaleOrigQuant;
+ [[maybe_unused]] TScale scaleOrigQuant;
if constexpr (FP8_OUTPUT || ENABLE_8BITS_CACHE)
{
mmha::convert_from_float(&scaleOrigQuant, params.kvScaleOrigQuant ? params.kvScaleOrigQuant[0] : 1.0f);
@@ -830,6 +827,9 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams(params.QuantizedQKV) + src_q_idx
+ : reinterpret_cast(params.Q) + dst_q_idx;
mmha::store_8bits_vec(quantized_q_ptr, q, 0, scaleOrigQuant);
}
else
diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h
index db0762351..c8228f7d1 100644
--- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h
+++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h
@@ -32,6 +32,8 @@ namespace weight_only
{
enum class KernelType
{
+ FP16Int8Groupwise,
+ BF16Int8Groupwise,
FP16Int4Groupwise,
BF16Int4Groupwise,
FP16Int8PerChannel,
@@ -49,6 +51,8 @@ struct kernel_type_traits;
static constexpr bool isGroupwise = _isGroupwise; \
static constexpr bool isInt4 = _isInt4; \
};
+KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8Groupwise, true, false);
+KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int8Groupwise, true, false);
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int4Groupwise, true, true);
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int4Groupwise, true, true);
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8PerChannel, false, false);
diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8GroupwiseColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8GroupwiseColumnMajorFalse.cu
new file mode 100644
index 000000000..7fa33376f
--- /dev/null
+++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8GroupwiseColumnMajorFalse.cu
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
+
+namespace tensorrt_llm
+{
+namespace kernels
+{
+namespace weight_only
+{
+INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
+ KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
+} // namespace weight_only
+} // namespace kernels
+} // namespace tensorrt_llm
diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8GroupwiseColumnMajorInterleavedTrue.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8GroupwiseColumnMajorInterleavedTrue.cu
new file mode 100644
index 000000000..6c718b24a
--- /dev/null
+++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherBf16Int8GroupwiseColumnMajorInterleavedTrue.cu
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
+
+namespace tensorrt_llm
+{
+namespace kernels
+{
+namespace weight_only
+{
+INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
+ KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true, 64);
+} // namespace weight_only
+} // namespace kernels
+} // namespace tensorrt_llm
diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8GroupwiseColumnMajorFalse.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8GroupwiseColumnMajorFalse.cu
new file mode 100644
index 000000000..118032999
--- /dev/null
+++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8GroupwiseColumnMajorFalse.cu
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
+
+namespace tensorrt_llm
+{
+namespace kernels
+{
+namespace weight_only
+{
+INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
+ KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
+} // namespace weight_only
+} // namespace kernels
+} // namespace tensorrt_llm
diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8GroupwiseColumnMajorInterleavedTrue.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8GroupwiseColumnMajorInterleavedTrue.cu
new file mode 100644
index 000000000..fa5002ae0
--- /dev/null
+++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcherFp16Int8GroupwiseColumnMajorInterleavedTrue.cu
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
+
+namespace tensorrt_llm
+{
+namespace kernels
+{
+namespace weight_only
+{
+INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
+ KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true, 64);
+} // namespace weight_only
+} // namespace kernels
+} // namespace tensorrt_llm
diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h
index 7ff08a19e..e047d1235 100644
--- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h
+++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h
@@ -61,6 +61,8 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
{
EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
}
+ EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true);
+ EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true);
EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true);
@@ -70,6 +72,8 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
}
else if (arch >= 90)
{
+ EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajor, false);
+ EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajor, false);
EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false);
EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false);
EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, false);
@@ -98,6 +102,8 @@ inline bool is_supported(int arch, KernelType kernel_type)
}
else if (arch >= 80 && arch < 90)
{
+ SUPPORT(KernelType::FP16Int8Groupwise);
+ SUPPORT(KernelType::BF16Int8Groupwise);
SUPPORT(KernelType::FP16Int4Groupwise);
SUPPORT(KernelType::BF16Int4Groupwise);
SUPPORT(KernelType::FP16Int8PerChannel);
@@ -107,6 +113,8 @@ inline bool is_supported(int arch, KernelType kernel_type)
}
else if (arch >= 90)
{
+ SUPPORT(KernelType::FP16Int8Groupwise);
+ SUPPORT(KernelType::BF16Int8Groupwise);
SUPPORT(KernelType::FP16Int4Groupwise);
SUPPORT(KernelType::BF16Int4Groupwise);
SUPPORT(KernelType::FP16Int8PerChannel);
diff --git a/cpp/tensorrt_llm/layers/externalDraftTokensLayer.cpp b/cpp/tensorrt_llm/layers/externalDraftTokensLayer.cpp
index 5f29a9a13..097fe116e 100644
--- a/cpp/tensorrt_llm/layers/externalDraftTokensLayer.cpp
+++ b/cpp/tensorrt_llm/layers/externalDraftTokensLayer.cpp
@@ -431,7 +431,7 @@ void ExternalDraftTokensLayer::getAllTopKs(std::shared_ptrprobsComputed;
@@ -475,7 +475,7 @@ void ExternalDraftTokensLayer::getAllTopPs(std::shared_ptr(params, getStream());
diff --git a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp
index 32b812967..414572322 100644
--- a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp
+++ b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp
@@ -76,6 +76,8 @@ LookaheadDecodingLayer::CpuAlgorithmResources::CpuAlgorithmResources(DecoderD
ITensor::makeShape({maxTokensPerStep, maxBatchSize, beamWidth}), nvinfer1::DataType::kINT32);
mPathsOffsets
= BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxAcceptedDraftLen}), nvinfer1::DataType::kINT32);
+ mPathsOffsetsBatch
+ = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxAcceptedDraftLen}), nvinfer1::DataType::kINT32);
mNumNewTokens = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
mNumNewTokensCumSum = BufferManager::cpu(ITensor::makeShape({maxBatchSize + 1}), nvinfer1::DataType::kINT32);
mNextDraftTokens = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32);
@@ -220,7 +222,7 @@ void LookaheadDecodingLayer::forwardAsync(std::shared_ptr::forwardSyncCPU(
BufferRange nextDraftLengthsRange(*mCpuAlgo->mNextDraftLengths);
BufferRange sequenceLengthsRange(*mCpuAlgo->mSequenceLengths);
BufferLocation pathsOffsetLocation(*mCpuAlgo->mPathsOffsets);
+ BufferLocation pathsOffsetBatchLocation(*mCpuAlgo->mPathsOffsetsBatch);
BufferLocation outputIdsLocation(*mCpuAlgo->mOutputIds);
mBufferManager->setZero(*mCpuAlgo->mPathsOffsets);
@@ -394,20 +397,22 @@ void LookaheadDecodingLayer::forwardSyncCPU(
D(accepted).values().c_str(), D(draft).values().c_str());
}
- numNewTokensCumSumRange[0] = 0;
SizeType32 pi = 0;
- for (SizeType32 bi = 0; bi < numNewTokensRange.size(); bi++)
+ numNewTokensCumSumRange[0] = 0;
+ for (SizeType32 bi = 0; bi < batchSize; bi++)
{
- SizeType32 acceptedDraftLen = numNewTokensRange[bi] <= 1 ? 0 : (numNewTokensRange[bi] - 1);
+ SizeType32 gbi = batchSlotsRange[bi];
+ SizeType32 acceptedDraftLen = numNewTokensRange[gbi] <= 1 ? 0 : (numNewTokensRange[gbi] - 1);
numNewTokensCumSumRange[bi + 1] = numNewTokensCumSumRange[bi] + acceptedDraftLen;
for (SizeType32 tj = 0; tj < acceptedDraftLen; tj++)
{
- pathsOffsetLocation[pi++] = pathsOffsetLocation.at(bi, tj);
+ pathsOffsetBatchLocation[pi++] = pathsOffsetLocation.at(gbi, tj);
}
}
- for (; pi < pathsOffsetLocation.size(); pi++)
+
+ for (; pi < pathsOffsetBatchLocation.size(); pi++)
{
- pathsOffsetLocation[pi++] = 0;
+ pathsOffsetBatchLocation[pi++] = 0;
}
TLLM_CHECK(outputs->numNewTokens);
@@ -415,8 +420,8 @@ void LookaheadDecodingLayer::forwardSyncCPU(
mBufferManager->copy(*mCpuAlgo->mSequenceLengths, *outputs->sequenceLength.value());
mBufferManager->copy(*mCpuAlgo->mNewTokens, *outputs->newTokens);
- mBufferManager->copy(*mCpuAlgo->mPathsOffsets, *outputs->pathsOffsets);
mBufferManager->copy(*mCpuAlgo->mNumNewTokens, *outputs->numNewTokens.value());
+ mBufferManager->copy(*mCpuAlgo->mPathsOffsetsBatch, *outputs->pathsOffsets);
mBufferManager->copy(*mCpuAlgo->mNumNewTokensCumSum, *outputs->numNewTokensCumSum); //
mBufferManager->copy(*mCpuAlgo->mNextDraftTokens, *outputs->nextDraftTokens);
diff --git a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h
index f2470a411..e20b59b22 100644
--- a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h
+++ b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h
@@ -70,6 +70,7 @@ class LookaheadDecodingLayer : public BaseLayer
TensorPtr mOutputIds;
TensorPtr mPathsOffsets;
+ TensorPtr mPathsOffsetsBatch;
TensorPtr mNumNewTokens;
TensorPtr mNumNewTokensCumSum;
TensorPtr mNewTokens;
diff --git a/cpp/tensorrt_llm/layers/medusaDecodingLayer.cpp b/cpp/tensorrt_llm/layers/medusaDecodingLayer.cpp
index ac8f78ec1..69978863b 100644
--- a/cpp/tensorrt_llm/layers/medusaDecodingLayer.cpp
+++ b/cpp/tensorrt_llm/layers/medusaDecodingLayer.cpp
@@ -329,11 +329,33 @@ void MedusaDecodingLayer::acceptDraftTokens(SpeculativeDecodingOutputs const&
auto medusaInputLogitsPtrsPtr = reinterpret_cast(bufferCast(*mMedusaInputLogitsPtrs));
auto medusaSelectedLogitsPtrsDevicePtr
= const_cast(bufferCastOrNull(mMedusaSelectedLogitsPtrsDevice));
- acceptDraftTokensByIdsWithPaths(outputIds, draftIds, targetTokensDevicePtr, sequenceLengths, numNewTokens,
- finishedStatesPtr, workspace->getDeviceBatchSlotsPtr(), paths, endIds, medusaInputLogitsPtrsPtr,
- medusaSelectedLogitsPtrsDevicePtr, curTokensPerStepDevice, targetTokensPerStepDevice, bestPathIdsDevicePtr,
- batchSize, mDecoderDomain.getVocabSize(), mDecoderDomain.getBatchSize(), maxSeqLen, maxDraftPathLen,
- mDecoderDomain.getMaxDecodingTokens(), getStream());
+
+ AcceptDraftTokensByIdsWithPathsParams params;
+ params.outputIds = outputIds;
+ params.draftIds = draftIds;
+ params.targetIds = targetTokensDevicePtr;
+ params.sequenceLengths = sequenceLengths;
+ params.acceptedLengths = numNewTokens;
+ params.finishedFinal = finishedStatesPtr;
+ params.batchSlots = workspace->getDeviceBatchSlotsPtr();
+ params.paths = paths;
+ params.endIds = endIds;
+ params.medusaLogits = medusaInputLogitsPtrsPtr;
+ params.logitsPtrs = medusaSelectedLogitsPtrsDevicePtr;
+ params.curTokensPerStep = curTokensPerStepDevice;
+ params.targetTokensPerStep = targetTokensPerStepDevice;
+ params.bestPathIds = bestPathIdsDevicePtr;
+ params.batchSize = batchSize;
+ params.maxBatchSize = mDecoderDomain.getBatchSize();
+ params.vocabSize = mDecoderDomain.getVocabSize();
+ params.maxSeqLen = maxSeqLen;
+ params.maxDraftPathLen = maxDraftPathLen;
+ params.maxDecodingTokens = mDecoderDomain.getMaxDecodingTokens();
+ params.stream = getStream();
+
+ params.checkParams();
+
+ acceptDraftTokensByIdsWithPaths(params);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@@ -390,7 +412,7 @@ void MedusaDecodingLayer::sampleNewDraftTokens(SpeculativeDecodingOutputs con
params.maxBatchSize = maxBatchSizeHeadNums;
params.maxTokensPerStep = 1;
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
- params.returnAllTopK = true;
+ params.returnAllSelectedTokens = true;
invokeBatchTopKSampling(params, getStream());
diff --git a/cpp/tensorrt_llm/plugins/CMakeLists.txt b/cpp/tensorrt_llm/plugins/CMakeLists.txt
index 604c656e5..def5e55d9 100755
--- a/cpp/tensorrt_llm/plugins/CMakeLists.txt
+++ b/cpp/tensorrt_llm/plugins/CMakeLists.txt
@@ -54,7 +54,8 @@ set(PLUGIN_LISTS
mambaConv1dPlugin
lruPlugin
cumsumLastDimPlugin
- lowLatencyGemmPlugin)
+ lowLatencyGemmPlugin
+ eaglePlugin)
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
include_directories(${PLUGIN_ITER})
diff --git a/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp b/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp
index 8a9d6784d..2d4f94176 100644
--- a/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp
+++ b/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp
@@ -39,6 +39,9 @@
#include "tensorrt_llm/plugins/ncclPlugin/sendPlugin.h"
#endif // ENABLE_MULTI_DEVICE
#include "tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.h"
+#include "tensorrt_llm/plugins/eaglePlugin/eagleDecodeDraftTokensPlugin.h"
+#include "tensorrt_llm/plugins/eaglePlugin/eaglePrepareDrafterInputsPlugin.h"
+#include "tensorrt_llm/plugins/eaglePlugin/eagleSampleAndAcceptDraftTokensPlugin.h"
#include "tensorrt_llm/plugins/lowLatencyGemmPlugin/lowLatencyGemmPlugin.h"
#include "tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.h"
#include "tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.h"
@@ -201,6 +204,10 @@ extern "C"
static tensorrt_llm::plugins::lruPluginCreator lruPluginCreator;
static tensorrt_llm::plugins::CumsumLastDimPluginCreator cumsumLastDimPluginCreator;
static tensorrt_llm::plugins::LowLatencyGemmPluginCreator lowLatencyGemmPluginCreator;
+ static tensorrt_llm::plugins::EagleDecodeDraftTokensPluginCreator eagleDecodeDraftTokensPluginCreator;
+ static tensorrt_llm::plugins::EagleSampleAndAcceptDraftTokensPluginCreator
+ eagleSampleAndAcceptDraftTokensPluginCreator;
+ static tensorrt_llm::plugins::EaglePrepareDrafterInputsPluginCreator eaglePrepareDrafterInputsPluginCreator;
static std::array pluginCreators
= { creatorPtr(identityPluginCreator),
@@ -231,6 +238,9 @@ extern "C"
creatorPtr(lruPluginCreator),
creatorPtr(cumsumLastDimPluginCreator),
creatorPtr(lowLatencyGemmPluginCreator),
+ creatorPtr(eagleDecodeDraftTokensPluginCreator),
+ creatorPtr(eagleSampleAndAcceptDraftTokensPluginCreator),
+ creatorPtr(eaglePrepareDrafterInputsPluginCreator),
};
nbCreators = pluginCreators.size();
return pluginCreators.data();
diff --git a/cpp/tensorrt_llm/plugins/eaglePlugin/CMakeLists.txt b/cpp/tensorrt_llm/plugins/eaglePlugin/CMakeLists.txt
new file mode 100644
index 000000000..b6bd0439c
--- /dev/null
+++ b/cpp/tensorrt_llm/plugins/eaglePlugin/CMakeLists.txt
@@ -0,0 +1,21 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
+# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+#
+file(GLOB SRCS *.cpp)
+set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
+set(PLUGIN_SOURCES
+ ${PLUGIN_SOURCES}
+ PARENT_SCOPE)
diff --git a/cpp/tensorrt_llm/plugins/eaglePlugin/eagleDecodeDraftTokensPlugin.cpp b/cpp/tensorrt_llm/plugins/eaglePlugin/eagleDecodeDraftTokensPlugin.cpp
new file mode 100644
index 000000000..8dbb8f47d
--- /dev/null
+++ b/cpp/tensorrt_llm/plugins/eaglePlugin/eagleDecodeDraftTokensPlugin.cpp
@@ -0,0 +1,228 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
+ * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "eagleDecodeDraftTokensPlugin.h"
+
+using namespace nvinfer1;
+using tensorrt_llm::plugins::EagleDecodeDraftTokensPluginCreator;
+using tensorrt_llm::plugins::EagleDecodeDraftTokensPlugin;
+
+static char const* EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_VERSION{"1"};
+static char const* EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_NAME{"EagleDecodeDraftTokens"};
+PluginFieldCollection EagleDecodeDraftTokensPluginCreator::mFC{};
+std::vector EagleDecodeDraftTokensPluginCreator::mPluginAttributes;
+
+EagleDecodeDraftTokensPlugin::EagleDecodeDraftTokensPlugin(nvinfer1::DataType type, int32_t layerIdx)
+ : mDtype(type)
+ , mLayerIdx(layerIdx)
+{
+}
+
+// Parameterized constructor
+EagleDecodeDraftTokensPlugin::EagleDecodeDraftTokensPlugin(void const* data, size_t length)
+{
+ char const *d = reinterpret_cast(data), *a = d;
+ read(d, mDtype);
+ read(d, mLayerIdx);
+ TLLM_CHECK_WITH_INFO(d == a + length,
+ "Expected length (%d) != real length (%d). This is often "
+ "caused by using different TensorRT-LLM version to build "
+ "engine and run engine.",
+ static_cast(length), static_cast(d - a));
+}
+
+// IPluginV2DynamicExt Methods
+nvinfer1::IPluginV2DynamicExt* EagleDecodeDraftTokensPlugin::clone() const noexcept
+{
+ auto* plugin = new EagleDecodeDraftTokensPlugin(*this);
+ plugin->setPluginNamespace(mNamespace.c_str());
+ return plugin;
+}
+
+nvinfer1::DimsExprs EagleDecodeDraftTokensPlugin::getOutputDimensions(
+ int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
+{
+ TLLM_CHECK(outputIndex < 2);
+ TLLM_CHECK(nbInputs == 5);
+ return inputs[outputIndex + 1];
+}
+
+bool EagleDecodeDraftTokensPlugin::supportsFormatCombination(
+ int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
+{
+ if (pos == 0) // logits
+ {
+ return (inOut[pos].type == mDtype) && (inOut[pos].format == TensorFormat::kLINEAR);
+ }
+ else if (pos == 3) // rand_data_sample
+ {
+ return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR);
+ }
+ else // next_draft_tokens, next_draft_lens, paths, tree_indices
+ {
+ return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == TensorFormat::kLINEAR);
+ }
+}
+
+void EagleDecodeDraftTokensPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
+ nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
+{
+}
+
+size_t EagleDecodeDraftTokensPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
+ nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
+{
+ return 0;
+}
+
+int EagleDecodeDraftTokensPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
+ nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) noexcept
+{
+ // TODO fill me
+
+ return 0;
+}
+
+// IPluginV2Ext Methods
+nvinfer1::DataType EagleDecodeDraftTokensPlugin::getOutputDataType(
+ int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
+{
+ TLLM_CHECK(index < 2);
+ return inputTypes[index + 1];
+}
+
+// IPluginV2 Methods
+
+char const* EagleDecodeDraftTokensPlugin::getPluginType() const noexcept
+{
+ return EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_NAME;
+}
+
+char const* EagleDecodeDraftTokensPlugin::getPluginVersion() const noexcept
+{
+ return EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_VERSION;
+}
+
+int EagleDecodeDraftTokensPlugin::getNbOutputs() const noexcept
+{
+ return 2;
+}
+
+int EagleDecodeDraftTokensPlugin::initialize() noexcept
+{
+ return 0;
+}
+
+void EagleDecodeDraftTokensPlugin::terminate() noexcept {}
+
+size_t EagleDecodeDraftTokensPlugin::getSerializationSize() const noexcept
+{
+ return sizeof(mDtype) + sizeof(mLayerIdx);
+}
+
+void EagleDecodeDraftTokensPlugin::serialize(void* buffer) const noexcept
+{
+ char *d = static_cast(buffer), *a = d;
+ write(d, mLayerIdx);
+ write(d, mDtype);
+ assert(d == a + getSerializationSize());
+}
+
+void EagleDecodeDraftTokensPlugin::destroy() noexcept
+{
+ // This gets called when the network containing plugin is destroyed
+ delete this;
+}
+
+///////////////
+
+EagleDecodeDraftTokensPluginCreator::EagleDecodeDraftTokensPluginCreator()
+{
+ // Fill PluginFieldCollection with PluginField arguments metadata
+ mPluginAttributes.clear();
+ mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
+ mPluginAttributes.emplace_back(PluginField("layer_idx", nullptr, PluginFieldType::kINT32, 0));
+ mFC.nbFields = mPluginAttributes.size();
+ mFC.fields = mPluginAttributes.data();
+}
+
+char const* EagleDecodeDraftTokensPluginCreator::getPluginName() const noexcept
+{
+ return EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_NAME;
+}
+
+char const* EagleDecodeDraftTokensPluginCreator::getPluginVersion() const noexcept
+{
+ return EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_VERSION;
+}
+
+PluginFieldCollection const* EagleDecodeDraftTokensPluginCreator::getFieldNames() noexcept
+{
+ return &mFC;
+}
+
+IPluginV2* EagleDecodeDraftTokensPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
+{
+ PluginField const* fields = fc->fields;
+ int32_t layerIdx;
+ nvinfer1::DataType type;
+ // Read configurations from each fields
+ for (int i = 0; i < fc->nbFields; ++i)
+ {
+ char const* attrName = fields[i].name;
+ if (!strcmp(attrName, "layer_idx"))
+ {
+ TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
+ layerIdx = *static_cast(fields[i].data);
+ }
+ else if (!strcmp(attrName, "type_id"))
+ {
+ TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
+ type = static_cast(*(static_cast(fields[i].data)));
+ }
+ }
+
+ try
+ {
+ auto* obj = new EagleDecodeDraftTokensPlugin(type, layerIdx);
+ obj->setPluginNamespace(mNamespace.c_str());
+ return obj;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+IPluginV2* EagleDecodeDraftTokensPluginCreator::deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept
+{
+ // This object will be deleted when the network is destroyed, which will
+ // call EagleDecodeDraftTokensPlugin::destroy()
+ try
+ {
+ auto* obj = new EagleDecodeDraftTokensPlugin(serialData, serialLength);
+ obj->setPluginNamespace(mNamespace.c_str());
+ return obj;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
diff --git a/cpp/tensorrt_llm/plugins/eaglePlugin/eagleDecodeDraftTokensPlugin.h b/cpp/tensorrt_llm/plugins/eaglePlugin/eagleDecodeDraftTokensPlugin.h
new file mode 100644
index 000000000..3278d4c57
--- /dev/null
+++ b/cpp/tensorrt_llm/plugins/eaglePlugin/eagleDecodeDraftTokensPlugin.h
@@ -0,0 +1,90 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "tensorrt_llm/plugins/common/plugin.h"
+#include
+#include
+#include
+#include
+
+namespace tensorrt_llm::plugins
+{
+
+class EagleDecodeDraftTokensPlugin : public BasePlugin
+{
+public:
+ EagleDecodeDraftTokensPlugin(nvinfer1::DataType type, int32_t layerIdx);
+
+ EagleDecodeDraftTokensPlugin(void const* data, size_t length);
+
+ ~EagleDecodeDraftTokensPlugin() override = default;
+
+ // IPluginV2DynamicExt Methods
+ nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
+ nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs,
+ nvinfer1::IExprBuilder& exprBuilder) noexcept override;
+ bool supportsFormatCombination(
+ int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override;
+ void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
+ nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override;
+ size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
+ nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override;
+ int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
+ void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
+
+ // IPluginV2Ext Methods
+ nvinfer1::DataType getOutputDataType(
+ int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override;
+
+ // IPluginV2 Methods
+ char const* getPluginType() const noexcept override;
+ char const* getPluginVersion() const noexcept override;
+ int getNbOutputs() const noexcept override;
+ int initialize() noexcept override;
+ void terminate() noexcept override;
+ size_t getSerializationSize() const noexcept override;
+ void serialize(void* buffer) const noexcept override;
+ void destroy() noexcept override;
+
+private:
+ nvinfer1::DataType mDtype;
+ int32_t mLayerIdx;
+};
+
+class EagleDecodeDraftTokensPluginCreator : public BaseCreator
+{
+public:
+ EagleDecodeDraftTokensPluginCreator();
+
+ char const* getPluginName() const noexcept override;
+
+ char const* getPluginVersion() const noexcept override;
+
+ nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
+
+ nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
+
+ nvinfer1::IPluginV2* deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept override;
+
+private:
+ static nvinfer1::PluginFieldCollection mFC;
+ static std::vector mPluginAttributes;
+};
+
+} // namespace tensorrt_llm::plugins
diff --git a/cpp/tensorrt_llm/plugins/eaglePlugin/eaglePrepareDrafterInputsPlugin.cpp b/cpp/tensorrt_llm/plugins/eaglePlugin/eaglePrepareDrafterInputsPlugin.cpp
new file mode 100644
index 000000000..39eef83d5
--- /dev/null
+++ b/cpp/tensorrt_llm/plugins/eaglePlugin/eaglePrepareDrafterInputsPlugin.cpp
@@ -0,0 +1,272 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
+ * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "eaglePrepareDrafterInputsPlugin.h"
+
+using namespace nvinfer1;
+using tensorrt_llm::plugins::EaglePrepareDrafterInputsPluginCreator;
+using tensorrt_llm::plugins::EaglePrepareDrafterInputsPlugin;
+
+static char const* EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_VERSION{"1"};
+static char const* EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_NAME{"EaglePrepareDrafterInputs"};
+PluginFieldCollection EaglePrepareDrafterInputsPluginCreator::mFC{};
+std::vector EaglePrepareDrafterInputsPluginCreator::mPluginAttributes;
+
+EaglePrepareDrafterInputsPlugin::EaglePrepareDrafterInputsPlugin(nvinfer1::DataType type, int32_t layerIdx)
+ : mDtype(type)
+ , mLayerIdx(layerIdx)
+{
+}
+
+// Parameterized constructor
+EaglePrepareDrafterInputsPlugin::EaglePrepareDrafterInputsPlugin(void const* data, size_t length)
+{
+ char const *d = reinterpret_cast(data), *a = d;
+ read(d, mDtype);
+ read(d, mLayerIdx);
+ TLLM_CHECK_WITH_INFO(d == a + length,
+ "Expected length (%d) != real length (%d). This is often "
+ "caused by using different TensorRT-LLM version to build "
+ "engine and run engine.",
+ static_cast(length), static_cast(d - a));
+}
+
+// IPluginV2DynamicExt Methods
+nvinfer1::IPluginV2DynamicExt* EaglePrepareDrafterInputsPlugin::clone() const noexcept
+{
+ auto* plugin = new EaglePrepareDrafterInputsPlugin(*this);
+ plugin->setPluginNamespace(mNamespace.c_str());
+ return plugin;
+}
+
+nvinfer1::DimsExprs EaglePrepareDrafterInputsPlugin::getOutputDimensions(
+ int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
+{
+ TLLM_CHECK(outputIndex < 10);
+ TLLM_CHECK(nbInputs == 7);
+ auto const batchSizeExpr = inputs[nbInputs - 2].d[0];
+ auto const maxDraftLenExpr = inputs[nbInputs - 2].d[1];
+
+ nvinfer1::DimsExprs ret;
+ switch (outputIndex)
+ {
+ case 0: // sequence_length
+ case 1: // host_request_types
+ case 2: // host_past_key_value_lengths
+ ret = inputs[outputIndex];
+ break;
+ case 3: // spec_decoding_generation_lengths
+ ret.nbDims = 1;
+ ret.d[0] = batchSizeExpr;
+ break;
+ case 4: // spec_decoding_position_offsets
+ case 5: // input_ids
+ case 6: // position_ids
+ // FIXME input_ids should have real value, not maxDraftLen
+ ret.nbDims = 1;
+ ret.d[0] = maxDraftLenExpr;
+ break;
+ case 7: // spec_decoding_packed_mask
+ // FIXME
+ ret.nbDims = 3;
+ ret.d[0] = batchSizeExpr;
+ ret.d[1] = maxDraftLenExpr;
+ ret.d[2] = exprBuilder.operation(DimensionOperation::kCEIL_DIV, *maxDraftLenExpr, *exprBuilder.constant(32));
+ break;
+ case 8: // hidden_dim
+ ret.nbDims = 2;
+ // FIXME real dim instead of max draft len
+ ret.d[0] = maxDraftLenExpr;
+ ret.d[1] = inputs[4].d[1];
+ break;
+ }
+ return ret;
+}
+
+bool EaglePrepareDrafterInputsPlugin::supportsFormatCombination(
+ int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
+{
+ if (pos == nbInputs - 1 || pos == nbInputs + nbOutputs - 1) // hidden_states
+ {
+ return (inOut[pos].type == mDtype) && (inOut[pos].format == TensorFormat::kLINEAR);
+ }
+ else if (pos == 3) // kv cache pool pointers
+ {
+ return inOut[pos].type == nvinfer1::DataType::kINT64 && inOut[pos].format == TensorFormat::kLINEAR;
+ }
+ else // all other tensors
+ {
+ return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == TensorFormat::kLINEAR);
+ }
+}
+
+void EaglePrepareDrafterInputsPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
+ nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
+{
+}
+
+size_t EaglePrepareDrafterInputsPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
+ nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
+{
+ return 0;
+}
+
+int EaglePrepareDrafterInputsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
+ nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) noexcept
+{
+ // TODO fill me
+
+ return 0;
+}
+
+// IPluginV2Ext Methods
+nvinfer1::DataType EaglePrepareDrafterInputsPlugin::getOutputDataType(
+ int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
+{
+ TLLM_CHECK(index < 9);
+ if (index < 8)
+ {
+ return inputTypes[0]; // type of sequence_length
+ }
+ else // hidden_states
+ {
+ return inputTypes[nbInputs - 1]; // type of hidden_states
+ }
+}
+
+// IPluginV2 Methods
+
+char const* EaglePrepareDrafterInputsPlugin::getPluginType() const noexcept
+{
+ return EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_NAME;
+}
+
+char const* EaglePrepareDrafterInputsPlugin::getPluginVersion() const noexcept
+{
+ return EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_VERSION;
+}
+
+int EaglePrepareDrafterInputsPlugin::getNbOutputs() const noexcept
+{
+ return 9;
+}
+
+int EaglePrepareDrafterInputsPlugin::initialize() noexcept
+{
+ return 0;
+}
+
+void EaglePrepareDrafterInputsPlugin::terminate() noexcept {}
+
+size_t EaglePrepareDrafterInputsPlugin::getSerializationSize() const noexcept
+{
+ return sizeof(mDtype) + sizeof(mLayerIdx);
+}
+
+void EaglePrepareDrafterInputsPlugin::serialize(void* buffer) const noexcept
+{
+ char *d = static_cast(buffer), *a = d;
+ write(d, mLayerIdx);
+ write(d, mDtype);
+ assert(d == a + getSerializationSize());
+}
+
+void EaglePrepareDrafterInputsPlugin::destroy() noexcept
+{
+ // This gets called when the network containing plugin is destroyed
+ delete this;
+}
+
+///////////////
+
+EaglePrepareDrafterInputsPluginCreator::EaglePrepareDrafterInputsPluginCreator()
+{
+ // Fill PluginFieldCollection with PluginField arguments metadata
+ mPluginAttributes.clear();
+ mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
+ mPluginAttributes.emplace_back(PluginField("layer_idx", nullptr, PluginFieldType::kINT32, 0));
+ mFC.nbFields = mPluginAttributes.size();
+ mFC.fields = mPluginAttributes.data();
+}
+
+char const* EaglePrepareDrafterInputsPluginCreator::getPluginName() const noexcept
+{
+ return EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_NAME;
+}
+
+char const* EaglePrepareDrafterInputsPluginCreator::getPluginVersion() const noexcept
+{
+ return EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_VERSION;
+}
+
+PluginFieldCollection const* EaglePrepareDrafterInputsPluginCreator::getFieldNames() noexcept
+{
+ return &mFC;
+}
+
+IPluginV2* EaglePrepareDrafterInputsPluginCreator::createPlugin(
+ char const* name, PluginFieldCollection const* fc) noexcept
+{
+ PluginField const* fields = fc->fields;
+ int32_t layerIdx;
+ nvinfer1::DataType type;
+ // Read configurations from each fields
+ for (int i = 0; i < fc->nbFields; ++i)
+ {
+ char const* attrName = fields[i].name;
+ if (!strcmp(attrName, "layer_idx"))
+ {
+ TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
+ layerIdx = *static_cast(fields[i].data);
+ }
+ else if (!strcmp(attrName, "type_id"))
+ {
+ TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
+ type = static_cast(*(static_cast(fields[i].data)));
+ }
+ }
+
+ try
+ {
+ auto* obj = new EaglePrepareDrafterInputsPlugin(type, layerIdx);
+ obj->setPluginNamespace(mNamespace.c_str());
+ return obj;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+IPluginV2* EaglePrepareDrafterInputsPluginCreator::deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept
+{
+ // This object will be deleted when the network is destroyed, which will
+ // call EaglePrepareDrafterInputsPlugin::destroy()
+ try
+ {
+ auto* obj = new EaglePrepareDrafterInputsPlugin(serialData, serialLength);
+ obj->setPluginNamespace(mNamespace.c_str());
+ return obj;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
diff --git a/cpp/tensorrt_llm/plugins/eaglePlugin/eaglePrepareDrafterInputsPlugin.h b/cpp/tensorrt_llm/plugins/eaglePlugin/eaglePrepareDrafterInputsPlugin.h
new file mode 100644
index 000000000..d88238ba8
--- /dev/null
+++ b/cpp/tensorrt_llm/plugins/eaglePlugin/eaglePrepareDrafterInputsPlugin.h
@@ -0,0 +1,90 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "tensorrt_llm/plugins/common/plugin.h"
+#include
+#include
+#include
+#include
+
+namespace tensorrt_llm::plugins
+{
+
+class EaglePrepareDrafterInputsPlugin : public BasePlugin
+{
+public:
+ EaglePrepareDrafterInputsPlugin(nvinfer1::DataType type, int32_t layerIdx);
+
+ EaglePrepareDrafterInputsPlugin(void const* data, size_t length);
+
+ ~EaglePrepareDrafterInputsPlugin() override = default;
+
+ // IPluginV2DynamicExt Methods
+ nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
+ nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs,
+ nvinfer1::IExprBuilder& exprBuilder) noexcept override;
+ bool supportsFormatCombination(
+ int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override;
+ void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
+ nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override;
+ size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
+ nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override;
+ int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
+ void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
+
+ // IPluginV2Ext Methods
+ nvinfer1::DataType getOutputDataType(
+ int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override;
+
+ // IPluginV2 Methods
+ char const* getPluginType() const noexcept override;
+ char const* getPluginVersion() const noexcept override;
+ int getNbOutputs() const noexcept override;
+ int initialize() noexcept override;
+ void terminate() noexcept override;
+ size_t getSerializationSize() const noexcept override;
+ void serialize(void* buffer) const noexcept override;
+ void destroy() noexcept override;
+
+private:
+ nvinfer1::DataType mDtype;
+ int32_t mLayerIdx;
+};
+
+class EaglePrepareDrafterInputsPluginCreator : public BaseCreator
+{
+public:
+ EaglePrepareDrafterInputsPluginCreator();
+
+ char const* getPluginName() const noexcept override;
+
+ char const* getPluginVersion() const noexcept override;
+
+ nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
+
+ nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
+
+ nvinfer1::IPluginV2* deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept override;
+
+private:
+ static nvinfer1::PluginFieldCollection mFC;
+ static std::vector mPluginAttributes;
+};
+
+} // namespace tensorrt_llm::plugins
diff --git a/cpp/tensorrt_llm/plugins/eaglePlugin/eagleSampleAndAcceptDraftTokensPlugin.cpp b/cpp/tensorrt_llm/plugins/eaglePlugin/eagleSampleAndAcceptDraftTokensPlugin.cpp
new file mode 100644
index 000000000..42f03d93b
--- /dev/null
+++ b/cpp/tensorrt_llm/plugins/eaglePlugin/eagleSampleAndAcceptDraftTokensPlugin.cpp
@@ -0,0 +1,515 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
+ * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "eagleSampleAndAcceptDraftTokensPlugin.h"
+
+#include "tensorrt_llm/common/assert.h"
+#include "tensorrt_llm/common/dataType.h"
+#include "tensorrt_llm/common/memoryUtils.h"
+#include "tensorrt_llm/kernels/samplingTopKKernels.h"
+#include "tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h"
+#include "tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.h"
+#include "tensorrt_llm/runtime/common.h"
+
+using namespace nvinfer1;
+using tensorrt_llm::plugins::EagleSampleAndAcceptDraftTokensPluginCreator;
+using tensorrt_llm::plugins::EagleSampleAndAcceptDraftTokensPlugin;
+using namespace tensorrt_llm::kernels;
+using namespace tensorrt_llm::kernels::speculative_decoding;
+using namespace tensorrt_llm::runtime;
+namespace tc = tensorrt_llm::common;
+
+static char const* EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_VERSION{"1"};
+static char const* EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_NAME{"EagleSampleAndAcceptDraftTokens"};
+PluginFieldCollection EagleSampleAndAcceptDraftTokensPluginCreator::mFC{};
+std::vector EagleSampleAndAcceptDraftTokensPluginCreator::mPluginAttributes;
+
+EagleSampleAndAcceptDraftTokensPlugin::EagleSampleAndAcceptDraftTokensPlugin(
+ nvinfer1::DataType type, bool greedySampling)
+ : mDtype(type)
+ , mGreedySampling(greedySampling)
+{
+ TLLM_CHECK_WITH_INFO(mGreedySampling, "Non-greedy sampling is not supported yet.");
+}
+
+// Parameterized constructor
+EagleSampleAndAcceptDraftTokensPlugin::EagleSampleAndAcceptDraftTokensPlugin(void const* data, size_t length)
+{
+ char const *d = reinterpret_cast(data), *a = d;
+ read(d, mDtype);
+ read(d, mGreedySampling);
+ TLLM_CHECK_WITH_INFO(d == a + length,
+ "Expected length (%d) != real length (%d). This is often "
+ "caused by using different TensorRT-LLM version to build "
+ "engine and run engine.",
+ (int) length, (int) (d - a));
+}
+
+// IPluginV2DynamicExt Methods
+nvinfer1::IPluginV2DynamicExt* EagleSampleAndAcceptDraftTokensPlugin::clone() const noexcept
+{
+ auto* plugin = new EagleSampleAndAcceptDraftTokensPlugin(*this);
+ plugin->setPluginNamespace(mNamespace.c_str());
+ return plugin;
+}
+
+nvinfer1::DimsExprs EagleSampleAndAcceptDraftTokensPlugin::getOutputDimensions(
+ int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
+{
+ TLLM_CHECK(nbInputs == 6);
+ TLLM_CHECK(outputIndex < 7);
+ auto const batchSizeExpr = inputs[getIdx(InputIdxEntry::PATHS)].d[0];
+ auto const maxDecodingDraftTokensExpr = inputs[getIdx(InputIdxEntry::DRAFT_TOKEN_IDS)].d[1];
+ auto const maxPathLenExpr = inputs[getIdx(InputIdxEntry::PATHS)].d[2];
+
+ nvinfer1::DimsExprs ret;
+ switch (outputIndex)
+ {
+ case 0: // accepted_tokens
+ ret.nbDims = 2;
+ ret.d[0] = batchSizeExpr;
+ ret.d[1] = maxPathLenExpr;
+ break;
+ case 1: // num_accepted_tokens
+ ret.nbDims = 1;
+ ret.d[0] = batchSizeExpr;
+ break;
+ case 2: // accepted_paths
+ ret.nbDims = 1;
+ ret.d[0] = batchSizeExpr;
+ break;
+ case 3: // last_accepted_tokens
+ ret.nbDims = 1;
+ ret.d[0] = batchSizeExpr;
+ break;
+ case 4: // exclusive_sum_last_accepted_indices
+ ret.nbDims = 1;
+ ret.d[0] = batchSizeExpr;
+ break;
+ case 5: // next_draft_tokens
+ ret.nbDims = 2;
+ ret.d[0] = batchSizeExpr;
+ ret.d[1] = maxDecodingDraftTokensExpr;
+ break;
+ case 6: // next_draft_lens
+ ret.nbDims = 1;
+ ret.d[0] = batchSizeExpr;
+ break;
+ }
+ return ret;
+}
+
+bool EagleSampleAndAcceptDraftTokensPlugin::supportsFormatCombination(
+ int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
+{
+ if (pos == getIdx(InputIdxEntry::LOGITS)) // logits
+ {
+ return (inOut[pos].type == mDtype) && (inOut[pos].format == TensorFormat::kLINEAR);
+ }
+ else if (pos == getIdx(InputIdxEntry::TEMPERATURE)
+ || pos == getIdx(InputIdxEntry::RAND_VALIDATION)) // temperature, rand_validation
+ {
+ return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR);
+ }
+ else // everything else
+ {
+ return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == TensorFormat::kLINEAR);
+ }
+}
+
+void EagleSampleAndAcceptDraftTokensPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
+ nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
+{
+}
+
+template
+size_t EagleSampleAndAcceptDraftTokensPlugin::getWorkspaceSizeType(nvinfer1::PluginTensorDesc const* inputs,
+ int nbInputs, nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
+{
+ size_t workspaceSize{0};
+
+ auto const vocabSizePadded = inputs[getIdx(InputIdxEntry::LOGITS)].dims.d[1];
+ auto const batchSize = inputs[getIdx(InputIdxEntry::PATHS)].dims.d[0];
+ auto const maxDecodingTokens = inputs[getIdx(InputIdxEntry::PATHS)].dims.d[1];
+
+ // Greedy sampling
+ {
+ // Top1 sampling workspace
+ auto const primarySamplingWorkspaceSize
+ = getTopKWorkspaceSize(batchSize, maxDecodingTokens, /* maxTopK */ 1, vocabSizePadded);
+
+ // Target output ids
+ auto const targetOutputIdsSize = batchSize * maxDecodingTokens * sizeof(TokenIdType);
+
+ // Logits ptrs
+ auto const logitsPtrsSize = batchSize * maxDecodingTokens * sizeof(T*);
+ SizeType32 constexpr NUM_BUFFERS{4};
+ size_t workspaces[NUM_BUFFERS];
+ workspaces[0] = primarySamplingWorkspaceSize;
+ workspaces[1] = targetOutputIdsSize;
+ workspaces[2] = logitsPtrsSize;
+ workspaces[3] = batchSize * sizeof(SizeType32);
+ workspaceSize = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
+ }
+
+ return workspaceSize;
+}
+
+size_t EagleSampleAndAcceptDraftTokensPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
+ nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
+{
+ auto const logitsType = inputs[getIdx(InputIdxEntry::LOGITS)].type;
+ if (logitsType == nvinfer1::DataType::kFLOAT)
+ {
+ return getWorkspaceSizeType(inputs, nbInputs, outputs, nbOutputs);
+ }
+ else if (logitsType == nvinfer1::DataType::kHALF)
+ {
+ return getWorkspaceSizeType<__half>(inputs, nbInputs, outputs, nbOutputs);
+ }
+ else
+ {
+ TLLM_CHECK_WITH_INFO(false, "Unsupported logits type");
+ }
+ return 0;
+}
+
+template
+void EagleSampleAndAcceptDraftTokensPlugin::samplePrimeHeadTokens(nvinfer1::PluginTensorDesc const* inputDesc,
+ nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) noexcept
+{
+ TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
+
+ auto const maxNumTokens = inputDesc[getIdx(InputIdxEntry::LOGITS)].dims.d[0];
+ auto const vocabSizePadded = inputDesc[getIdx(InputIdxEntry::LOGITS)].dims.d[1];
+ auto const batchSize = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[0];
+ auto const maxDecodingTokens = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[1];
+
+ auto logits = static_cast(inputs[getIdx(InputIdxEntry::LOGITS)]);
+ auto prevDraftLens = reinterpret_cast(inputs[getIdx(InputIdxEntry::DRAFT_LENS)]);
+
+ int8_t* workspaceBytePtr = reinterpret_cast(workspace);
+ size_t offset{0};
+
+ auto const samplingWorkspaceSize
+ = getTopKWorkspaceSize(batchSize, maxDecodingTokens, /* maxTopK */ 1, vocabSizePadded);
+
+ void* workspaceSampling
+ = reinterpret_cast(tc::nextWorkspacePtr(workspaceBytePtr, offset, samplingWorkspaceSize));
+ TokenIdType* outputIds = reinterpret_cast(
+ tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * maxDecodingTokens * sizeof(TokenIdType)));
+ T const** logitsPtrs = reinterpret_cast(
+ tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * maxDecodingTokens * sizeof(T*)));
+ SizeType32* decodingTokens
+ = reinterpret_cast(tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * sizeof(SizeType32)));
+
+ // Assemble pointers to logits
+ invokeAssembleTargetLogitsOffsets(
+ logitsPtrs, decodingTokens, logits, prevDraftLens, batchSize, maxDecodingTokens, vocabSizePadded, stream);
+
+ sync_check_cuda_error();
+
+ TopKSamplingKernelParams params;
+ params.logProbsPtrs = logitsPtrs;
+ params.outputIds = outputIds;
+ params.workspace = workspaceSampling;
+ params.maxTopK = 1;
+ params.batchSize = batchSize;
+ params.maxBatchSize = batchSize;
+ params.tokensPerStep = decodingTokens;
+ params.maxTokensPerStep = maxDecodingTokens;
+ params.maxSeqLen = maxDecodingTokens;
+ params.vocabSizePadded = vocabSizePadded;
+
+ invokeBatchTopKSampling(params, stream);
+
+ sync_check_cuda_error();
+
+ TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
+}
+
+template
+void EagleSampleAndAcceptDraftTokensPlugin::acceptDraftTokens(nvinfer1::PluginTensorDesc const* inputDesc,
+ nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) noexcept
+{
+ TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
+
+ auto const maxNumTokens = inputDesc[getIdx(InputIdxEntry::LOGITS)].dims.d[0];
+ auto const vocabSizePadded = inputDesc[getIdx(InputIdxEntry::LOGITS)].dims.d[1];
+
+ auto const batchSize = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[0];
+ auto const maxDecodingTokens = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[1];
+ auto const maxPathLen = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[2];
+ auto const maxDraftPathLen = maxPathLen - 1;
+
+ int8_t* workspaceBytePtr = reinterpret_cast(workspace);
+ size_t offset{0};
+
+ auto const samplingWorkspaceSize
+ = getTopKWorkspaceSize(batchSize, maxDecodingTokens, /* maxTopK */ 1, vocabSizePadded);
+
+ void* workspaceSampling
+ = reinterpret_cast(tc::nextWorkspacePtr(workspaceBytePtr, offset, samplingWorkspaceSize));
+ TokenIdType* outputIds = reinterpret_cast(
+ tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * maxDecodingTokens * sizeof(TokenIdType)));
+
+ AcceptDraftTokensByIdsWithPathsParams params;
+ params.outputIds = reinterpret_cast(outputs[getIdx(OutputIdxEntry::ACCEPTED_TOKENS)]);
+ params.draftIds = reinterpret_cast(inputs[getIdx(InputIdxEntry::DRAFT_TOKEN_IDS)]);
+ params.targetIds = outputIds;
+ params.acceptedLengths = reinterpret_cast(outputs[getIdx(OutputIdxEntry::ACCEPTED_LEN)]);
+ params.paths = reinterpret_cast(inputs[getIdx(InputIdxEntry::PATHS)]);
+ params.bestPathIds = reinterpret_cast(outputs[getIdx(OutputIdxEntry::BEST_ACCEPTED_PATHS)]);
+ params.batchSize = batchSize;
+ params.maxBatchSize = batchSize;
+ params.vocabSize = vocabSizePadded;
+ params.maxSeqLen = maxPathLen;
+ params.maxDraftPathLen = maxDraftPathLen;
+ params.maxDecodingTokens = maxDecodingTokens;
+ params.stream = stream;
+
+ params.checkParams();
+
+ acceptDraftTokensByIdsWithPaths(params);
+
+ sync_check_cuda_error();
+
+ TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
+}
+
+template
+void EagleSampleAndAcceptDraftTokensPlugin::doGreedy(nvinfer1::PluginTensorDesc const* inputDesc,
+ nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) noexcept
+{
+ TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
+
+ // Sample all main head tokens with Top-1.
+ samplePrimeHeadTokens(inputDesc, outputDesc, inputs, outputs, workspace, stream);
+
+ // Greedy accept tokens based on token ids, write the best path and best token id.
+ acceptDraftTokens(inputDesc, outputDesc, inputs, outputs, workspace, stream);
+
+ TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
+}
+
+void EagleSampleAndAcceptDraftTokensPlugin::selectLastAccTokenAndComputeIndicesCumSum(
+ nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
+ void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
+{
+ TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
+
+ auto const batchSize = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[0];
+ auto const maxDecodingTokens = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[1];
+ auto const maxPathLen = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[2];
+
+ auto lastAcceptedTokenIds
+ = reinterpret_cast(outputs[getIdx(OutputIdxEntry::LAST_ACCEPTED_TOKEN_IDS)]);
+ auto exclusiveSumLastAcceptedIndices
+ = reinterpret_cast(outputs[getIdx(OutputIdxEntry::EXCLUSIVE_SUM_LAST_TOKEN_INDICES)]);
+ auto prevDraftLens = reinterpret_cast(inputs[getIdx(InputIdxEntry::DRAFT_LENS)]);
+ auto acceptedTokenIds = reinterpret_cast(outputs[getIdx(OutputIdxEntry::ACCEPTED_TOKENS)]);
+ auto acceptedLengths = reinterpret_cast(outputs[getIdx(OutputIdxEntry::ACCEPTED_LEN)]);
+ auto bestPathIds = reinterpret_cast(outputs[getIdx(OutputIdxEntry::BEST_ACCEPTED_PATHS)]);
+ auto paths = reinterpret_cast(inputs[getIdx(InputIdxEntry::PATHS)]);
+
+ invokeSelectLastAccTokenAndComputeIndicesCumSum(lastAcceptedTokenIds, exclusiveSumLastAcceptedIndices,
+ prevDraftLens, acceptedTokenIds, acceptedLengths, bestPathIds, paths, batchSize, maxDecodingTokens, maxPathLen,
+ stream);
+
+ sync_check_cuda_error();
+
+ TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
+}
+
+template
+void EagleSampleAndAcceptDraftTokensPlugin::enqueueType(nvinfer1::PluginTensorDesc const* inputDesc,
+ nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) noexcept
+{
+ TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
+
+ // TODO split batch into greedy and non-greedy and execute both paths
+ if (mGreedySampling)
+ {
+ doGreedy(inputDesc, outputDesc, inputs, outputs, workspace, stream);
+ }
+ else
+ {
+ // TODO fill me
+ TLLM_CHECK_WITH_INFO(false, "Non-greedy sampling is not supported yet");
+ }
+
+ // Find last accepted tokens and do cumulative sum of accepted indices.
+ selectLastAccTokenAndComputeIndicesCumSum(inputDesc, outputDesc, inputs, outputs, workspace, stream);
+
+ TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
+}
+
+int EagleSampleAndAcceptDraftTokensPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
+ nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) noexcept
+{
+ auto const logitsType = inputDesc[getIdx(InputIdxEntry::LOGITS)].type;
+ if (logitsType == nvinfer1::DataType::kFLOAT)
+ {
+ enqueueType(inputDesc, outputDesc, inputs, outputs, workspace, stream);
+ }
+ else if (logitsType == nvinfer1::DataType::kHALF)
+ {
+ enqueueType<__half>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
+ }
+ else
+ {
+ TLLM_CHECK_WITH_INFO(false, "Unsupported logits type");
+ }
+
+ return 0;
+}
+
+// IPluginV2Ext Methods
+nvinfer1::DataType EagleSampleAndAcceptDraftTokensPlugin::getOutputDataType(
+ int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
+{
+ TLLM_CHECK(index < 7);
+ // input 1 is draft tokens now of int32 type. All outputs are int32_t as well.
+ return inputTypes[getIdx(InputIdxEntry::DRAFT_TOKEN_IDS)];
+}
+
+// IPluginV2 Methods
+
+char const* EagleSampleAndAcceptDraftTokensPlugin::getPluginType() const noexcept
+{
+ return EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_NAME;
+}
+
+char const* EagleSampleAndAcceptDraftTokensPlugin::getPluginVersion() const noexcept
+{
+ return EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_VERSION;
+}
+
+int EagleSampleAndAcceptDraftTokensPlugin::getNbOutputs() const noexcept
+{
+ return 7;
+}
+
+int EagleSampleAndAcceptDraftTokensPlugin::initialize() noexcept
+{
+ return 0;
+}
+
+void EagleSampleAndAcceptDraftTokensPlugin::terminate() noexcept {}
+
+size_t EagleSampleAndAcceptDraftTokensPlugin::getSerializationSize() const noexcept
+{
+ return sizeof(mDtype) + sizeof(mGreedySampling);
+}
+
+void EagleSampleAndAcceptDraftTokensPlugin::serialize(void* buffer) const noexcept
+{
+ char *d = static_cast(buffer), *a = d;
+ write(d, mDtype);
+ write(d, mGreedySampling);
+ assert(d == a + getSerializationSize());
+}
+
+void EagleSampleAndAcceptDraftTokensPlugin::destroy() noexcept
+{
+ // This gets called when the network containing plugin is destroyed
+ delete this;
+}
+
+///////////////
+
+EagleSampleAndAcceptDraftTokensPluginCreator::EagleSampleAndAcceptDraftTokensPluginCreator()
+{
+ // Fill PluginFieldCollection with PluginField arguments metadata
+ mPluginAttributes.clear();
+ mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
+ mPluginAttributes.emplace_back(PluginField("greedy_sampling", nullptr, PluginFieldType::kINT32, 1));
+ mFC.nbFields = mPluginAttributes.size();
+ mFC.fields = mPluginAttributes.data();
+}
+
+char const* EagleSampleAndAcceptDraftTokensPluginCreator::getPluginName() const noexcept
+{
+ return EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_NAME;
+}
+
+char const* EagleSampleAndAcceptDraftTokensPluginCreator::getPluginVersion() const noexcept
+{
+ return EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_VERSION;
+}
+
+PluginFieldCollection const* EagleSampleAndAcceptDraftTokensPluginCreator::getFieldNames() noexcept
+{
+ return &mFC;
+}
+
+IPluginV2* EagleSampleAndAcceptDraftTokensPluginCreator::createPlugin(
+ char const* name, PluginFieldCollection const* fc) noexcept
+{
+ PluginField const* fields = fc->fields;
+ nvinfer1::DataType type;
+ bool greedySampling;
+ // Read configurations from each fields
+ for (int i = 0; i < fc->nbFields; ++i)
+ {
+ char const* attrName = fields[i].name;
+ if (!strcmp(attrName, "type_id"))
+ {
+ TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
+ type = static_cast(*(static_cast(fields[i].data)));
+ }
+ else if (!strcmp(attrName, "greedy_sampling"))
+ {
+ TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
+ greedySampling = static_cast(*static_cast(fields[i].data));
+ }
+ }
+
+ try
+ {
+ auto* obj = new EagleSampleAndAcceptDraftTokensPlugin(type, greedySampling);
+ obj->setPluginNamespace(mNamespace.c_str());
+ return obj;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+IPluginV2* EagleSampleAndAcceptDraftTokensPluginCreator::deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept
+{
+ // This object will be deleted when the network is destroyed, which will
+ // call EagleSampleAndAcceptDraftTokensPlugin::destroy()
+ try
+ {
+ auto* obj = new EagleSampleAndAcceptDraftTokensPlugin(serialData, serialLength);
+ obj->setPluginNamespace(mNamespace.c_str());
+ return obj;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
diff --git a/cpp/tensorrt_llm/plugins/eaglePlugin/eagleSampleAndAcceptDraftTokensPlugin.h b/cpp/tensorrt_llm/plugins/eaglePlugin/eagleSampleAndAcceptDraftTokensPlugin.h
new file mode 100644
index 000000000..b2de11e9b
--- /dev/null
+++ b/cpp/tensorrt_llm/plugins/eaglePlugin/eagleSampleAndAcceptDraftTokensPlugin.h
@@ -0,0 +1,163 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "tensorrt_llm/plugins/common/plugin.h"
+
+#include
+#include
+#include
+#include
+#include
+
+namespace tensorrt_llm::plugins
+{
+
+class EagleSampleAndAcceptDraftTokensPlugin : public BasePlugin
+{
+public:
+ EagleSampleAndAcceptDraftTokensPlugin(nvinfer1::DataType type, bool greedySampling);
+
+ EagleSampleAndAcceptDraftTokensPlugin(void const* data, size_t length);
+
+ ~EagleSampleAndAcceptDraftTokensPlugin() override = default;
+
+ // IPluginV2DynamicExt Methods
+ nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
+ nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs,
+ nvinfer1::IExprBuilder& exprBuilder) noexcept override;
+ bool supportsFormatCombination(
+ int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override;
+ void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
+ nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override;
+ size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
+ nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override;
+ int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
+ void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
+
+ // IPluginV2Ext Methods
+ nvinfer1::DataType getOutputDataType(
+ int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override;
+
+ // IPluginV2 Methods
+ char const* getPluginType() const noexcept override;
+ char const* getPluginVersion() const noexcept override;
+ int getNbOutputs() const noexcept override;
+ int initialize() noexcept override;
+ void terminate() noexcept override;
+ size_t getSerializationSize() const noexcept override;
+ void serialize(void* buffer) const noexcept override;
+ void destroy() noexcept override;
+
+private:
+ enum class InputIdxEntry : int32_t
+ {
+ //! [num_tokens, vocab_size_padded]
+ LOGITS = 0,
+ //! [batch_size, max_decoding_draft_tokens]
+ DRAFT_TOKEN_IDS,
+ //! [batch_size]
+ DRAFT_LENS,
+ //! [batch_size]
+ TEMPERATURE,
+ //! []?
+ RAND_VALIDATION,
+ //! [batch_size, max_decoding_tokens, max_path_len]
+ PATHS
+ };
+
+ enum class OutputIdxEntry : int32_t
+ {
+ //! [batch_size, max_draft_path_len]
+ ACCEPTED_TOKENS = 0,
+ //! [batch_size]
+ ACCEPTED_LEN,
+ //! [batch_size]
+ BEST_ACCEPTED_PATHS,
+ //! [batch_size]
+ LAST_ACCEPTED_TOKEN_IDS,
+ //! [batch_size]
+ EXCLUSIVE_SUM_LAST_TOKEN_INDICES,
+ //! [batch_size, max_decoding_draft_tokens]
+ NEXT_DRAFT_TOKEN_IDS,
+ //! [batch_size]
+ NEXT_DRAFT_LENS
+ };
+
+ int32_t getIdx(InputIdxEntry idx) const
+ {
+ return static_cast(idx);
+ }
+
+ int32_t getIdx(OutputIdxEntry idx) const
+ {
+ return static_cast(idx);
+ }
+
+private:
+ template
+ size_t getWorkspaceSizeType(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
+ nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept;
+
+ template
+ void samplePrimeHeadTokens(nvinfer1::PluginTensorDesc const* inputDesc,
+ nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) noexcept;
+
+ template
+ void acceptDraftTokens(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
+ void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept;
+
+ template
+ void doGreedy(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
+ void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept;
+
+ void selectLastAccTokenAndComputeIndicesCumSum(nvinfer1::PluginTensorDesc const* inputDesc,
+ nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) noexcept;
+
+ template
+ void enqueueType(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
+ void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept;
+
+private:
+ nvinfer1::DataType mDtype;
+ bool mGreedySampling;
+};
+
+class EagleSampleAndAcceptDraftTokensPluginCreator : public BaseCreator
+{
+public:
+ EagleSampleAndAcceptDraftTokensPluginCreator();
+
+ char const* getPluginName() const noexcept override;
+
+ char const* getPluginVersion() const noexcept override;
+
+ nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
+
+ nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
+
+ nvinfer1::IPluginV2* deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept override;
+
+private:
+ static nvinfer1::PluginFieldCollection mFC;
+ static std::vector mPluginAttributes;
+};
+
+} // namespace tensorrt_llm::plugins
diff --git a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp
index 40c8113ac..ff75e9f6a 100644
--- a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp
+++ b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp
@@ -24,14 +24,17 @@ using namespace tensorrt_llm::kernels::cutlass_kernels;
using tensorrt_llm::plugins::WeightOnlyGroupwiseQuantMatmulPluginCreator;
using tensorrt_llm::plugins::WeightOnlyGroupwiseQuantMatmulPlugin;
using tensorrt_llm::plugins::WeightOnlyGroupwiseQuantGemmPluginProfiler;
+using tensorrt_llm::plugins::WeightOnlyGemmRunnerPtr;
// Flags for indicating whether the corresponding inputs are applied in mQuantAlgo
-// mQuantAlgo = pre_quant_scale * PRE_QUANT_SCALE + zero * ZERO + bias * BIAS
-// Here pre_quant_scale, zero and bias are boolean type
+// mQuantAlgo = int8_weight * INT8_WEIGHT + use_w4a8_awq * FP8_ALPHA + pre_quant_scale * PRE_QUANT_SCALE
+// + zero * ZERO + bias * BIAS
+// Here int8_weight, use_w4a8_awq, pre_quant_scale, zero and bias are boolean type
static constexpr int BIAS = int(1) << 0;
static constexpr int ZERO = int(1) << 1;
static constexpr int PRE_QUANT_SCALE = int(1) << 2;
static constexpr int FP8_ALPHA = int(1) << 3;
+static constexpr int INT8_WEIGHT = int(1) << 4;
using tensorrt_llm::plugins::read;
using tensorrt_llm::plugins::write;
@@ -43,11 +46,10 @@ std::vector WeightOnlyGroupwiseQuantMatmulPluginCreator::
void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic(int m, int n, int k,
WeightOnlyGroupwiseQuantGemmPluginProfiler::Config const& tactic, char* workspace, cudaStream_t const& stream)
{
- // Quantized weights are packed in FP16 format (INT4*4 -> FP16)
- int const originalN = n * FP16_INT4_RATIO;
+ // Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
+ int const originalN = mQuantAlgo & INT8_WEIGHT ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO;
half* actPtr = reinterpret_cast(workspace);
- cutlass::uint4b_t* weightPtr = reinterpret_cast(
- nextWorkspacePtr(reinterpret_cast(actPtr), m * k * sizeof(half)));
+ void* weightPtr = nextWorkspacePtr(reinterpret_cast(actPtr), m * k * sizeof(half));
half* inputScalesPtr
= reinterpret_cast(nextWorkspacePtr(reinterpret_cast(weightPtr), n * k * sizeof(float)));
half* zerosPtr = reinterpret_cast(
@@ -69,15 +71,22 @@ void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic(int m, int n, int k,
}
int const wsSize = mRunner->getWorkspaceSize(m, originalN, k);
-
- mRunner->gemm(actPtr, weightPtr, inputScalesPtr, zerosPtr, biasesPtr, outputPtr, m, originalN, k, mGroupSize,
- tactic, workspacePtr, wsSize, stream);
+ if (mQuantAlgo & INT8_WEIGHT)
+ {
+ mRunner->gemm(actPtr, reinterpret_cast(weightPtr), inputScalesPtr, zerosPtr, biasesPtr, outputPtr, m,
+ originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream);
+ }
+ else
+ {
+ mRunner->gemm(actPtr, reinterpret_cast(weightPtr), inputScalesPtr, zerosPtr, biasesPtr,
+ outputPtr, m, originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream);
+ }
}
void WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k)
{
- // Quantized weights are packed in FP16 format (INT4*4 -> FP16)
- int const originalN = n * FP16_INT4_RATIO;
+ // Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
+ int const originalN = mQuantAlgo & INT8_WEIGHT ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO;
std::vector workspaces = {
maxM * k * sizeof(half), // A
k * n * sizeof(float), // B
@@ -129,6 +138,38 @@ WeightOnlyGroupwiseQuantMatmulPlugin::WeightOnlyGroupwiseQuantMatmulPlugin(
(int) length, (int) (d - a));
}
+template
+using GemmRunner = tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner;
+
+template
+WeightOnlyGemmRunnerPtr selectGemmRunnerForZERO(int quant_algo)
+{
+ if (quant_algo & ZERO)
+ {
+ return std::make_shared>();
+ }
+ else
+ {
+ return std::make_shared<
+ GemmRunner>();
+ }
+}
+
+template
+WeightOnlyGemmRunnerPtr selectGemmRunnerForWeightType(int quant_algo)
+{
+ if (quant_algo & INT8_WEIGHT)
+ {
+ return selectGemmRunnerForZERO(quant_algo);
+ }
+ else
+ {
+ return selectGemmRunnerForZERO(quant_algo);
+ }
+}
+
void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int quant_algo, int group_size)
{
mArch = tensorrt_llm::common::getSMVersion();
@@ -136,7 +177,7 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
mQuantAlgo = quant_algo;
mGroupSize = group_size;
- // quant_algo = fp8_alpha * 8 + pre_quant_scale * 4 + zero * 2 + bias
+ // quant_algo = int8_weight * 16 + fp8_alpha * 8 + pre_quant_scale * 4 + zero * 2 + bias
mPreQuantScaleInputIdx = (quant_algo & PRE_QUANT_SCALE) ? 1 : 0;
mWeightInputIdx = mPreQuantScaleInputIdx + 1;
mScalesInputIdx = mWeightInputIdx + 1;
@@ -146,6 +187,7 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
if (mType == nvinfer1::DataType::kHALF)
{
+ // CUTLASS kernel selection
if (quant_algo & FP8_ALPHA)
{
// Ada & Hopper style kernels
@@ -153,45 +195,34 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
{
TLLM_THROW("W4A(fp)8 kernel is unsupported on pre-Ada (sm<89) architectures!");
}
- if (quant_algo & ZERO)
- {
- // has zeros
- m_weightOnlyGroupwiseGemmRunner = std::make_shared<
- tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, cutlass::uint4b_t,
- cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, half, half>>();
- }
- else
- {
- // no zeros
- m_weightOnlyGroupwiseGemmRunner
- = std::make_shared>();
- }
+ assert(!(quant_algo & INT8_WEIGHT) && "W4A(fp)8 kernel requires INT4 weight!");
+ m_weightOnlyGroupwiseGemmRunner
+ = selectGemmRunnerForZERO<__nv_fp8_e4m3, cutlass::uint4b_t, half>(quant_algo);
}
else
{
- if (quant_algo & ZERO)
- {
- // has zeros
- m_weightOnlyGroupwiseGemmRunner
- = std::make_shared>();
- }
- else
- {
- // no zeros
- m_weightOnlyGroupwiseGemmRunner
- = std::make_shared>();
- }
+ m_weightOnlyGroupwiseGemmRunner = selectGemmRunnerForWeightType(quant_algo);
+ }
+ // CUDA kernel selection
+ if (quant_algo & INT8_WEIGHT)
+ {
+ // INT8 weight
+ mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
+ mArch, tensorrt_llm::kernels::weight_only::KernelType::FP16Int8Groupwise);
+ mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::FP16Int8Groupwise;
+ }
+ else
+ {
+ // INT4 weight
+ mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
+ mArch, tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise);
+ mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise;
}
- mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
- mArch, tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise);
- mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise;
}
#if defined(ENABLE_BF16)
else if (mType == nvinfer1::DataType::kBF16)
{
+ // CUTLASS kernel selection
if (quant_algo & FP8_ALPHA)
{
// FP8 requires at least sm89 devices
@@ -203,24 +234,23 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
}
else
{
- if (quant_algo & ZERO)
- {
- // has zeros
- m_weightOnlyGroupwiseGemmRunner
- = std::make_shared>();
- }
- else
- {
- // no zeros
- m_weightOnlyGroupwiseGemmRunner
- = std::make_shared>();
- }
+ m_weightOnlyGroupwiseGemmRunner = selectGemmRunnerForWeightType<__nv_bfloat16>(quant_algo);
+ }
+ // CUDA kernel selection
+ if (quant_algo & INT8_WEIGHT)
+ {
+ // INT8 weight
+ mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
+ mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int8Groupwise);
+ mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int8Groupwise;
+ }
+ else
+ {
+ // INT4 weight
+ mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
+ mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise);
+ mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise;
}
- mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
- mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise);
- mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise;
}
#endif
else
@@ -273,8 +303,9 @@ nvinfer1::DimsExprs WeightOnlyGroupwiseQuantMatmulPlugin::getOutputDimensions(
ret.d[ii] = inputs[0].d[ii];
}
- // int4 weight only quant
- ret.d[nbDimsA - 1] = exprBuilder.constant(inputs[mWeightInputIdx].d[1]->getConstantValue() * FP16_INT4_RATIO);
+ // int4/int8 weight only quant (INT4*4 -> FP16, INT8*2 -> FP16)
+ int const weight_multiplier = mQuantAlgo & INT8_WEIGHT ? FP16_INT8_RATIO : FP16_INT4_RATIO;
+ ret.d[nbDimsA - 1] = exprBuilder.constant(inputs[mWeightInputIdx].d[1]->getConstantValue() * weight_multiplier);
return ret;
}
@@ -320,11 +351,12 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::configurePlugin(nvinfer1::DynamicPlug
int const maxK = in[0].max.d[in[0].max.nbDims - 1];
- // Quantized weights are packed in FP16 format (INT4*4 -> FP16)
- int const maxN = in[mWeightInputIdx].max.d[1] * FP16_INT4_RATIO;
+ // Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
+ int const weight_multiplier = mQuantAlgo & INT8_WEIGHT ? FP16_INT8_RATIO : FP16_INT4_RATIO;
+ int const maxN = in[mWeightInputIdx].max.d[1] * weight_multiplier;
auto const K = maxK;
- auto const N = maxN / FP16_INT4_RATIO;
+ auto const N = maxN / weight_multiplier;
if (!mDims.isInitialized())
{
@@ -424,8 +456,9 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(nvinfer1::PluginTensorDesc con
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF, "No valid weightOnlyGropwiseQuantMatmul configuration");
#endif
- // Quantized weights are packed in FP16 format (INT4*4 -> FP16)
- int real_n = n * FP16_INT4_RATIO;
+ // Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
+ int real_n = mQuantAlgo & INT8_WEIGHT ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO;
+
if (use_cuda_kernel)
{
void const* pre_quant_scale_ptr = nullptr;
diff --git a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h
index 7c65e6623..ed85d2098 100644
--- a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h
+++ b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h
@@ -46,6 +46,7 @@ constexpr int32_t INT8_BITS = 8;
constexpr int32_t INT4_BITS = 4;
constexpr int32_t INT8_INT4_RATIO = INT8_BITS / INT4_BITS;
constexpr int32_t FP16_INT4_RATIO = FP16_BITS / INT4_BITS;
+constexpr int32_t FP16_INT8_RATIO = FP16_BITS / INT8_BITS;
inline int32_t getWeightTypeMultiplier(WeightTypeId weightTypeId)
{
diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.cpp b/cpp/tensorrt_llm/pybind/executor/bindings.cpp
index 0d8f5a2ff..8ff76615d 100644
--- a/cpp/tensorrt_llm/pybind/executor/bindings.cpp
+++ b/cpp/tensorrt_llm/pybind/executor/bindings.cpp
@@ -140,6 +140,7 @@ void InitBindings(pybind11::module_& m)
.def_readwrite("iter", &tle::IterationStats::iter)
.def_readwrite("iter_latency_ms", &tle::IterationStats::iterLatencyMS)
.def_readwrite("new_active_requests_queue_latency_ms", &tle::IterationStats::newActiveRequestsQueueLatencyMS)
+ .def_readwrite("num_new_active_requests", &tle::IterationStats::numNewActiveRequests)
.def_readwrite("num_active_requests", &tle::IterationStats::numActiveRequests)
.def_readwrite("num_queued_requests", &tle::IterationStats::numQueuedRequests)
.def_readwrite("num_completed_requests", &tle::IterationStats::numCompletedRequests)
@@ -180,6 +181,9 @@ void InitBindings(pybind11::module_& m)
.def_readwrite("scheduled", &tle::RequestStats::scheduled)
.def_readwrite("paused", &tle::RequestStats::paused)
.def_readwrite("dis_serving_stats", &tle::RequestStats::disServingStats)
+ .def_readwrite("alloc_total_blocks_per_request", &tle::RequestStats::allocTotalBlocksPerRequest)
+ .def_readwrite("alloc_new_blocks_per_request", &tle::RequestStats::allocNewBlocksPerRequest)
+ .def_readwrite("reused_blocks_per_request", &tle::RequestStats::reusedBlocksPerRequest)
.def("to_json_str",
[](tle::RequestStats const& iterationStats) { return tle::JsonSerialization::toJsonStr(iterationStats); });
diff --git a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp
index da58300fa..9ad9e0bb9 100644
--- a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp
+++ b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp
@@ -266,6 +266,7 @@ void parsePluginConfig(ModelConfig& modelConfig, Json const& pluginConfig)
auto const manageWeightsType = parseJsonFieldOr(pluginConfig, "manage_weights", false)
? ModelConfig::ManageWeightsType::kEnabled
: ModelConfig::ManageWeightsType::kDisabled;
+ auto const ppReduceScatter = parseJsonFieldOr(pluginConfig, "pp_reduce_scatter", false);
TLLM_CHECK_WITH_INFO(
!removeInputPadding || modelConfig.getMaxNumTokens(), "Padding removal requires max_num_tokens to be set.");
@@ -283,6 +284,7 @@ void parsePluginConfig(ModelConfig& modelConfig, Json const& pluginConfig)
modelConfig.setPagedContextFMHA(pagedContextFMHA);
modelConfig.useXQA(useXQA);
modelConfig.setManageWeightsType(manageWeightsType);
+ modelConfig.setPpReduceScatter(ppReduceScatter);
}
void parseLora(ModelConfig& modelConfig, Json const& json, Json const& pluginConfig, bool engineVersionNone,
diff --git a/cpp/tensorrt_llm/runtime/gptSession.cpp b/cpp/tensorrt_llm/runtime/gptSession.cpp
index c5bc84cf1..73df2cb3f 100644
--- a/cpp/tensorrt_llm/runtime/gptSession.cpp
+++ b/cpp/tensorrt_llm/runtime/gptSession.cpp
@@ -72,7 +72,6 @@ auto const kProfileMbIdxs = populateMicrobatchIndexes();
GptSession::Config setPath(GptSession::Config const& original, std::string const& path)
{
GptSession::Config config = original;
- config.enginePath = std::filesystem::path(path);
return config;
}
diff --git a/cpp/tensorrt_llm/runtime/tllmRuntime.cpp b/cpp/tensorrt_llm/runtime/tllmRuntime.cpp
index a20046079..3cb9b05b6 100644
--- a/cpp/tensorrt_llm/runtime/tllmRuntime.cpp
+++ b/cpp/tensorrt_llm/runtime/tllmRuntime.cpp
@@ -408,9 +408,7 @@ void TllmRuntime::loadManagedWeights(RawEngine const& rawEngine, int localRank)
{
TLLM_LOG_DEBUG("Loading managed weight: %s", name.c_str());
auto iTensor = tensorrt_llm::executor::detail::toITensor(weight);
- auto weightsDevice = std::shared_ptr{
- manager.allocate(MemoryType::kGPU, iTensor->getShape(), iTensor->getDataType())};
- manager.copy(iTensor->data(), *weightsDevice, MemoryType::kCPU);
+ auto weightsDevice = std::shared_ptr{manager.copyFrom(*iTensor, MemoryType::kGPU)};
mManagedWeightsMap.insert(std::make_pair(name, weightsDevice));
}
}
diff --git a/cpp/tests/kernels/decodingKernelTest.cpp b/cpp/tests/kernels/decodingKernelTest.cpp
index 9b9a868b4..f820760e7 100644
--- a/cpp/tests/kernels/decodingKernelTest.cpp
+++ b/cpp/tests/kernels/decodingKernelTest.cpp
@@ -1326,16 +1326,34 @@ class DecodingKernelsTest : public testing::Test
void callAcceptByIdsWithPaths()
{
- tksp::acceptDraftTokensByIdsWithPaths(bufferCast(*mOutputTokens),
- bufferCast(*mDraftTokens), bufferCast(*mTargetTokens),
- bufferCast(*mSequenceLengths), bufferCast(*mAcceptedLengths),
- reinterpret_cast(bufferCast(*mFinishedFinal)),
- bufferCast(*mBatchSlots), bufferCast(*mPaths), bufferCast(*mEndIds),
- reinterpret_cast(bufferCast(*mMedusaInputLogitsPtrs)),
- reinterpret_cast(bufferCast(*mMedusaLogitsPtrs)),
- bufferCast(*mTokensPerStep), bufferCast(*mTokensPerStep),
- bufferCast(*mBestPaths), mBatchSize, mMaxBatchSize, mVocabSize, mMaxSeqLen, mMaxNumHeads,
- mMaxDraftSeqPerStep, mStream->get());
+ tksp::AcceptDraftTokensByIdsWithPathsParams params;
+
+ params.outputIds = bufferCast(*mOutputTokens);
+ params.draftIds = bufferCast(*mDraftTokens);
+ params.targetIds = bufferCast(*mTargetTokens);
+ params.sequenceLengths = bufferCast(*mSequenceLengths);
+ params.acceptedLengths = bufferCast(*mAcceptedLengths);
+ params.finishedFinal
+ = reinterpret_cast(bufferCast(*mFinishedFinal));
+ params.batchSlots = bufferCast(*mBatchSlots);
+ params.paths = bufferCast(*mPaths);
+ params.endIds = bufferCast(*mEndIds);
+ params.medusaLogits = reinterpret_cast(bufferCast(*mMedusaInputLogitsPtrs));
+ params.logitsPtrs = reinterpret_cast(bufferCast(*mMedusaLogitsPtrs));
+ params.curTokensPerStep = bufferCast(*mTokensPerStep);
+ params.targetTokensPerStep = bufferCast(*mTokensPerStep);
+ params.bestPathIds = bufferCast(*mBestPaths);
+ params.batchSize = mBatchSize;
+ params.maxBatchSize = mMaxBatchSize;
+ params.vocabSize = mVocabSize;
+ params.maxSeqLen = mMaxSeqLen;
+ params.maxDraftPathLen = mMaxNumHeads;
+ params.maxDecodingTokens = mMaxDraftSeqPerStep;
+ params.stream = mStream->get();
+
+ params.checkParams();
+
+ tksp::acceptDraftTokensByIdsWithPaths(params);
}
void callTestedKernel()
diff --git a/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp b/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp
index 402eea153..e3c479ba4 100644
--- a/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp
+++ b/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp
@@ -91,54 +91,59 @@ TYPED_TEST_SUITE(AirTopPSamplingKernelTest, FloatAndHalfTypes);
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessSmallP)
{
- this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.2f));
+ this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.2f));
};
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeP)
{
- this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.9f));
+ this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.9f));
};
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessAncestral)
{
- this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(1.0f));
+ this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(1.0f));
};
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeVocabSmallP)
{
- this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.2f));
+ this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.2f));
};
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeVocabLargeP)
{
- this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.9f));
+ this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.9f));
};
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessSmallP)
{
- this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.2f).setDeterministicTopP(true));
+ this->runTest(
+ SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.2f).setDeterministicTopP(true));
};
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeP)
{
- this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.9f).setDeterministicTopP(true));
+ this->runTest(
+ SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.9f).setDeterministicTopP(true));
};
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessAncestral)
{
- this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(1.0f).setDeterministicTopP(true));
+ this->runTest(
+ SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(1.0f).setDeterministicTopP(true));
};
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeVocabSmallP)
{
this->runTest(
- SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.2f).setDeterministicTopP(true));
+ SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.2f).setDeterministicTopP(
+ true));
};
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeVocabLargeP)
{
this->runTest(
- SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.9f).setDeterministicTopP(true));
+ SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.9f).setDeterministicTopP(
+ true));
};
class AirTopPSamplingKernelUtilsTest : public SamplingKernelTest