From 980dc2824ba473a65228609498162e3b829533d5 Mon Sep 17 00:00:00 2001 From: Kaiyu <26294424+kaiyux@users.noreply.github.com> Date: Wed, 21 Feb 2024 03:12:55 -0800 Subject: [PATCH 1/2] Update TensorRT-LLM --- benchmarks/cpp/README.md | 2 + benchmarks/cpp/gptManagerBenchmark.cpp | 393 +++++- cpp/CMakeLists.txt | 14 + .../batch_manager/kvCacheConfig.h | 8 + .../tensorrt_llm/batch_manager/llmRequest.h | 288 ++++- .../batch_manager/schedulerPolicy.h | 6 + .../batch_manager/trtGptModelOptionalParams.h | 13 +- cpp/include/tensorrt_llm/common/arrayView.h | 96 ++ cpp/include/tensorrt_llm/executor/executor.h | 416 +++++++ cpp/include/tensorrt_llm/executor/tensor.h | 272 +++++ cpp/include/tensorrt_llm/executor/types.h | 175 +++ .../tensorrt_llm/runtime/bufferManager.h | 4 +- cpp/include/tensorrt_llm/runtime/cudaStream.h | 6 + cpp/include/tensorrt_llm/runtime/iBuffer.h | 79 +- .../tensorrt_llm/runtime/samplingConfig.h | 38 + cpp/tensorrt_llm/CMakeLists.txt | 74 ++ .../libtensorrt_llm_batch_manager_static.a | 4 +- ...sorrt_llm_batch_manager_static.pre_cxx11.a | 4 +- .../x86_64-linux-gnu/version.txt | 4 +- .../libtensorrt_llm_executor_static.a | 3 + ...ibtensorrt_llm_executor_static.pre_cxx11.a | 3 + .../executor/x86_64-linux-gnu/version.txt | 2 + cpp/tensorrt_llm/kernels/decodingKernels.cu | 26 +- cpp/tensorrt_llm/kernels/decodingKernels.h | 3 +- .../onlineSoftmaxBeamsearchKernelsTemplate.h | 6 +- .../kernels/samplingAirTopPKernels.cu | 66 +- .../kernels/samplingTopKKernels.cu | 32 +- .../kernels/samplingTopKKernels.h | 5 +- .../kernels/samplingTopPKernels.cu | 41 +- .../kernels/samplingTopPKernels.h | 19 +- .../layers/dynamicDecodeLayer.cpp | 15 +- cpp/tensorrt_llm/layers/dynamicDecodeLayer.h | 11 +- cpp/tensorrt_llm/layers/topKSamplingLayer.cu | 6 +- cpp/tensorrt_llm/layers/topPSamplingLayer.cu | 15 +- .../bertAttentionPlugin.cpp | 41 +- cpp/tensorrt_llm/runtime/bufferManager.cpp | 12 +- cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp | 6 + cpp/tensorrt_llm/runtime/iBuffer.cpp | 7 +- cpp/tensorrt_llm/runtime/iTensor.cpp | 7 +- cpp/tensorrt_llm/runtime/tllmBuffers.h | 1 + cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp | 29 +- cpp/tensorrt_llm/thop/dynamicDecodeOp.h | 14 +- cpp/tests/CMakeLists.txt | 5 + .../kernels/sampling/samplingAirTopPTest.cpp | 8 +- cpp/tests/kernels/sampling/samplingTest.cpp | 2 +- .../kernels/sampling/samplingTopKTest.cpp | 6 +- .../kernels/sampling/samplingTopPTest.cpp | 11 +- cpp/tests/layers/dynamicDecodeLayerTest.cpp | 101 +- cpp/tests/layers/dynamicDecodeLayerTest.h | 17 +- cpp/tests/runtime/samplingConfigTest.cpp | 77 ++ examples/enc_dec/README.md | 4 + examples/enc_dec/build.py | 15 + examples/enc_dec/run.py | 38 +- examples/enc_dec/t5/weight.py | 5 +- examples/gemma/README.md | 734 +++++++++++ examples/gemma/convert_checkpoint.py | 856 +++++++++++++ examples/gemma/requirements.txt | 9 + examples/gemma/utils/__init__.py | 14 + examples/gemma/utils/layers.py | 39 + examples/gemma/utils/modules.py | 206 ++++ examples/gemma/utils/params.py | 73 ++ examples/gemma/utils/positional_embeddings.py | 92 ++ examples/gemma/utils/sampler.py | 190 +++ examples/gemma/utils/transformer.py | 113 ++ examples/llama/convert_checkpoint.py | 5 - examples/mixtral/README.md | 4 +- examples/mmlu.py | 2 +- examples/quantization/README.md | 30 +- examples/quantization/quantize.py | 16 +- examples/run.py | 2 +- examples/skywork/README.md | 20 +- examples/summarize.py | 5 +- examples/utils.py | 14 +- examples/whisper/README.md | 4 +- examples/whisper/build.py | 22 +- examples/whisper/requirements.txt | 1 + examples/whisper/run.py | 26 +- examples/whisper/weight.py | 110 +- tensorrt_llm/__init__.py | 6 +- tensorrt_llm/models/__init__.py | 3 + tensorrt_llm/models/enc_dec/model.py | 2 +- tensorrt_llm/models/gemma/__init__.py | 14 + tensorrt_llm/models/gemma/model.py | 456 +++++++ tensorrt_llm/models/gemma/smoothquant.py | 1072 +++++++++++++++++ tensorrt_llm/models/gemma/weight.py | 681 +++++++++++ tensorrt_llm/quantization/layers.py | 2 +- tensorrt_llm/runtime/generation.py | 13 +- tensorrt_llm/runtime/model_runner.py | 5 +- 88 files changed, 6938 insertions(+), 458 deletions(-) create mode 100644 cpp/include/tensorrt_llm/common/arrayView.h create mode 100644 cpp/include/tensorrt_llm/executor/executor.h create mode 100644 cpp/include/tensorrt_llm/executor/tensor.h create mode 100644 cpp/include/tensorrt_llm/executor/types.h create mode 100644 cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a create mode 100644 cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a create mode 100644 cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt create mode 100644 cpp/tests/runtime/samplingConfigTest.cpp create mode 100644 examples/gemma/README.md create mode 100644 examples/gemma/convert_checkpoint.py create mode 100644 examples/gemma/requirements.txt create mode 100644 examples/gemma/utils/__init__.py create mode 100644 examples/gemma/utils/layers.py create mode 100644 examples/gemma/utils/modules.py create mode 100644 examples/gemma/utils/params.py create mode 100644 examples/gemma/utils/positional_embeddings.py create mode 100644 examples/gemma/utils/sampler.py create mode 100644 examples/gemma/utils/transformer.py create mode 100644 tensorrt_llm/models/gemma/__init__.py create mode 100644 tensorrt_llm/models/gemma/model.py create mode 100644 tensorrt_llm/models/gemma/smoothquant.py create mode 100644 tensorrt_llm/models/gemma/weight.py diff --git a/benchmarks/cpp/README.md b/benchmarks/cpp/README.md index d6d8baa6e..506fdba8d 100644 --- a/benchmarks/cpp/README.md +++ b/benchmarks/cpp/README.md @@ -154,3 +154,5 @@ Take GPT-350M as an example for single GPU with static batching --static_emulated_timeout 100 \ --dataset ../../benchmarks/cpp/preprocessed_dataset.json ``` + +`gptManagerBenchmark` can also be used with the high-level C++ API defined by the `executor::Executor` class (see `cpp/include/tensorrt_llm/executor/executor.h`). This can be done by passing the argument `--api executor`. Note that the Executor class is still under development and currently does not support models with tp or pp > 1. diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index 59ae8ee35..27e5d938d 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -22,6 +22,7 @@ #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/mpiUtils.h" #include "tensorrt_llm/common/stringUtils.h" +#include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h" #include "tensorrt_llm/runtime/tllmLogger.h" #include "tensorrt_llm/runtime/worldConfig.h" @@ -39,9 +40,24 @@ using namespace tensorrt_llm::batch_manager; using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; +namespace texec = tensorrt_llm::executor; namespace mpi = tensorrt_llm::mpi; namespace trt = nvinfer1; +namespace +{ + +struct BenchmarkParams +{ + std::optional maxTokensInPagedKvCache = std::nullopt; + std::optional freeGpuMemoryFraction = std::nullopt; + bool enableTrtOverlap = false; + bool enableBlockReuse = false; + bool enableChunkedContext = false; + bool streaming = false; +}; +} // namespace + // Class holding all infos regarding a single work item. // This includes the original request, associated response factor // and state. @@ -223,6 +239,12 @@ class Recorder mRequestBenchInfos[requestId] = BenchInfo(inputLength, outputLength, start); } + void recordStart(SizeType inputLength, SizeType maxNewTokens, uint64_t requestId, + std::chrono::time_point const& start) + { + mRequestBenchInfos[requestId] = BenchInfo(inputLength, maxNewTokens, start); + } + void recordEnd(uint64_t requestId) { mRequestBenchInfos[requestId].end = std::chrono::steady_clock::now(); @@ -296,13 +318,107 @@ class Recorder std::string mOpCsvFile; }; // class Recorder +class ExecutorServer +{ +public: + ExecutorServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth, + batch_scheduler::SchedulerPolicy schedulerPolicy, BenchmarkParams const& benchmarkParams, + std::shared_ptr recorder, std::chrono::milliseconds waitSleep, + std::optional const staticEmulatedBatchSize, bool logIterationData) + : mRecorder(std::move(recorder)) + , mWaitSleep(waitSleep) + , mStaticEmulatedBatchSize(staticEmulatedBatchSize) + , mActiveCount(0) + { + + texec::SchedulerConfig schedulerConfig(batch_scheduler::batchManagerToExecSchedPolicy(schedulerPolicy)); + texec::KvCacheConfig kvCacheConfig(benchmarkParams.enableBlockReuse, benchmarkParams.maxTokensInPagedKvCache, + std::nullopt, std::nullopt, benchmarkParams.freeGpuMemoryFraction, false); + texec::ExecutorConfig executorConfig(maxBeamWidth, schedulerConfig, kvCacheConfig, + benchmarkParams.enableChunkedContext, true, benchmarkParams.enableTrtOverlap); + + mExecutor = std::make_shared(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig); + } + + ~ExecutorServer() {} + + void enqueue(std::vector requests, bool warmup = false) + { + try + { + std::vector inputLengths, maxNewTokens; + for (auto const& request : requests) + { + inputLengths.push_back(request.getInputTokenIds().size()); + maxNewTokens.push_back(request.getMaxNewTokens()); + } + auto const start = std::chrono::steady_clock::now(); + auto reqIds = mExecutor->enqueueRequests(std::move(requests)); + for (int req = 0; req < reqIds.size(); ++req) + { + if (!warmup) + { + mRecorder->recordStart(inputLengths.at(req), maxNewTokens.at(req), reqIds.at(req), start); + } + mActiveCount++; + } + } + catch (const std::exception& e) + { + TLLM_THROW("%s", e.what()); + } + return; + } + + void waitForResponses(std::optional numRequests, bool warmup = false) + { + SizeType numFinished = 0; + while (mActiveCount || (numRequests && numFinished < numRequests.value())) + { + auto responses = mExecutor->awaitResponses(std::nullopt, mWaitSleep); + for (auto const& response : responses) + { + if (response.hasError()) + { + // This request failed for some reason, get error msg + std::string errStr = "Request id " + std::to_string(response.getRequestId()) + " failed with err " + + response.getErrorMsg(); + TLLM_THROW(errStr); + } + else if (response.getResult().isFinal) + { + auto reqId = response.getRequestId(); + mActiveCount--; + numFinished++; + if (!warmup) + { + mRecorder->recordEnd(reqId); + } + } + } + } + } + + void shutdown() + { + mExecutor->shutdown(); + } + +private: + std::shared_ptr mExecutor; + std::shared_ptr mRecorder; + std::chrono::milliseconds mWaitSleep; + std::optional mStaticEmulatedBatchSize; + std::atomic mActiveCount; +}; // class ExecutorServer + class GptServer { public: GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth, batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams, std::shared_ptr recorder, std::optional terminateReqId, std::chrono::milliseconds waitSleep, - std::optional const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs) + std::optional const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs, bool logIterationData) : mRecorder(std::move(recorder)) , mTerminateReqId(terminateReqId) , mWaitSleep(waitSleep) @@ -312,9 +428,9 @@ class GptServer , mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs) , mActiveCount(0) { - ReturnBatchManagerStatsCallback iterationDataCallback = [this, &optionalParams](std::string const& log) + ReturnBatchManagerStatsCallback iterationDataCallback = [this, &logIterationData](std::string const& log) { - if (optionalParams.logIterationData) + if (logIterationData) { TLLM_LOG_INFO(log); } @@ -396,12 +512,8 @@ class GptServer auto rank = comm.getRank(); if (rank == 0) { - auto numNewWorkItems = std::min(static_cast(mWorkItemsQueue.numPendingWorkItems()), + auto const numNewWorkItems = std::min(static_cast(mWorkItemsQueue.numPendingWorkItems()), static_cast(max_num_requests)); - if (world_size > 1) - { - comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0); - } bool readyForNextBatch = numNewWorkItems > 0; if (mStaticEmulatedBatchSize) @@ -446,18 +558,21 @@ class GptServer sendResponse(workItem->requestId(), {}, true, warnStr); } } - if (world_size > 1) + } + if (world_size > 1) + { + auto numNewWorkItems = static_cast(rval.size()); + comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0); + + std::vector packed; + for (auto const& ir : rval) { - std::vector packed; - for (auto const& ir : rval) - { - auto vpacked = ir->serialize(); - packed.push_back(static_cast(vpacked.size())); - packed.insert( - packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end())); - } - comm.bcast(packed, 0); + auto vpacked = ir->serialize(); + packed.push_back(static_cast(vpacked.size())); + packed.insert( + packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end())); } + comm.bcast(packed, 0); } } else @@ -581,15 +696,38 @@ std::shared_ptr makeRequest(std::uint64_t reqId, Sample const& return request; } +texec::Request makeExecutorRequest(Sample const& sample, SizeType const& beamWidth, + std::optional const& eosId, std::optional const& padId, bool streaming = false, + bool const& returnContextLogits = false, bool const& returnGenerationLogits = false) +{ + auto samplingConfig = texec::SamplingConfig{beamWidth}; + auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false}; + return texec::Request(sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId); +} + void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType, std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp, - const std::optional& eosId, const std::optional& padId, - TrtGptModelOptionalParams const& optionalParams, batch_scheduler::SchedulerPolicy schedulerPolicy, - std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits, - std::optional const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs) + std::optional const& eosId, std::optional const& padId, BenchmarkParams const& benchmarkParams, + batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits, + bool returnGenerationLogits, std::optional const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs, + bool logIterationData) { auto const worldConfig = WorldConfig::mpi(); + TrtGptModelOptionalParams optionalParams; + + if (benchmarkParams.maxTokensInPagedKvCache) + { + optionalParams.kvCacheConfig.maxTokens = benchmarkParams.maxTokensInPagedKvCache; + } + if (benchmarkParams.freeGpuMemoryFraction) + { + optionalParams.kvCacheConfig.freeGpuMemoryFraction = benchmarkParams.freeGpuMemoryFraction; + } + optionalParams.kvCacheConfig.enableBlockReuse = benchmarkParams.enableBlockReuse; + optionalParams.enableChunkedContext = benchmarkParams.enableChunkedContext; + optionalParams.enableTrtOverlap = benchmarkParams.enableTrtOverlap; + BufferManager bufferManager{std::make_shared()}; // the stream is not used ITensor::SharedPtr beamWidthTensor{ @@ -603,7 +741,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType auto recorder = std::make_shared(opCsvFile); uint64_t terminateReqId = numSamples + 1; auto gptServer = std::make_shared(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams, - recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, staticEmulatedTimeoutMs); + recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, staticEmulatedTimeoutMs, logIterationData); ITensor::SharedPtr eosIdTensor{ eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr}; @@ -660,6 +798,109 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType gptServer->waitBatchManager(); } +void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType modelType, + std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp, + std::optional const& eosId, std::optional const& padId, BenchmarkParams const& benchmarkParams, + batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits, + bool returnGenerationLogits, std::optional const staticEmulatedBatchSize, bool logIterationData) +{ + // Check that mpi size is 1 for now + auto const worldConfig = WorldConfig::mpi(); + if (worldConfig.getSize() > 1) + { + TLLM_THROW("benchmarkExecutor does not yet support mpiSize > 1"); + } + + // Load dataset + const auto samples = parseWorkloadJson(datasetPath, maxNumSamples); + const auto numSamples = samples.size(); + + auto recorder = std::make_shared(opCsvFile); + + auto executorServer = std::make_shared(engineDir, modelType, beamWidth, schedulerPolicy, + benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData); + + if (worldConfig.getRank() == 0) + { + // Warm up + { + std::vector requests; + for (auto i = 0; i < warmUp; ++i) + { + requests.emplace_back(makeExecutorRequest(samples[0], beamWidth, eosId, padId, + benchmarkParams.streaming, returnContextLogits, returnGenerationLogits)); + } + executorServer->enqueue(std::move(requests), true); + executorServer->waitForResponses(warmUp, true); + } + + // Benchmark + { + // Create requests + recorder->initialize(); + std::vector requests; + std::vector delays; + for (std::size_t i = 0; i < numSamples; ++i) + { + requests.emplace_back(makeExecutorRequest(samples[i], beamWidth, eosId, padId, + benchmarkParams.streaming, returnContextLogits, returnGenerationLogits)); + delays.push_back(static_cast(samples[i].delay * 1000)); + } + + bool hasDelay = std::any_of(delays.begin(), delays.end(), [](const auto& delay) { return delay > 0; }); + if (hasDelay && staticEmulatedBatchSize) + { + TLLM_THROW("Executor benchmark doesn't support delays with emulated static batch sizes"); + } + + if (!hasDelay) + { + if (!staticEmulatedBatchSize) + { + executorServer->enqueue(std::move(requests)); + executorServer->waitForResponses(numSamples); + } + else + { + SizeType numRequests = requests.size(); + SizeType maxBatchSize = staticEmulatedBatchSize.value(); + for (SizeType req = 0; req < numRequests; req += maxBatchSize) + { + auto batchSize = std::min(maxBatchSize, numRequests - req); + + std::vector requestsBatch(std::make_move_iterator(requests.begin() + req), + std::make_move_iterator(requests.begin() + req + batchSize)); + // Enqueue in batches + executorServer->enqueue(std::move(requestsBatch)); + // Wait for current batch to be done + executorServer->waitForResponses(batchSize); + } + } + } + else + { + // Launch a thread that will wait for responses + std::thread waitThread( + [numSamples, executorServer]() { executorServer->waitForResponses(numSamples); }); + // Enqueue requests one by one + for (std::size_t i = 0; i < numSamples; ++i) + { + executorServer->enqueue({std::move(requests.at(i))}); + std::this_thread::sleep_for(std::chrono::milliseconds(delays.at(i))); + } + waitThread.join(); + } + } + recorder->finalize(); + recorder->calculateMetrics(); + recorder->report(); + recorder->writeOpMetricsToCsv(); + // Send terminateReqId to terminate servers on all ranks + // Sever on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases + // gptServer->enqueue(std::make_shared(terminateReqId)); + } +} + } // namespace int main(int argc, char* argv[]) @@ -671,6 +912,8 @@ int main(int argc, char* argv[]) options.add_options()( "m,model", "Model name specified for engines.", cxxopts::value()->default_value("gpt_350m")); options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value()); + options.add_options()( + "api", "API type: gptManager or executor.", cxxopts::value()->default_value("gptManager")); options.add_options()( "type", "Batching type: IFB or V1(non-IFB) batching.", cxxopts::value()->default_value("IFB")); options.add_options()("dataset", "Dataset that is used for benchmarking BatchManager.", @@ -689,14 +932,17 @@ int main(int argc, char* argv[]) options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value()); options.add_options()( "kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value()); + options.add_options()("enable_trt_overlap", "Overlap TRT context preparation and execution", + cxxopts::value()->default_value("false")); + options.add_options()("streaming", "Operate in streaming mode", cxxopts::value()->default_value("false")); options.add_options()( - "enable_trt_overlap", "Overlap TRT context preparation and execution", cxxopts::value()); - options.add_options()("enable_kv_cache_reuse", "Enables the KV cache reuse.", cxxopts::value()); - options.add_options()("enable_chunked_context", "Whether to enable context chunking.", cxxopts::value()); - options.add_options()( - "return_context_logits", "Whether to return context logits.", cxxopts::value()->default_value("0")); + "enable_kv_cache_reuse", "Enables the KV cache reuse.", cxxopts::value()->default_value("false")); + options.add_options()("enable_chunked_context", "Whether to enable context chunking.", + cxxopts::value()->default_value("false")); options.add_options()( - "return_generation_logits", "Whether to return generation logits.", cxxopts::value()->default_value("0")); + "return_context_logits", "Whether to return context logits.", cxxopts::value()->default_value("false")); + options.add_options()("return_generation_logits", "Whether to return generation logits.", + cxxopts::value()->default_value("false")); options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.", cxxopts::value()->default_value("guaranteed_no_evict")); @@ -708,8 +954,8 @@ int main(int argc, char* argv[]) cxxopts::value()->default_value("500")); options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.", cxxopts::value()->default_value("error")); - options.add_options()( - "log_iteration_data", "On each decoder iteration, print batch state metadata.", cxxopts::value()); + options.add_options()("log_iteration_data", "On each decoder iteration, print batch state metadata.", + cxxopts::value()->default_value("false")); options.add_options()("wait_sleep", "Specify how many milliseconds to sleep each iteration of waitForEmpty loop.", cxxopts::value()->default_value("25")); @@ -729,6 +975,9 @@ int main(int argc, char* argv[]) return 1; } + // Argument: API + auto const api = result["api"].as(); + // Argument: Batching Type auto const type = result["type"].as(); TrtGptModelType modelType{TrtGptModelType::V1}; @@ -758,50 +1007,38 @@ int main(int argc, char* argv[]) // Argument: wait_sleep auto const waitSleep = std::chrono::milliseconds(result["wait_sleep"].as()); + BenchmarkParams benchmarkParams; - TrtGptModelOptionalParams optionalParams; // Argument: Max tokens in paged K-V Cache if (result.count("max_tokens_in_paged_kvcache")) { - optionalParams.kvCacheConfig.maxTokens = result["max_tokens_in_paged_kvcache"].as(); + benchmarkParams.maxTokensInPagedKvCache = result["max_tokens_in_paged_kvcache"].as(); } // Argument: K-V Cache Free Gpu Mem Fraction if (result.count("kv_cache_free_gpu_mem_fraction")) { - optionalParams.kvCacheConfig.freeGpuMemoryFraction = result["kv_cache_free_gpu_mem_fraction"].as(); + benchmarkParams.freeGpuMemoryFraction = result["kv_cache_free_gpu_mem_fraction"].as(); } // Argument: Enable TRT overlap - if (result.count("enable_trt_overlap")) - { - optionalParams.enableTrtOverlap = result["enable_trt_overlap"].as(); - } + benchmarkParams.enableTrtOverlap = result["enable_trt_overlap"].as(); + // Argument: Enable KV cache reuse - if (result.count("enable_kv_cache_reuse")) - { - optionalParams.kvCacheConfig.enableBlockReuse = result["enable_kv_cache_reuse"].as(); - } + benchmarkParams.enableBlockReuse = result["enable_kv_cache_reuse"].as(); + + // Argument: streaming + benchmarkParams.streaming = result["streaming"].as(); + // Argument: Enable batch stats output - if (result.count("log_iteration_data")) - { - optionalParams.logIterationData = result["log_iteration_data"].as(); - } + bool logIterationData = result["log_iteration_data"].as(); + // Argument: Enable chunked context - if (result.count("enable_chunked_context")) - { - optionalParams.enableChunkedContext = result["enable_chunked_context"].as(); - } + benchmarkParams.enableChunkedContext = result["enable_chunked_context"].as(); + // Argument: Enable return context logits - bool returnContextLogits = false; - if (result.count("return_context_logits")) - { - returnContextLogits = result["return_context_logits"].as(); - } + bool returnContextLogits = result["return_context_logits"].as(); + // Argument: Enable return context logits - bool returnGenerationLogits = false; - if (result.count("return_generation_logits")) - { - returnGenerationLogits = result["return_generation_logits"].as(); - } + bool returnGenerationLogits = result["return_generation_logits"].as(); std::optional padId; // Argument: Padding token id @@ -873,16 +1110,40 @@ int main(int argc, char* argv[]) initTrtLlmPlugins(logger.get()); - try + if (api == "gptManager") { - benchmarkGptManager(result["engine_dir"].as(), modelType, datasetPath, opCsvFile, maxNumSamples, - beamWidth, result["warm_up"].as(), eosId, padId, optionalParams, schedulerPolicy, waitSleep, - returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, staticEmulatedTimeout); + try + { + benchmarkGptManager(result["engine_dir"].as(), modelType, datasetPath, opCsvFile, + maxNumSamples, beamWidth, result["warm_up"].as(), eosId, padId, benchmarkParams, schedulerPolicy, + waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, staticEmulatedTimeout, + logIterationData); + } + catch (const std::exception& e) + { + TLLM_LOG_ERROR(e.what()); + return 1; + } } - catch (const std::exception& e) + else if (api == "executor") { - TLLM_LOG_ERROR(e.what()); + try + { + benchmarkExecutor(result["engine_dir"].as(), modelType, datasetPath, opCsvFile, maxNumSamples, + beamWidth, result["warm_up"].as(), eosId, padId, benchmarkParams, schedulerPolicy, waitSleep, + returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData); + } + catch (const std::exception& e) + { + TLLM_LOG_ERROR(e.what()); + return 1; + } + } + else + { + TLLM_LOG_ERROR("api parameter must be gptManager or executor"); return 1; } + return 0; } diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index dc7585465..56e61de3a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -60,6 +60,20 @@ else() message(STATUS "Importing batch manager") endif() +if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/tensorrt_llm/executor/CMakeLists.txt") + set(BUILD_EXECUTOR_DEFAULT ON) +else() + set(BUILD_EXECUTOR_DEFAULT OFF) +endif() + +option(BUILD_EXECUTOR "Build executor from source" ${BUILD_EXECUTOR_DEFAULT}) + +if(BUILD_EXECUTOR) + message(STATUS "Building executor") +else() + message(STATUS "Importing executor") +endif() + if(BUILD_PYT) message(STATUS "Building PyTorch") else() diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h index 5496f9d9f..fe2515c70 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h @@ -17,6 +17,7 @@ #pragma once +#include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/common.h" #include @@ -42,6 +43,13 @@ class KvCacheConfig { } + explicit KvCacheConfig(executor::KvCacheConfig const& kvCacheConfig) + : KvCacheConfig(kvCacheConfig.getMaxTokens(), kvCacheConfig.getMaxAttentionWindow(), + kvCacheConfig.getSinkTokenLength(), kvCacheConfig.getFreeGpuMemoryFraction(), + kvCacheConfig.getEnableBlockReuse(), kvCacheConfig.getUseUvm()) + { + } + std::optional maxTokens; std::optional maxAttentionWindow; std::optional sinkTokenLength; diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index c0d69dbff..2b8713a0f 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -17,6 +17,7 @@ #pragma once #include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/samplingConfig.h" @@ -58,7 +59,7 @@ class GenericLlmRequest std::optional loraConfig = std::nullopt, bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, std::optional> draftTokens = std::nullopt, - std::optional draftLogits = std::nullopt) + std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false) : mRequestId(requestId) , mPromptLen(inputTokens->size()) , mMaxNewTokens(maxNewTokens) @@ -68,7 +69,8 @@ class GenericLlmRequest , mEndId(endId) , mPadId(padId) , mSeqSlot(-1) - , mOrigPromptLen(inputTokens->size()) + , mOrigPromptLen(mPromptLen) + , mMaxSentTokenPos(mPromptLen - 1) , mEmbeddingBias(embeddingBias) , mBadWordsList(badWordsList) , mStopWordsList(stopWordsList) @@ -85,27 +87,112 @@ class GenericLlmRequest , mDraftLogits(draftLogits) , mReturnContextLogits(returnContextLogits) , mReturnGenerationLogits(returnGenerationLogits) + , mExcludeInputFromOutput(excludeInputFromOutput) { - mMaxSentTokenPos = mPromptLen - 1; - // Scatter the input tokens to other beam - mTokens = BeamTokens(mSamplingConfig.beamWidth, *inputTokens); + initialize(*inputTokens); + } - if ((mPromptEmbeddingTable.has_value() && !mPromptVocabSize.has_value()) - || (!mPromptEmbeddingTable.has_value() && mPromptVocabSize.has_value())) + GenericLlmRequest(RequestIdType requestId, executor::Request const& req) + : mRequestId(requestId) + , mPromptLen(req.getInputTokenIds().size()) + , mMaxNewTokens(req.getMaxNewTokens()) + , mSamplingConfig(req.getSamplingConfig(), req.getSpeculativeDecodingConfig()) + , mState(REQUEST_STATE_CONTEXT_INIT) + , mIsStreaming(req.getStreaming()) + , mEndId(req.getEndId()) + , mPadId(req.getPadId()) + , mSeqSlot(-1) + , mOrigPromptLen(mPromptLen) + , mMaxSentTokenPos(mPromptLen - 1) + , mReturnLogProbs(req.getOutputConfig().returnLogProbs) + , mContextChunkSize(std::nullopt) + , mContextCurrentPosition(0) + , mLogProbs(mSamplingConfig.beamWidth) + , mCumLogProbs(mSamplingConfig.beamWidth) + , mDraftTokens(std::make_shared()) + , mReturnContextLogits(req.getOutputConfig().returnContextLogits) + , mReturnGenerationLogits(req.getOutputConfig().returnGenerationLogits) + , mExcludeInputFromOutput(req.getOutputConfig().excludeInputFromOutput) + { + if (req.getEmbeddingBias()) { - std::string errStr - = "Prompt embedding table and prompt vocab size tensors must both be provided for requests with prompt " - "tuning enabled."; - TLLM_LOG_ERROR(errStr); - throw std::runtime_error(errStr); + mEmbeddingBias = executor::detail::toITensor(*(req.getEmbeddingBias().value())); + // Add leading 1 dimension since that's what IFB code expects + mEmbeddingBias.value()->unsqueeze(0); + } + if (req.getBadWords()) + { + mBadWordsList = createListTensor(req.getBadWords().value()); + } + if (req.getStopWords()) + { + mStopWordsList = createListTensor(req.getStopWords().value()); } - if (draftLogits.has_value() && !draftTokens.has_value()) + auto pTuningConfig = req.getPromptTuningConfig(); + if (pTuningConfig) { - std::string errStr = "Draft tokens must be specified when draft logits are given."; - TLLM_LOG_ERROR(errStr); - throw std::runtime_error(errStr); + mPromptEmbeddingTable = executor::detail::toITensor(*pTuningConfig.value().getEmbeddingTable()); + TLLM_CHECK(mPromptEmbeddingTable.value()->getShape().nbDims == 2); + mPromptVocabSize = mPromptEmbeddingTable.value()->getShape().d[0]; + mPromptEmbeddingTable.value()->unsqueeze(0); } + + auto loraConfig = req.getLoraConfig(); + if (loraConfig) + { + mLoraWeights = executor::detail::toITensor(*loraConfig.value().getWeights()); + mLoraWeights.value()->unsqueeze(0); + + mLoraConfig = executor::detail::toITensor(*loraConfig.value().getConfig()); + mLoraConfig.value()->unsqueeze(0); + } + + auto speculativeDecodingConfig = req.getSpeculativeDecodingConfig(); + if (speculativeDecodingConfig) + { + mDraftTokens = std::make_shared(speculativeDecodingConfig.value().getTokens()); + + if (speculativeDecodingConfig.value().getLogits()) + { + mDraftLogits = executor::detail::toITensor(*speculativeDecodingConfig.value().getLogits().value()); + } + + // NOTE: Draft acceptance threshold is stored in mSamplingConfig + } + + initialize(req.getInputTokenIds()); + } + + void validate(SizeType maxInputLen, SizeType maxSequenceLen) + { + if (mPromptLen > maxInputLen) + { + TLLM_THROW("Prompt length (%d) exceeds maximum input length (%d).", mPromptLen, maxInputLen); + } + + if (mPromptLen + mMaxNewTokens > maxSequenceLen) + { + auto const maxNewTokens = maxSequenceLen - mPromptLen; + TLLM_LOG_WARNING( + "Number of requested output tokens (%d) exceeds maximum sequence length (%d). " + "Number of requested output tokens is changed to (%d).", + mMaxNewTokens, maxSequenceLen, maxNewTokens); + mMaxNewTokens = maxNewTokens; + } + + if (mSamplingConfig.beamWidth <= 0) + { + TLLM_THROW( + "Requested value: %d for beamWidth is invalid. To de-activate beam searching " + "set beamWidth to 1 instead.", + mSamplingConfig.beamWidth); + } + } + + void setExcludeInputFromOutput(bool exclude) + { + mExcludeInputFromOutput = exclude; } /// @brief Get total number of tokens for this req (prompt + generated) @@ -236,7 +323,6 @@ class GenericLlmRequest else { SizeType newPromptLen = std::min(maxInputLen, mPromptLen + getMaxNumGeneratedTokens()); - TLLM_LOG_DEBUG("pause: id %lu, mPromptLen %d, newPromptLen %d", mRequestId, mPromptLen, newPromptLen); for (std::size_t beam = 0; beam < mTokens.size(); ++beam) { auto& beamTokens = mTokens.at(beam); @@ -288,11 +374,31 @@ class GenericLlmRequest return mLoraWeights; } + void setLoraWeights(TensorPtr weights) + { + mLoraWeights = weights; + } + + void clearLoraWeights() + { + mLoraWeights = std::nullopt; + } + std::optional getLoraConfig() const { return mLoraConfig; } + void setLoraConfig(TensorPtr config) + { + mLoraConfig = config; + } + + void clearLoraConfig() + { + mLoraConfig = std::nullopt; + } + std::optional getEmbeddingBias() const { return mEmbeddingBias; @@ -389,6 +495,12 @@ class GenericLlmRequest mContextLogitsHost = std::move(contextLogitsHost); } + void allocContextLogitsHost(SizeType vocabSizePadded, nvinfer1::DataType logitsDataType) + { + mContextLogitsHost = runtime::BufferManager::pinned( + runtime::ITensor::makeShape({mPromptLen, vocabSizePadded}), logitsDataType); + } + TensorPtr const& getGenerationLogitsHost() const { return mGenerationLogitsHost; @@ -399,6 +511,12 @@ class GenericLlmRequest mGenerationLogitsHost = std::move(generationLogitsHost); } + void allocGenerationLogitsHost(SizeType vocabSizePadded, nvinfer1::DataType logitsDataType) + { + mGenerationLogitsHost = runtime::BufferManager::pinned( + runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}), logitsDataType); + } + std::vector const& getGenerationLogitsFragments() const { return mGenerationLogitsFragments; @@ -498,6 +616,84 @@ class GenericLlmRequest } } + /// @brief Create a Response from the current state of the request + /// @return An optional Response + std::optional createResponse() + { + if (mState == batch_manager::REQUEST_STATE_GENERATION_COMPLETE + || (mIsStreaming && mState == batch_manager::REQUEST_STATE_GENERATION_IN_PROGRESS)) + { + executor::Result result; + result.isFinal = mState == batch_manager::REQUEST_STATE_GENERATION_COMPLETE ? true : false; + + auto nbBeams = mSamplingConfig.beamWidth; + auto maxNbTokens = getMaxBeamNumTokens(); + // FIXME(nkorobov): For streaming we do not allow beam search and + // streaming index calculation here applies only for sampling + int nbTokensOut = mIsStreaming ? 1 : maxNbTokens; + if (mExcludeInputFromOutput && !mIsStreaming) + { + nbTokensOut -= getOrigPromptLen(); + } + + result.outputTokenIds.resize(nbBeams); + SizeType tokenPos = maxNbTokens - nbTokensOut; + + bool shouldSendResponse = (mState == batch_manager::REQUEST_STATE_GENERATION_COMPLETE) + || (mIsStreaming && tokenPos > getMaxSentTokenPos()); + + if (!shouldSendResponse) + { + return std::nullopt; + } + else + { + for (SizeType beam = 0; beam < nbBeams; ++beam) + { + auto tokens = getTokens(beam); + auto nbTokens = mIsStreaming ? (tokenPos - getMaxSentTokenPos()) : tokens.size(); + if (mExcludeInputFromOutput && !mIsStreaming) + { + nbTokens -= getOrigPromptLen(); + } + if (nbTokens > 0) + { + result.outputTokenIds.at(beam).assign( + tokens.data() + tokenPos, tokens.data() + tokenPos + nbTokens); + } + } + + if (returnLogProbs()) + { + result.cumLogProbs = getCumLogProbs(); + result.logProbs = getLogProbs(); + } + + if (getReturnContextLogits()) + { + result.contextLogits + = std::make_shared(executor::detail::ofITensor(getContextLogitsHost())); + } + + if (getReturnGenerationLogits()) + { + result.generationLogits + = std::make_shared(executor::detail::ofITensor(getGenerationLogitsHost())); + } + + // Update position of last sent response + mMaxSentTokenPos = tokenPos; + + auto response = executor::Response(mRequestId, std::move(result)); + return response; + } + } + else + { + return std::nullopt; + } + } + RequestIdType mRequestId; SizeType mPromptLen; SizeType mMaxNewTokens; @@ -545,6 +741,55 @@ class GenericLlmRequest TensorPtr mGenerationLogits; // [beam_size, mMaxNewTokens, vocab_size_padded] TensorPtr mGenerationLogitsHost; std::vector mGenerationLogitsFragments; + + bool mExcludeInputFromOutput; + +private: + void initialize(VecTokens const& inputTokens) + { + // Scatter the input tokens to other beam + mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens); + + if ((mPromptEmbeddingTable.has_value() && !mPromptVocabSize.has_value()) + || (!mPromptEmbeddingTable.has_value() && mPromptVocabSize.has_value())) + { + std::string errStr + = "Prompt embedding table and prompt vocab size tensors must both be provided for requests with " + "prompt " + "tuning enabled."; + TLLM_THROW(errStr); + } + + if (mDraftLogits.has_value() && mDraftTokens->empty()) + { + TLLM_THROW("Draft tokens must be specified when draft logits are given."); + } + } + + TensorPtr createListTensor(std::list const& wordsList) + { + std::vector offsets; + VecTokens words; + SizeType offsetCnt = 0; + for (auto const& tokens : wordsList) + { + offsetCnt += tokens.size(); + offsets.push_back(offsetCnt); + words.insert(words.end(), tokens.begin(), tokens.end()); + } + offsets.resize(words.size(), -1); + + SizeType numWords = static_cast(words.size()); + auto shape = runtime::ITensor::makeShape({2, numWords}); + auto tensor = runtime::BufferManager::pinnedPool(shape, nvinfer1::DataType::kINT32); + auto data = runtime::bufferCast(*tensor); + std::memcpy(data, words.data(), numWords * sizeof(int32_t)); + std::memcpy(data + numWords, offsets.data(), numWords * sizeof(int32_t)); + // Add leading dim of 1 + tensor->unsqueeze(0); + + return tensor; + } }; class LlmRequest : public GenericLlmRequest @@ -568,10 +813,15 @@ class LlmRequest : public GenericLlmRequest std::optional loraConfig = std::nullopt, bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, std::optional> draftTokens = std::nullopt, - std::optional draftLogits = std::nullopt) + std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false) : Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, loraWeights, loraConfig, returnLogProbs, - returnContextLogits, returnGenerationLogits, draftTokens, draftLogits) + returnContextLogits, returnGenerationLogits, draftTokens, draftLogits, excludeInputFromOutput) + { + } + + LlmRequest(RequestIdType requestId, executor::Request const& Request) + : Base(requestId, Request) { } diff --git a/cpp/include/tensorrt_llm/batch_manager/schedulerPolicy.h b/cpp/include/tensorrt_llm/batch_manager/schedulerPolicy.h index 2ec1f4356..8910e5a0c 100644 --- a/cpp/include/tensorrt_llm/batch_manager/schedulerPolicy.h +++ b/cpp/include/tensorrt_llm/batch_manager/schedulerPolicy.h @@ -16,6 +16,8 @@ #pragma once +#include "tensorrt_llm/executor/types.h" + namespace tensorrt_llm::batch_manager::batch_scheduler { @@ -25,4 +27,8 @@ enum class SchedulerPolicy GUARANTEED_NO_EVICT, }; +SchedulerPolicy execToBatchManagerSchedPolicy(executor::SchedulerPolicy policy); + +executor::SchedulerPolicy batchManagerToExecSchedPolicy(SchedulerPolicy policy); + } // namespace tensorrt_llm::batch_manager::batch_scheduler diff --git a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h index b8e87e3d7..751f57d81 100644 --- a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h +++ b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h @@ -18,6 +18,7 @@ #pragma once #include "tensorrt_llm/batch_manager/kvCacheConfig.h" +#include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/decodingMode.h" @@ -36,23 +37,29 @@ class TrtGptModelOptionalParams explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{}, bool enableTrtOverlap = false, std::optional> const& deviceIds = std::nullopt, - bool normalizeLogProbs = true, bool logIterationData = false, bool enableChunkedContext = false, + bool normalizeLogProbs = true, bool enableChunkedContext = false, std::optional const& decodingMode = std::nullopt) : kvCacheConfig{kvCacheConfig} , enableTrtOverlap{enableTrtOverlap} , deviceIds(deviceIds) , normalizeLogProbs{normalizeLogProbs} - , logIterationData{logIterationData} , enableChunkedContext{enableChunkedContext} , decodingMode{decodingMode} { } + explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig) + : TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), + executorConfig.getEnableTrtOverlap(), executorConfig.getDeviceIds(), executorConfig.getNormalizeLogProbs(), + executorConfig.getEnableChunkedContext()) + { + } + KvCacheConfig kvCacheConfig; + bool enableTrtOverlap; std::optional> deviceIds; bool normalizeLogProbs; - bool logIterationData; bool enableChunkedContext; std::optional decodingMode; }; diff --git a/cpp/include/tensorrt_llm/common/arrayView.h b/cpp/include/tensorrt_llm/common/arrayView.h new file mode 100644 index 000000000..cd409e684 --- /dev/null +++ b/cpp/include/tensorrt_llm/common/arrayView.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace tensorrt_llm::common +{ + +//! +//! \brief A very rudimentary implementation of std::span. +//! +template +class ArrayView +{ +public: + using value_type = T; + using size_type = std::size_t; + using reference = value_type&; + using const_reference = value_type const&; + using pointer = T*; + using const_pointer = T const*; + using iterator = pointer; + using const_iterator = const_pointer; + + ArrayView(T* data, size_type size) + : mData{data} + , mSize{size} + { + } + + [[nodiscard]] iterator begin() + { + return mData; + } + + [[nodiscard]] iterator end() + { + return mData + mSize; + } + + [[nodiscard]] const_iterator begin() const + { + return mData; + } + + [[nodiscard]] const_iterator end() const + { + return mData + mSize; + } + + [[nodiscard]] const_iterator cbegin() const + { + return mData; + } + + [[nodiscard]] const_iterator cend() const + { + return mData + mSize; + } + + [[nodiscard]] size_type size() const + { + return mSize; + } + + [[nodiscard]] reference operator[](size_type index) + { + return mData[index]; + } + + [[nodiscard]] const_reference operator[](size_type index) const + { + return mData[index]; + } + +private: + T* mData; + size_type mSize; +}; + +} // namespace tensorrt_llm::common diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h new file mode 100644 index 000000000..b65c1f461 --- /dev/null +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -0,0 +1,416 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/executor/tensor.h" +#include "tensorrt_llm/executor/types.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::executor +{ + +/// @brief Sampling configuration +class SamplingConfig +{ +public: + SamplingConfig(SizeType beamWidth = 1, std::optional topK = std::nullopt, + std::optional topP = std::nullopt, std::optional topPMin = std::nullopt, + std::optional topPResetIds = std::nullopt, std::optional topPDecay = std::nullopt, + std::optional randomSeed = std::nullopt, std::optional temperature = std::nullopt, + std::optional minLength = std::nullopt, + std::optional beamSearchDiversityRate = std::nullopt, + std::optional repetitionPenalty = std::nullopt, + std::optional presencePenalty = std::nullopt, + std::optional frequencyPenalty = std::nullopt, + std::optional lengthPenalty = std::nullopt); + + ~SamplingConfig(); + + [[nodiscard]] SizeType getBeamWidth() const; + [[nodiscard]] std::optional getTopK() const; + [[nodiscard]] std::optional getTopP() const; + [[nodiscard]] std::optional getTopPMin() const; + [[nodiscard]] std::optional getTopPResetIds() const; + [[nodiscard]] std::optional getTopPDecay() const; + [[nodiscard]] std::optional getRandomSeed() const; + [[nodiscard]] std::optional getTemperature() const; + [[nodiscard]] std::optional getMinLength() const; + [[nodiscard]] std::optional getBeamSearchDiversityRate() const; + [[nodiscard]] std::optional getRepetitionPenalty() const; + [[nodiscard]] std::optional getPresencePenalty() const; + [[nodiscard]] std::optional getFrequencyPenalty() const; + [[nodiscard]] std::optional getLengthPenalty() const; + +private: + SizeType mBeamWidth; + std::optional mTopK; + std::optional mTopP; + std::optional mTopPMin; + std::optional mTopPResetIds; + std::optional mTopPDecay; + std::optional mRandomSeed; + std::optional mTemperature; + std::optional mMinLength; + std::optional mBeamSearchDiversityRate; + std::optional mRepetitionPenalty; + std::optional mPresencePenalty; + std::optional mFrequencyPenalty; + std::optional mLengthPenalty; +}; + +/// @brief Configuration that controls the outputs of a Result +struct OutputConfig +{ + bool returnLogProbs{false}; + bool returnContextLogits{false}; + bool returnGenerationLogits{false}; + bool excludeInputFromOutput{false}; +}; + +/// @brief Configuration for speculative decoding. Allows to include draft tokens, draft logits and specify acceptance +/// threshold +class SpeculativeDecodingConfig +{ +public: + explicit SpeculativeDecodingConfig(VecTokens tokens, std::optional logits = std::nullopt, + std::optional acceptanceThreshold = std::nullopt); + + ~SpeculativeDecodingConfig(); + + [[nodiscard]] VecTokens getTokens() const; + [[nodiscard]] std::optional getLogits() const; + [[nodiscard]] std::optional getAcceptanceThreshold() const; + +private: + VecTokens mTokens; + std::optional mLogits; + std::optional mAcceptanceThreshold; +}; + +/// @brief Configuration for prompt tuning +class PromptTuningConfig +{ +public: + /// @brief + /// @param embeddingTable The prompt embedding table. Data type must match model weights. Shape [vocabSize, + /// hiddenSize] + /// @param vocabSize + PromptTuningConfig(TensorPtr embeddingTable); + ~PromptTuningConfig(); + + [[nodiscard]] TensorPtr getEmbeddingTable() const; + +private: + TensorPtr mEmbeddingTable; +}; + +/// @brief Configuration for LoRA +class LoraConfig +{ +public: + LoraConfig(TensorPtr weights, TensorPtr config); + ~LoraConfig(); + + [[nodiscard]] TensorPtr getWeights() const; + [[nodiscard]] TensorPtr getConfig() const; + +private: + TensorPtr mWeights; + TensorPtr mConfig; +}; + +/// @brief A class that holds information about the request +class Request +{ +public: + /// @brief + /// @param inputTokenIds The input token ids + /// @param maxNewTokens The maximum number of tokens to generate + /// @param streaming // Indicates if the responses should be streamed or not + /// @param samplingConfig // The sampling configuration + /// @param outputConfig // The output configuration + /// @param endId // The end token id + /// @param padId // The pad token id + /// @param badWords // A list of bad words tokens. Each "word" can be composed of multiple tokens + /// @param stopWords // A list of stop words tokens. Each "word" can be composed of multiple tokens + /// @param embeddingBias // The embedding bias tensor. Expected type is kFP32 and shape is [vocab_size] + /// @param speculativeDecodingConfig // The speculative decoding configuration + /// @param pTuningConfig // The prompt tuning configuration + /// @param loraConfig // The LoRA configuration + Request(VecTokens inputTokenIds, SizeType maxNewTokens, bool streaming = false, + SamplingConfig samplingConfig = SamplingConfig(), OutputConfig outputConfig = OutputConfig(), + std::optional endId = std::nullopt, std::optional padId = std::nullopt, + std::optional> badWords = std::nullopt, + std::optional> stopWords = std::nullopt, + std::optional embeddingBias = std::nullopt, + std::optional speculativeDecodingConfig = std::nullopt, + std::optional pTuningConfig = std::nullopt, + std::optional loraConfig = std::nullopt); + + Request(Request const& other); + Request(Request&& other) noexcept; + Request& operator=(Request const& other); + Request& operator=(Request&& other) noexcept; + ~Request(); + + [[nodiscard]] VecTokens getInputTokenIds() const; + [[nodiscard]] SizeType getMaxNewTokens() const; + [[nodiscard]] bool getStreaming() const; + [[nodiscard]] SamplingConfig getSamplingConfig() const; + [[nodiscard]] OutputConfig getOutputConfig() const; + [[nodiscard]] std::optional getEndId() const; + [[nodiscard]] std::optional getPadId() const; + [[nodiscard]] std::optional> getBadWords() const; + [[nodiscard]] std::optional> getStopWords() const; + [[nodiscard]] std::optional getEmbeddingBias() const; + [[nodiscard]] std::optional getSpeculativeDecodingConfig() const; + [[nodiscard]] std::optional getPromptTuningConfig() const; + [[nodiscard]] std::optional getLoraConfig() const; + + void setStreaming(bool streaming); + void setSamplingConfig(SamplingConfig config); + void setOutputConfig(OutputConfig outputConfig); + void setEndId(SizeType endId); + void setPadId(SizeType padId); + void setBadWords(std::list badWords); + void setStopWords(std::list stopWords); + void setEmbeddingBias(TensorPtr); + void setSpeculativeDecodingConfig(SpeculativeDecodingConfig specDecodingConfig); + void setPromptTuningConfig(PromptTuningConfig pTuningConfig); + void setLoraConfig(LoraConfig loraConfig); + +private: + class Impl; + std::unique_ptr mImpl; +}; + +/// @brief Struct that holds the generation result +struct Result +{ + // Indicates if this is the final result for the request + bool isFinal; + + /// @brief The output tokens for each beam + BeamTokens outputTokenIds; + + std::optional cumLogProbs; // [beamSize] + std::optional> logProbs; // [beamSize, seqLen] + std::optional contextLogits; // [promptLen, vocab_size_padded] + std::optional generationLogits; // [beam_size, mMaxNewTokens, vocab_size_padded] +}; + +/// @brief Class that holds either an error or a result +class Response +{ +public: + Response(IdType requestId, std::string errorMsg); + Response(IdType requestId, Result Result); + + ~Response(); + Response(Response const& other); + Response(Response&& other) noexcept; + Response& operator=(Response const& other); + Response& operator=(Response&& other) noexcept; + + // Get the id of the request for which this response was generated + IdType getRequestId() const; + + // Indicates if this response has an error or not + bool hasError() const; + + // Get the error msg for this response + // Will throw an exception if hasError is false + std::string getErrorMsg() const; + + // Get the result for this response + // Will throw an exception if hasResult is true + Result getResult() const; + +private: + class Impl; + std::unique_ptr mImpl; +}; + +/// @brief Configuration class for the scheduler +class SchedulerConfig +{ +public: + explicit SchedulerConfig(SchedulerPolicy policy = SchedulerPolicy::kGUARANTEED_NO_EVICT); + ~SchedulerConfig(); + + [[nodiscard]] SchedulerPolicy getPolicy() const; + +private: + SchedulerPolicy mPolicy; +}; + +/// @brief Configuration class for the KV cache +class KvCacheConfig +{ +public: + KvCacheConfig(bool enableBlockReuse = false, std::optional maxTokens = std::nullopt, + std::optional maxAttentionWindow = std::nullopt, + std::optional sinkTokenLength = std::nullopt, + std::optional freeGpuMemoryFraction = std::nullopt, bool useUvm = false); + + [[nodiscard]] bool getEnableBlockReuse() const; + [[nodiscard]] std::optional getMaxTokens() const; + [[nodiscard]] std::optional getMaxAttentionWindow() const; + [[nodiscard]] std::optional getSinkTokenLength() const; + [[nodiscard]] std::optional getFreeGpuMemoryFraction() const; + [[nodiscard]] bool getUseUvm() const; + +private: + bool mEnableBlockReuse; + std::optional mMaxTokens; + std::optional mMaxAttentionWindow; + std::optional mSinkTokenLength; + std::optional mFreeGpuMemoryFraction; + bool mUseUvm; +}; + +SizeType const kDefaultIterStatsMaxIterations = 1000; + +/// @brief Configuration class for the model executor +class ExecutorConfig +{ +public: + ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig schedulerConfig = SchedulerConfig(), + KvCacheConfig kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false, bool normalizeLogProbs = true, + bool enableTrtOverlap = false, std::optional> deviceIds = std::nullopt, + SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations, + BatchingType batchingType = BatchingType::kINFLIGHT); + + [[nodiscard]] SizeType getMaxBeamWidth() const; + [[nodiscard]] SchedulerConfig getSchedulerConfig() const; + [[nodiscard]] KvCacheConfig getKvCacheConfig() const; + [[nodiscard]] bool getEnableChunkedContext() const; + [[nodiscard]] bool getNormalizeLogProbs() const; + [[nodiscard]] bool getEnableTrtOverlap() const; + [[nodiscard]] std::optional> getDeviceIds() const; + [[nodiscard]] SizeType getIterStatsMaxIterations() const; + [[nodiscard]] BatchingType getBatchingType() const; + + void setMaxBeamWidth(SizeType maxBeamWidth); + void setSchedulerConfig(SchedulerConfig schedulerConfig); + void setKvCacheConfig(KvCacheConfig kvCacheConfig); + void setEnableChunkedContext(bool enableChunkedContext); + void setNormalizeLogProbs(bool normalizeLogProbs); + void setEnableTrtOverlap(bool enableTrtOverlap); + void setDeviceIds(std::optional> deviceIds); + void setIterStatsMaxIterations(SizeType iterStatsMaxIterations); + void setBatchingType(BatchingType batchingType); + +private: + SizeType mMaxBeamWidth; + SchedulerConfig mSchedulerConfig; + KvCacheConfig mKvCacheConfig; + bool mEnableChunkedContext; + bool mNormalizeLogProbs; + bool mEnableTrtOverlap; + std::optional> mDeviceIds; + SizeType mIterStatsMaxIterations; + BatchingType mBatchingType; +}; + +/// TODO: +/// @brief A class to identify processes involved in the execution of a model +/// Currently only supports MPI communication +class Communicator +{ +public: + Communicator(CommunicatorType commType, CommMode mode, SizeType currentId, std::vector const& commIds, + std::optional orchestratorId){}; + ~Communicator() = default; +}; + +class Model; + +/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference +class Executor +{ + using RequestPtr = std::shared_ptr; + +public: + /// @brief + /// @param modelPath Path to the folder that defines the model to run + /// @param modelType The type of model + /// @param executorConfig The configuration for the executor + /// @param comm An optional inter-process communicator configuration + Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig executorConfig, + std::optional comm = std::nullopt); + + Executor(std::vector const& engineBuffer, std::string const& jsonConfigStr, ModelType modelType, + ExecutorConfig executorConfig, std::optional comm = std::nullopt); + + Executor( + std::shared_ptr model, ExecutorConfig executorConfig, std::optional comm = std::nullopt); + + ~Executor(); + + /// @brief Enqueue a new request + /// @param request The LLM request which contains input tokens and request parameters + /// @return A unique id that identifies the request + IdType enqueueRequest(Request request); + + /// @brief Enqueue a batch of request + std::vector enqueueRequests(std::vector requests); + + /// @brief Await for ready responses + /// @param id An optional request id. If not specified, responses for any request can be returned + /// @param timeout The maximum time to wait for new responses + /// @return A vector of responses + std::vector awaitResponses( + std::optional id = std::nullopt, std::optional timeout = std::nullopt); + + /// @brief Get the number of ready responses + /// @param id The request id + /// @return The number of ready responses + SizeType getNumResponsesReady(std::optional id = std::nullopt); + + /// @brief Cancel the request with provided request id + /// @param id The request id for which to cancel the response + void cancelRequest(IdType id); + + /// @brief Signals the server to shutdown + /// This call is blocking. Only returns when all requests have terminated or timeout has been reached + void shutdown(); + + /// @brief Returns the per-iterations statistics computed since last call to getLatestIterationStats + /// Contains at most iterStatsMaxIterations iterations + /// Will block until stats for at least one iteration are available + /// TODO: Should we use a class for iterationStats, i.e. std::deque + /// @return + std::deque getLatestIterationStats(); + +private: + class Impl; + std::unique_ptr mImpl; +}; + +} // namespace tensorrt_llm::executor diff --git a/cpp/include/tensorrt_llm/executor/tensor.h b/cpp/include/tensorrt_llm/executor/tensor.h new file mode 100644 index 000000000..8bf2851d6 --- /dev/null +++ b/cpp/include/tensorrt_llm/executor/tensor.h @@ -0,0 +1,272 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/executor/types.h" + +#include "tensorrt_llm/common/arrayView.h" +#include "tensorrt_llm/common/assert.h" + +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::runtime +{ +class ITensor; +class CudaStream; +} // namespace tensorrt_llm::runtime + +namespace tensorrt_llm::executor +{ + +class Tensor; + +namespace detail +{ +std::shared_ptr const& toITensor(Tensor const& tensor); +Tensor ofITensor(std::shared_ptr tensor); +} // namespace detail + +// A thin wrapper around span that supports constructions with an initializer list. +class Shape : public tensorrt_llm::common::ArrayView +{ +public: + using Base = tensorrt_llm::common::ArrayView; + using DimType = typename std::remove_cv_t; + + Shape() + : Base{nullptr, 0} {}; + + Shape(DimType const* data, Base::size_type size) + : Base{data, size} + { + } + + Shape(std::initializer_list dims) // NOLINT(*-explicit-constructor) + : Base{dims.begin(), dims.size()} + { + } +}; + +class Tensor +{ +public: + using CudaStreamPtr = std::shared_ptr; + + //! Allocate a cpu tensor with the given shape and data type. + //! + //! \param shape The shape of the tensor. + //! \param dataType The data type of the tensor. + static Tensor cpu(DataType dataType, Shape shape = {}); + + template + static Tensor cpu(Shape shape = {}) + { + return Tensor::cpu(getRuntimeType(), shape); + } + + [[nodiscard]] Tensor copyToCpu(Tensor::CudaStreamPtr stream = nullptr) const; + + //! Allocate a cpu tensor in pinned memory with the given shape and data type. + //! + //! \param shape The shape of the tensor. + //! \param dataType The data type of the tensor. + static Tensor pinned(DataType dataType, Shape shape = {}); + + template + static Tensor pinned(Shape shape = {}) + { + return Tensor::pinned(getRuntimeType(), shape); + } + + [[nodiscard]] Tensor copyToPinned(Tensor::CudaStreamPtr stream = nullptr) const; + + //! Allocate a cpu tensor in pooled pinned memory with the given shape and data type. + //! + //! \param shape The shape of the tensor. + //! \param dataType The data type of the tensor. + static Tensor pooledPinned(DataType dataType, Shape shape = {}); + + template + static Tensor pooledPinned(Shape shape = {}) + { + return Tensor::pooledPinned(getRuntimeType(), shape); + } + + [[nodiscard]] Tensor copyToPooledPinned(Tensor::CudaStreamPtr stream = nullptr) const; + + //! Allocate a tensor in managed memory (UVM) with the given shape and data type. + //! + //! \param shape The shape of the tensor. + //! \param dataType The data type of the tensor. + static Tensor managed(DataType dataType, Shape shape = {}); + + template + static Tensor managed(Shape shape = {}) + { + return Tensor::managed(getRuntimeType(), shape); + } + + [[nodiscard]] Tensor copyToManaged(Tensor::CudaStreamPtr stream = nullptr) const; + + //! Allocate a gpu tensor with the given shape and data type on a particular cuda stream. + //! + //! \param shape The shape of the tensor. + //! \param stream Specifies the CUDA stream on which to allocate the tensor for GPU memory. + //! \param dataType The data type of the tensor. + static Tensor gpu(DataType dataType, CudaStreamPtr stream, Shape shape = {}); + + template + static Tensor gpu(CudaStreamPtr stream, Shape shape = {}) + { + return Tensor::gpu(getRuntimeType(), std::move(stream), shape); + } + + [[nodiscard]] Tensor copyToGpu(Tensor::CudaStreamPtr stream) const; + + //! Wrap a data pointer into a tensor without taking ownership. + //! + //! \param shape The shape of the tensor. + //! \param dataType The data type of the tensor. + //! \param stream Specifies the CUDA stream on which to allocate the tensor for GPU memory. + static Tensor of(DataType dataType, void* data, Shape shape); + + //! Wrap a data pointer into a tensor without taking ownership. + //! + //! \param shape The shape of the tensor. + //! \param dataType The data type of the tensor. + //! \param stream Specifies the CUDA stream on which to allocate the tensor for GPU memory. + template + static Tensor of(T* data, Shape shape) + { + return of(getRuntimeType(), static_cast(data), shape); + } + + //! Wrap any container into a tensor without taking ownership. + //! + //! \param shape The shape of the tensor. + //! \param dataType The data type of the tensor. + //! \param stream Specifies the CUDA stream on which to allocate the tensor for GPU memory. + template + static Tensor of(T& data) + { + using DimType = Shape::DimType; + if constexpr (!std::is_same_v) + { + TLLM_CHECK(data.size() <= std::numeric_limits::max()); + } + return of(data.data(), {static_cast(data.size())}); + } + + Tensor() noexcept = default; + + ~Tensor() = default; + + Tensor(const Tensor& other) noexcept = default; + + Tensor(Tensor&& other) noexcept = default; + + Tensor& operator=(const Tensor& other) noexcept = default; + + Tensor& operator=(Tensor&& other) noexcept = default; + + //! + //! \brief Returns a pointer to underlying array. + //! + [[nodiscard]] void* getData(); + + //! + //! \brief Returns a pointer to underlying array. + //! + [[nodiscard]] void const* getData() const; + + //! + //! \brief Returns the data type of the buffer. + //! + [[nodiscard]] DataType getDataType() const; + + //! + //! \brief Returns the memory type of the buffer. + //! + [[nodiscard]] MemoryType getMemoryType() const; + + //! + //! \brief Returns the tensor dimensions. + //! + [[nodiscard]] Shape getShape() const; + + //! + //! \brief Returns the number of elements in the tensor. + //! + [[nodiscard]] std::size_t getSize() const; + + //! + //! \brief Returns the size of the tensor in bytes. + //! + [[nodiscard]] std::size_t getSizeInBytes() const; + + //! + //! \brief Set the entire memory to zero. + //! + //! \param stream Must be a valid CUDA stream if the memory type is GPU. + void setZero(CudaStreamPtr stream = nullptr); + + //! + //! \brief Copy the data and shape from another tensor. + //! + //! \param other A tensor to copy from. + //! \param stream Must be a valid CUDA stream if the memory type is GPU. + void setFrom(Tensor const& other, CudaStreamPtr stream = nullptr); + + explicit operator bool() const + { + return static_cast(mTensor); + } + + bool operator==(Tensor const& rhs) const + { + return mTensor == rhs.mTensor; + } + + bool operator!=(Tensor const& rhs) const + { + return !(rhs == *this); + } + +private: + using Impl = runtime::ITensor; + explicit Tensor(std::shared_ptr tensor); + + template + static DataType getRuntimeType() + { + return TypeTraits>::value; + } + + [[nodiscard]] Tensor copyTo(std::shared_ptr tensor, CudaStreamPtr stream) const; + + std::shared_ptr mTensor; + + friend std::shared_ptr const& detail::toITensor(Tensor const& tensor); + friend Tensor detail::ofITensor(std::shared_ptr tensor); +}; + +} // namespace tensorrt_llm::executor diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h new file mode 100644 index 000000000..34a7c6dc9 --- /dev/null +++ b/cpp/include/tensorrt_llm/executor/types.h @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#ifdef ENABLE_FP8 +#include +#endif +#ifdef ENABLE_BF16 +#include +#endif + +namespace tensorrt_llm::executor +{ + +class Request; +class Tensor; + +using TensorPtr = std::shared_ptr; +using SizeType = std::int32_t; +using FloatType = float; +using TokenIdType = std::int32_t; +using VecTokens = std::vector; +using BeamTokens = std::vector; +using IdType = std::uint64_t; +using RandomSeedType = std::uint64_t; +using VecLogProbs = std::vector; + +enum class DataType +{ + kBOOL, + kUINT8, + kINT8, + kINT32, + kINT64, + kBF16, + kFP8, + kFP16, + kFP32, + kUNKNOWN +}; + +//! \brief For converting a C++ data type to a `TrtLmmDataType`. +template +struct TypeTraits +{ +}; + +template <> +struct TypeTraits +{ + static constexpr auto value = DataType::kFP32; +}; + +template <> +struct TypeTraits +{ + static constexpr auto value = DataType::kFP16; +}; + +template <> +struct TypeTraits +{ + static constexpr auto value = DataType::kINT8; +}; + +template <> +struct TypeTraits +{ + static constexpr auto value = DataType::kINT32; +}; + +template <> +struct TypeTraits +{ + static constexpr auto value = DataType::kINT64; +}; + +template <> +struct TypeTraits +{ + static constexpr auto value = DataType::kBOOL; +}; + +template <> +struct TypeTraits +{ + static constexpr auto value = DataType::kUINT8; +}; + +#ifdef ENABLE_BF16 +template <> +struct TypeTraits<__nv_bfloat16> +{ + static constexpr auto value = DataType::kBF16; +}; +#endif + +#ifdef ENABLE_FP8 +template <> +struct TypeTraits<__nv_fp8_e4m3> +{ + static constexpr auto value = DataType::kFP8; +}; +#endif + +template +struct TypeTraits +{ + // Pointers are stored as int64_t. + static constexpr auto value = DataType::kINT64; +}; + +enum class MemoryType +{ + kCPU, + kCPU_PINNED, + kGPU, + kUVM, + kUNKNOWN +}; + +enum class ModelType +{ + kDECODER_ONLY = 0, +}; + +enum class BatchingType +{ + kSTATIC = 0, + kINFLIGHT = 1, + kINFLIGHT_UNFUSED = 2, +}; + +enum class SchedulerPolicy +{ + kMAX_UTILIZATION = 0, + kGUARANTEED_NO_EVICT = 1, +}; + +enum class CommunicatorType +{ + kMPI = 0 +}; + +enum class CommMode +{ + kLEADER, // With the leader mode, only the leader will be returning from the executor constructor and + // therefore only the leader can enqueue requests and get responses + kORCHESTRATOR, // With the orchestrator mode, only the orchestrator will be returning from the executor constructor + // and therefore only the leader can enqueue requests and get responses The orchestrator doesn't + // participate in the computations + kALL, // With the ALL mode, all participants are expected to make the same calls to the executor API + // So they all need to send the same requests + // Responses will be the same for all participants +}; + +} // namespace tensorrt_llm::executor diff --git a/cpp/include/tensorrt_llm/runtime/bufferManager.h b/cpp/include/tensorrt_llm/runtime/bufferManager.h index 607bfe6e2..01e4b8525 100644 --- a/cpp/include/tensorrt_llm/runtime/bufferManager.h +++ b/cpp/include/tensorrt_llm/runtime/bufferManager.h @@ -73,10 +73,10 @@ class BufferManager [[nodiscard]] static ITensorPtr pinnedPool(nvinfer1::Dims dims, nvinfer1::DataType type = kBYTE_TYPE); //! \brief Allocates an `IBuffer` of the given size in UVM. - [[nodiscard]] IBufferPtr managed(std::size_t size, nvinfer1::DataType type = kBYTE_TYPE) const; + [[nodiscard]] static IBufferPtr managed(std::size_t size, nvinfer1::DataType type = kBYTE_TYPE); //! \brief Allocates an `ITensor` of the given dimensions in UVM. - [[nodiscard]] ITensorPtr managed(nvinfer1::Dims dims, nvinfer1::DataType type = kBYTE_TYPE) const; + [[nodiscard]] static ITensorPtr managed(nvinfer1::Dims dims, nvinfer1::DataType type = kBYTE_TYPE); //! \brief Allocates an `IBuffer` of the given size and memory type. [[nodiscard]] IBufferPtr allocate( diff --git a/cpp/include/tensorrt_llm/runtime/cudaStream.h b/cpp/include/tensorrt_llm/runtime/cudaStream.h index b112aa4e9..508708eab 100644 --- a/cpp/include/tensorrt_llm/runtime/cudaStream.h +++ b/cpp/include/tensorrt_llm/runtime/cudaStream.h @@ -60,6 +60,12 @@ class CudaStream mStream = StreamPtr{stream, Deleter{ownsStream}}; } + //! Construct with an existing cuda stream or the default stream by passing nullptr. + explicit CudaStream(cudaStream_t stream) + : CudaStream{stream, tensorrt_llm::common::getDevice(), false} + { + } + //! Returns the device on which the stream was created. [[nodiscard]] int getDevice() const { diff --git a/cpp/include/tensorrt_llm/runtime/iBuffer.h b/cpp/include/tensorrt_llm/runtime/iBuffer.h index 5aa5ba667..d02bec3a1 100644 --- a/cpp/include/tensorrt_llm/runtime/iBuffer.h +++ b/cpp/include/tensorrt_llm/runtime/iBuffer.h @@ -16,6 +16,9 @@ #pragma once +#include "tensorrt_llm/common/arrayView.h" +#include "tensorrt_llm/common/dataType.h" + #include #include @@ -33,8 +36,6 @@ #include #include -#include "tensorrt_llm/common/dataType.h" - namespace tensorrt_llm::runtime { @@ -561,21 +562,14 @@ T* bufferCast(IBuffer& buffer) } template -class BufferRange +class BufferRange : public tensorrt_llm::common::ArrayView { public: - using value_type = T; - using size_type = std::size_t; - using reference = value_type&; - using const_reference = value_type const&; - using pointer = T*; - using const_pointer = T const*; - using iterator = pointer; - using const_iterator = const_pointer; + using Base = tensorrt_llm::common::ArrayView; + using typename Base::size_type; BufferRange(T* data, size_type size) - : mData{data} - , mSize{size} + : Base{data, size} { } @@ -583,65 +577,6 @@ class BufferRange : BufferRange(bufferCast(buffer), buffer.getSize()) { } - - iterator begin() - { - return mData; - } - - iterator end() - { - return mData + mSize; - } - - const_iterator begin() const - { - return mData; - } - - const_iterator end() const - { - return mData + mSize; - } - - const_iterator cbegin() - { - return mData; - } - - const_iterator cend() - { - return mData + mSize; - } - - const_iterator cbegin() const - { - return mData; - } - - const_iterator cend() const - { - return mData + mSize; - } - - [[nodiscard]] size_type size() const - { - return mSize; - } - - reference operator[](size_type index) - { - return mData[index]; - } - - const_reference operator[](size_type index) const - { - return mData[index]; - } - -private: - T* mData; - size_type mSize; }; //! \brief Utility function to print a buffer. diff --git a/cpp/include/tensorrt_llm/runtime/samplingConfig.h b/cpp/include/tensorrt_llm/runtime/samplingConfig.h index 0ea5e430e..d7a49b713 100644 --- a/cpp/include/tensorrt_llm/runtime/samplingConfig.h +++ b/cpp/include/tensorrt_llm/runtime/samplingConfig.h @@ -16,8 +16,10 @@ #pragma once +#include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/common.h" +#include #include #include @@ -57,6 +59,9 @@ class SamplingConfig return std::make_optional>(values); } + template + using Vec = std::vector; + public: explicit SamplingConfig(SizeType beamWidth = 1) : beamWidth{beamWidth} @@ -86,6 +91,39 @@ class SamplingConfig = fuseValues(configs, [&configs](SizeType ci) { return configs[ci].draftAcceptanceThreshold; }); } + explicit SamplingConfig(executor::SamplingConfig const& samplingConfig, + std::optional const& specDecodingConfig) + : beamWidth{samplingConfig.getBeamWidth()} + { + + if (specDecodingConfig && specDecodingConfig.value().getAcceptanceThreshold()) + { + draftAcceptanceThreshold = Vec{specDecodingConfig.value().getAcceptanceThreshold().value()}; + } + +#define SET_FROM_OPTIONAL(varName, VarName, VarType) \ + \ + if (samplingConfig.get##VarName()) \ + { \ + varName = Vec{samplingConfig.get##VarName().value()}; \ + } + + SET_FROM_OPTIONAL(topK, TopK, SizeType) + SET_FROM_OPTIONAL(topP, TopP, FloatType) + SET_FROM_OPTIONAL(topPMin, TopPMin, FloatType) + SET_FROM_OPTIONAL(topPResetIds, TopPResetIds, SizeType) + SET_FROM_OPTIONAL(topPDecay, TopPDecay, FloatType) + SET_FROM_OPTIONAL(randomSeed, RandomSeed, uint64_t) + SET_FROM_OPTIONAL(temperature, Temperature, FloatType) + SET_FROM_OPTIONAL(minLength, MinLength, SizeType) + SET_FROM_OPTIONAL(beamSearchDiversityRate, BeamSearchDiversityRate, FloatType) + SET_FROM_OPTIONAL(repetitionPenalty, RepetitionPenalty, FloatType) + SET_FROM_OPTIONAL(presencePenalty, PresencePenalty, FloatType) + SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType) + SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType) +#undef SET_FROM_OPTIONAL + } + public: SizeType beamWidth; diff --git a/cpp/tensorrt_llm/CMakeLists.txt b/cpp/tensorrt_llm/CMakeLists.txt index 195880f50..4ed133467 100644 --- a/cpp/tensorrt_llm/CMakeLists.txt +++ b/cpp/tensorrt_llm/CMakeLists.txt @@ -34,6 +34,9 @@ add_subdirectory(runtime) set(BATCH_MANAGER_TARGET tensorrt_llm_batch_manager_static) set(BATCH_MANAGER_TARGET_ARCH "unknown") +set(EXECUTOR_TARGET tensorrt_llm_executor_static) +set(EXECUTOR_TARGET_ARCH "unknown") + message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") if(NOT WIN32) # Linux execute_process( @@ -52,8 +55,10 @@ if(NOT WIN32) # Linux if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") set(BATCH_MANAGER_TARGET_ARCH "x86_64-linux-gnu") + set(EXECUTOR_TARGET_ARCH "x86_64-linux-gnu") elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") set(BATCH_MANAGER_TARGET_ARCH "aarch64-linux-gnu") + set(EXECUTOR_TARGET_ARCH "aarch64-linux-gnu") if(NOT ${OS_ID} MATCHES "ubuntu" OR ${OS_VERSION_ID} VERSION_LESS 22.04) message( FATAL_ERROR @@ -68,6 +73,7 @@ else() # Windows # AMD64, IA64, ARM64, EM64T, X86 if(CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64") set(BATCH_MANAGER_TARGET_ARCH "x86_64-windows-msvc") + set(EXECUTOR_TARGET_ARCH "x86_64-windows-msvc") else() message( FATAL_ERROR @@ -105,8 +111,39 @@ else() endif() endif() +if(BUILD_EXECUTOR) + add_subdirectory(executor) +else() + add_library(${EXECUTOR_TARGET} STATIC IMPORTED) + if(NOT WIN32) # Linux + if(USE_CXX11_ABI) + set(EXECUTOR_LIB_LOC + "${CMAKE_CURRENT_SOURCE_DIR}/executor/${EXECUTOR_TARGET_ARCH}/libtensorrt_llm_executor_static.a" + ) + else() + set(EXECUTOR_LIB_LOC + "${CMAKE_CURRENT_SOURCE_DIR}/executor/${EXECUTOR_TARGET_ARCH}/libtensorrt_llm_executor_static.pre_cxx11.a" + ) + endif() + else() # Windows + set(EXECUTOR_LIB_LOC + "${CMAKE_CURRENT_SOURCE_DIR}/executor/${EXECUTOR_TARGET_ARCH}/tensorrt_llm_executor_static.lib" + ) + endif() + set_property(TARGET ${EXECUTOR_TARGET} PROPERTY IMPORTED_LOCATION + ${EXECUTOR_LIB_LOC}) + file(SIZE ${EXECUTOR_LIB_LOC} EXECUTOR_LIB_SIZE) + if(EXECUTOR_LIB_SIZE LESS 1024) + message( + FATAL_ERROR + "The executor library is truncated or incomplete. This is usually caused by using Git LFS (Large File Storage) incorrectly. Please try running command `git lfs install && git lfs pull`." + ) + endif() +endif() + find_package(Threads REQUIRED) target_link_libraries(${BATCH_MANAGER_TARGET} INTERFACE Threads::Threads) +target_link_libraries(${EXECUTOR_TARGET} INTERFACE Threads::Threads) if(NOT WIN32) if(USE_CXX11_ABI) @@ -128,6 +165,26 @@ else() add_custom_target(check_symbol) endif() +if(NOT WIN32) + if(USE_CXX11_ABI) + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/.check_symbol_executor" + COMMAND nm -C $ | grep -q 'std::__cxx11::' + DEPENDS ${EXECUTOR_TARGET}) + else() + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/.check_symbol_executor" + COMMAND nm -C $ | grep -qv + 'std::__cxx11::' + DEPENDS ${EXECUTOR_TARGET}) + endif() + add_custom_target( + check_symbol_executor + DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/.check_symbol_executor") +else() + add_custom_target(check_symbol_executor) +endif() + set(TRTLLM_LINK_LIBS ${CUBLAS_LIB} ${CUBLASLT_LIB} @@ -175,11 +232,28 @@ else() "-Wl,--no-whole-archive") endif() +if(WIN32) + target_link_libraries(${SHARED_TARGET} + PUBLIC $) + set_target_properties( + ${SHARED_TARGET} PROPERTIES LINK_FLAGS "/WHOLEARCHIVE:${EXECUTOR_TARGET}") +else() + # Assume everything else is like gcc + target_link_libraries( + ${SHARED_TARGET} + PRIVATE "-Wl,--whole-archive" $ + "-Wl,--no-whole-archive") +endif() + add_dependencies(${SHARED_TARGET} check_symbol) +add_dependencies(${SHARED_TARGET} check_symbol_executor) # Cyclic dependency of batch manager on TRT-LLM target_link_libraries(${BATCH_MANAGER_TARGET} INTERFACE ${SHARED_TARGET}) +# Cyclic dependency of executor on TRT-LLM +target_link_libraries(${EXECUTOR_TARGET} INTERFACE ${SHARED_TARGET}) + if(BUILD_PYT) add_subdirectory(thop) endif() diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a index dd2328b47..d356989f9 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0268f64b0c2540e07bf05ad458f7aa33c9d6e65fc4f5c85cd8d0946d658ffeb8 -size 2092012 +oid sha256:39835ca321e9c45d3b554ebceb1734966b75f83dbe8c550cc44846fb4fae8f72 +size 2110728 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index 92708de1a..b686e104b 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:89ae0be676e7aa9b562f6745636f7d77198f87b83ec6295aff74273767e4fca7 -size 2071180 +oid sha256:789c2eba349161e84a76b95b23f8294cf3bdcf855871672d76722c4ae858d81b +size 2091842 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt index d48810daa..542aa9382 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -63c3f64faa14f9d5d66b7e186a6cc80b libtensorrt_llm_batch_manager_static.a -dbcc1bbe80d977c1655d32ef69b36578 libtensorrt_llm_batch_manager_static.pre_cxx11.a +30a6c963121b3cfda21dc0117b7984e1 libtensorrt_llm_batch_manager_static.a +0d2d2e3157201f6336d749b3e6f994bc libtensorrt_llm_batch_manager_static.pre_cxx11.a diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a new file mode 100644 index 000000000..7df8edefe --- /dev/null +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:643e546711fd33a85073560e3428c6a2f60525f7592aa3328043dfad61631c30 +size 586532 diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a new file mode 100644 index 000000000..9b593b09f --- /dev/null +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28059131a9325c88bd362cb12c57a2b2e47d3e0aac140e5d1cf9a7020a81999e +size 570860 diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt new file mode 100644 index 000000000..6856a9ebe --- /dev/null +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt @@ -0,0 +1,2 @@ +fa89714705a1915f052c635a07dc4c73 libtensorrt_llm_executor_static.a +83cbfaf10bedd7d8edeab33552dcf3df libtensorrt_llm_executor_static.pre_cxx11.a diff --git a/cpp/tensorrt_llm/kernels/decodingKernels.cu b/cpp/tensorrt_llm/kernels/decodingKernels.cu index cd0290cef..0da2e5d8f 100644 --- a/cpp/tensorrt_llm/kernels/decodingKernels.cu +++ b/cpp/tensorrt_llm/kernels/decodingKernels.cu @@ -472,31 +472,35 @@ void invokeCopyNextStepIds(int* nextStepIds, int** outputIdsPtr, const int* sequ } __global__ void transposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, const int* sequenceLengths, - const int* batchSlots, int batchSize, int beamWidth, int maxSeqLen) + const int* batchSlots, int batchSize, int maxBatchSize, int beamWidth, int maxSeqLen) { int index = blockIdx.x * blockDim.x + threadIdx.x; - const int batchIdx = index / (beamWidth * maxSeqLen); - auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx; - const int tmpIdx = index % (beamWidth * maxSeqLen); - const int beamIdx = tmpIdx / maxSeqLen; - const int pos = tmpIdx % maxSeqLen; + int const batchIdx = index / (beamWidth * maxSeqLen); + int const tmpIdx = index % (beamWidth * maxSeqLen); + int const beamIdx = tmpIdx / maxSeqLen; + int const pos = tmpIdx % maxSeqLen; + if (batchIdx >= batchSize) + { + return; + } - if (batchIdx < batchSize && pos < sequenceLengths[batchSlot]) + auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx; + if (pos < sequenceLengths[batchSlot]) { auto const batchBeamIdx = batchSlot * beamWidth * maxSeqLen + beamIdx * maxSeqLen + pos; outputLogProbs[batchBeamIdx] - = outputLogProbsTiled[pos * batchSize * beamWidth + batchSlot * beamWidth + beamIdx]; + = outputLogProbsTiled[pos * maxBatchSize * beamWidth + batchSlot * beamWidth + beamIdx]; } } void invokeTransposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, const int* sequenceLengths, - const int* batchSlots, int batchSize, int beamWidth, int maxSeqLen, cudaStream_t stream) + const int* batchSlots, int batchSize, int maxBatchSize, int beamWidth, int maxSeqLen, cudaStream_t stream) { dim3 block(256); dim3 grid(divUp(batchSize * beamWidth * maxSeqLen, block.x)); - transposeLogProbs<<>>( - outputLogProbs, outputLogProbsTiled, sequenceLengths, batchSlots, batchSize, beamWidth, maxSeqLen); + transposeLogProbs<<>>(outputLogProbs, outputLogProbsTiled, sequenceLengths, batchSlots, + batchSize, maxBatchSize, beamWidth, maxSeqLen); } __global__ void acceptDraftTokensByIds(int32_t const* draftIds, int32_t const* targetIds, int32_t const* contextLengths, diff --git a/cpp/tensorrt_llm/kernels/decodingKernels.h b/cpp/tensorrt_llm/kernels/decodingKernels.h index c62b1282c..0ca020a73 100644 --- a/cpp/tensorrt_llm/kernels/decodingKernels.h +++ b/cpp/tensorrt_llm/kernels/decodingKernels.h @@ -129,7 +129,8 @@ void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs, int32_t maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream); void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_tiled, int32_t const* sequence_lengths, - int32_t const* batchSlots, int32_t batch_size, int32_t beam_width, int32_t max_seq_len, cudaStream_t stream); + int32_t const* batchSlots, int32_t batch_size, int32_t max_batch_size, int32_t beam_width, int32_t max_seq_len, + cudaStream_t stream); void invokeAcceptTokens(int32_t const* draft_tokens, int32_t const* target_tokens, int32_t const* context_lengths, int32_t const* nums_draft_tokens, int32_t* sequence_lengths, bool const* finished, bool* finished_final, diff --git a/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels/onlineSoftmaxBeamsearchKernelsTemplate.h b/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels/onlineSoftmaxBeamsearchKernelsTemplate.h index 63e13cf8a..940d53078 100644 --- a/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels/onlineSoftmaxBeamsearchKernelsTemplate.h +++ b/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels/onlineSoftmaxBeamsearchKernelsTemplate.h @@ -138,7 +138,8 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ int selected_beams; __shared__ float old_cum_log_probs[MAX_K2]; - __shared__ cub_kvp cta_topk[MAX_K2]; + __shared__ char cta_topk_store[MAX_K2 * sizeof(cub_kvp)]; + auto* cta_topk = reinterpret_cast(cta_topk_store); if (thread_id == 0) { @@ -687,7 +688,8 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta MD partial_md{-MAX_T_VAL, 0.0f}; cub_kvp total_topk{V - 1, -MAX_T_VAL}; - __shared__ cub_kvp buf_smem_kv[MAX_K2]; + __shared__ char buf_smem_kv_store[MAX_K2 * sizeof(cub_kvp)]; + auto* buf_smem_kv = reinterpret_cast(buf_smem_kv_store); // load and unpack into registers through smem for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * parts_per_beam; idx += THREADBLOCK_SIZE) diff --git a/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu b/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu index 2a23ecb18..39b899024 100644 --- a/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu +++ b/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu @@ -308,7 +308,7 @@ template __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* inIdxBuf, T* outBuf, IdxT* outIdxBuf, int previousLen, Counter* counter, AccT* histogram, IdxT* countHistogram, int pass, float* outputLogProbs, float* cumLogProbs, IdxT** ids, IdxT const* endIds, IdxT* sequenceLengths, - FinishedState* finishedOutput, int const batchId, bool earlyStop) + FinishedState* finishedOutput, int const batchId, int maxBatchSize, bool earlyStop) { static_assert(std::is_same_v | std::is_same_v, "T needs to be either half or float"); static_assert(std::is_same_v, "AccT needs to be float"); @@ -359,7 +359,7 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i // See the remark above on the distributed execution of `f` using // vectorizedProcess. auto f = [inIdxBuf, outBuf, outIdxBuf, selectMin, startBit, mask, previousStartBit, kthValueBits, pFilterCnt, - outputLogProbs, cumLogProbs, ids, endIds, sequenceLengths, finishedOutput, batchId, + outputLogProbs, cumLogProbs, ids, endIds, sequenceLengths, finishedOutput, batchId, maxBatchSize, earlyStop](T value, IdxT i) { auto const previousBits = (twiddleIn(value, selectMin) >> previousStartBit) << previousStartBit; @@ -370,8 +370,8 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i int const currentStep = sequenceLengths[batchId]; IdxT index = inIdxBuf ? inIdxBuf[i] : i; ids[batchId][currentStep] = index; - epilogue( - value, index, outputLogProbs, cumLogProbs, endIds, sequenceLengths, finishedOutput, batchId); + epilogue(value, index, outputLogProbs, cumLogProbs, endIds, sequenceLengths, finishedOutput, + batchId, maxBatchSize); } if (outBuf) { @@ -506,14 +506,14 @@ __device__ void chooseBucket( */ template __device__ void epilogue(T const value, IdxT const index, float* outputLogProbs, float* cumLogProbs, IdxT const* endIds, - IdxT* sequenceLengths, FinishedState* finishedOutput, int const batchId) + IdxT* sequenceLengths, FinishedState* finishedOutput, int const batchId, int maxBatchSize) { if (outputLogProbs != nullptr || cumLogProbs != nullptr) { float res = logf(value); if (outputLogProbs) { - outputLogProbs[batchId] = res; + outputLogProbs[sequenceLengths[batchId] * maxBatchSize + batchId] = res; } if (cumLogProbs) { @@ -542,7 +542,7 @@ __device__ void epilogue(T const value, IdxT const index, float* outputLogProbs, template __device__ void lastFilter(T const* inBuf, IdxT const* inIdxBuf, IdxT currentLen, Counter* counter, float* outputLogProbs, float* cumLogProbs, IdxT** ids, IdxT const* endIds, IdxT* sequenceLengths, - FinishedState* finishedOutput, int const batchId) + FinishedState* finishedOutput, int const batchId, int maxBatchSize) { auto const kthValueBits = counter->kthValueBits; auto const equalValue = twiddleOut(kthValueBits, false); @@ -565,7 +565,8 @@ __device__ void lastFilter(T const* inBuf, IdxT const* inIdxBuf, IdxT currentLen if (threadIdx.x == 0) { - epilogue(equalValue, *outIdx, outputLogProbs, cumLogProbs, endIds, sequenceLengths, finishedOutput, batchId); + epilogue(equalValue, *outIdx, outputLogProbs, cumLogProbs, endIds, sequenceLengths, finishedOutput, batchId, + maxBatchSize); } } @@ -609,7 +610,7 @@ __device__ void lastFilter(T const* inBuf, IdxT const* inIdxBuf, IdxT currentLen template __global__ void airTopPSampling(Counter* counters, AccT* histograms, IdxT* countHistograms, IdxT** ids, int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, IdxT const* endIds, int const batchSize, bool const* skipDecode, int const pass, T* buf1, + float* outputLogProbs, IdxT const* endIds, int const maxBatchSize, bool const* skipDecode, int const pass, T* buf1, IdxT* idxBuf1, T* buf2, IdxT* idxBuf2, int32_t const* batchSlots) { assert(sequenceLengths != nullptr); @@ -698,7 +699,7 @@ __global__ void airTopPSampling(Counter* counters, AccT* histogra filterAndHistogram(inBuf, inIdxBuf, outBuf, outIdxBuf, previousLen, counter, histogram, countHistogram, pass, outputLogProbs, cumLogProbs, ids, endIds, sequenceLengths, finishedOutput, batchSlot, - earlyStop); + maxBatchSize, earlyStop); __syncthreads(); __threadfence(); @@ -779,7 +780,7 @@ __global__ void airTopPSampling(Counter* counters, AccT* histogra { lastFilter(outBuf ? outBuf : inBuf, outIdxBuf ? outIdxBuf : inIdxBuf, outBuf ? currentLen : counter->oriLen, counter, outputLogProbs, cumLogProbs, ids, endIds, - sequenceLengths, finishedOutput, batchSlot); + sequenceLengths, finishedOutput, batchSlot, maxBatchSize); __syncthreads(); } @@ -891,9 +892,9 @@ unsigned calcAirTopPBlockNum(int batchSize, IdxT len, int smCnt) template void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - T const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, - float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, bool const* skipDecode, - int32_t const* batchSlots) + T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, + int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, + bool const* skipDecode, int32_t const* batchSlots) { using IdxT = int; using AccT = float; @@ -953,46 +954,47 @@ void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** ou } kernel<<>>(counters, histograms, countHistograms, outputIds, - sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, endIds, batchSize, skipDecode, - pass, buf1, idxBuf1, buf2, idxBuf2, batchSlots); + sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, endIds, maxBatchSize, + skipDecode, pass, buf1, idxBuf1, buf2, idxBuf2, batchSlots); sync_check_cuda_error(); } } template void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - float const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, - int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, - bool const* skipDecode, int32_t const* batchSlots); + float const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, + size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, + int blockNum, bool const* skipDecode, int32_t const* batchSlots); template void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - half const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, - int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, - bool const* skipDecode, int32_t const* batchSlots); + half const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, + size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, + int blockNum, bool const* skipDecode, int32_t const* batchSlots); template void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - T const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, - float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots) + T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, + int const* endIds, float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, + int32_t const* batchSlots) { invokeBatchAirTopPSampling(workspace, workspaceSize, outputIds, sequenceLength, finishedInput, finishedOutput, - cumLogProbs, outputLogProbs, logProbs, curandstate, batchSize, vocabSizePadded, endIds, topP, nullptr, stream, - blockNum, skipDecode, batchSlots); + cumLogProbs, outputLogProbs, logProbs, curandstate, batchSize, maxBatchSize, vocabSizePadded, endIds, topP, + nullptr, stream, blockNum, skipDecode, batchSlots); } template void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - float const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, - int const* endIds, float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, - int32_t const* batchSlots); + float const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, + size_t const vocabSizePadded, int const* endIds, float const topP, cudaStream_t stream, int blockNum, + bool const* skipDecode, int32_t const* batchSlots); template void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - half const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, - int const* endIds, float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, - int32_t const* batchSlots); + half const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, + size_t const vocabSizePadded, int const* endIds, float const topP, cudaStream_t stream, int blockNum, + bool const* skipDecode, int32_t const* batchSlots); template unsigned calcAirTopPBlockNum(int batchSize, int len, int smCnt); template unsigned calcAirTopPBlockNum(int batchSize, int len, int smCnt); diff --git a/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu b/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu index 805c44c30..1c701cdc5 100644 --- a/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu +++ b/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu @@ -123,7 +123,7 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, const int maxTopK, const int* topKs, const float topP, const float* topPs, curandState_t* curandstate, const int* endIds, const int vocabSize, const bool* skipDecode, const int* batchSlots, - const bool normalizeLogProbs, const bool logitHasProbs) + int maxBatchSize, const bool normalizeLogProbs, const bool logitHasProbs) { bool const IS_FP16 = std::is_same::value; T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; @@ -210,7 +210,8 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm // If s_id is -1 here we force output token to the last from vocabulary to get vivid indicator of smth // going wrong for the debug auto outputId = idx != -1 ? topKTmpIdBuf[batchIdx * stride + idx] % vocabSize : vocabSize - 1; - ids[batchSlot][sequenceLengths[batchSlot]] = outputId; + auto const curSeqLen = sequenceLengths[batchSlot]; + ids[batchSlot][curSeqLen] = outputId; if (cumLogProbs != nullptr || outputLogProbs != nullptr) { float logProb = logf(expLogit); @@ -225,7 +226,8 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm // log_prob = log P(i | i is in top-k) = log(expLogit) // normalized: // log_prob = log P(i | i is in top-k) = log(expLogit / sum) - outputLogProbs[batchSlot] = normalizeLogProbs ? logProb - logf(s_sum) : logProb; + outputLogProbs[curSeqLen * maxBatchSize + batchSlot] + = normalizeLogProbs ? logProb - logf(s_sum) : logProb; } } break; @@ -256,8 +258,8 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm topKStage2Sampling \ <<>>(topKTmpIdBuf, \ topKTmpValBuf, ids, sequenceLengths, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, maxTopK, \ - topKs, topP, topPs, curandstate, endIds, vocabSize, skipDecode, batchSlots, normalizeLogProbs, \ - logitsHasProbs); \ + topKs, topP, topPs, curandstate, endIds, vocabSize, skipDecode, batchSlots, maxBatchSize, \ + normalizeLogProbs, logitsHasProbs); \ break; template @@ -265,7 +267,7 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs, const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, - const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs) + int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -330,37 +332,39 @@ template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, co int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs, const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, - const int batchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs); + const int batchSize, int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs, + const bool logitsHasProbs); template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids, int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs, const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, - const int batchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs); + const int batchSize, int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs, + const bool logitsHasProbs); template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds, - const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode, + const int* batchSlots, cudaStream_t stream, const int batchSize, int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs) { invokeBatchTopKSampling(workspace, workspaceSize, logProbs, ids, sequenceLengths, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, curandstate, topK, nullptr, topP, nullptr, vocabSizePadded, endIds, batchSlots, - stream, batchSize, skipDecode, normalizeLogProbs, logitsHasProbs); + stream, batchSize, maxBatchSize, skipDecode, normalizeLogProbs, logitsHasProbs); } template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const float* logProbs, int** ids, int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, - const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode, - const bool normalizeLogProbs, const bool logitsHasProbs); + const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, int maxBatchSize, + const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs); template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids, int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, - const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode, - const bool normalizeLogProbs, const bool logitsHasProbs); + const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, int maxBatchSize, + const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/samplingTopKKernels.h b/cpp/tensorrt_llm/kernels/samplingTopKKernels.h index 8f4324815..e53d6d2e8 100644 --- a/cpp/tensorrt_llm/kernels/samplingTopKKernels.h +++ b/cpp/tensorrt_llm/kernels/samplingTopKKernels.h @@ -59,6 +59,7 @@ namespace kernels //! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool //! \param stream cuda stream //! \param batchSize batch size +//! \param maxBatchSize maximum batch size //! \param skipDecode input buffer [maxBatchSize]. Flags whether to skip decoding per request //! \param normalizeLogProbs when set to True outputLogProbs are normalized to TopK //! \param logitsHasProbs flag to highlight that logProbs contains probabilities @@ -68,14 +69,14 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs, const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, - const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs); + int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs); //! \brief Specialization of invokeBatchTopKSampling with topPs=nullptr and topKs=nullptr template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** outputIds, int* sequenceLength, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds, - const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode, + const int* batchSlots, cudaStream_t stream, const int batchSize, int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs); } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu b/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu index 901ccff8d..b03e12e44 100644 --- a/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu +++ b/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu @@ -164,7 +164,8 @@ struct BlockPrefixCallbackOp template __device__ void epilogue(int batchId, int currentStep, int offset, int** ids, int* sortedIdVals, T* sortedLogProbs, - float* cumLogProbs, float* outputLogProbs, int const* endIds, int* sequenceLengths, FinishedState* finishedOutput) + float* cumLogProbs, float* outputLogProbs, int const* endIds, int* sequenceLengths, FinishedState* finishedOutput, + int maxBatchSize) { ids[batchId][currentStep] = sortedIdVals[offset]; @@ -177,7 +178,7 @@ __device__ void epilogue(int batchId, int currentStep, int offset, int** ids, in } if (outputLogProbs != nullptr) { - outputLogProbs[batchId] = lprob; + outputLogProbs[sequenceLengths[batchId] * maxBatchSize + batchId] = lprob; } } if (sequenceLengths != nullptr && finishedOutput != nullptr) @@ -199,7 +200,7 @@ template __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, int const* beginOffsetBuf, int const* offsetBuf, int const vocabSize, curandState_t* curandstate, float const topP, - float const* topPs, int const* endIds, int const batchSize, bool const* skipDecode, int const* batchSlots) + float const* topPs, int const* endIds, int maxBatchSize, bool const* skipDecode, int const* batchSlots) { /** * Each block processes one request row sorted in descending order by probabilities. @@ -258,7 +259,7 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i { int offset = batchId * vocabSize; epilogue(batchSlot, currentStep, offset, ids, sortedIdVals, sortedLogProbs, cumLogProbs, outputLogProbs, - endIds, sequenceLength, finishedOutput); + endIds, sequenceLength, finishedOutput, maxBatchSize); } return; } @@ -299,7 +300,7 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1)) { epilogue(batchSlot, currentStep, offset + selectedTokenId, ids, sortedIdVals, sortedLogProbs, cumLogProbs, - outputLogProbs, endIds, sequenceLength, finishedOutput); + outputLogProbs, endIds, sequenceLength, finishedOutput, maxBatchSize); } } @@ -307,7 +308,7 @@ template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf, - curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, + curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, bool const* skipDecode, int const* batchSlots) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -354,7 +355,7 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub // Sample with Top P given sorted tokens topPSsampling<<>>(sortedLogProbs, sortedIdVals, outputIds, sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, beginOffsetBuf, - offsetBuf + 1, vocabSize, curandstate, maxTopP, topPs, endIds, batchSize, skipDecode, batchSlots); + offsetBuf + 1, vocabSize, curandstate, maxTopP, topPs, endIds, maxBatchSize, skipDecode, batchSlots); sync_check_cuda_error(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -363,40 +364,40 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, float const* logProbs, int const* idVals, int* offsetBuf, - int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, - int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, bool const* skipDecode, - int const* batchSlots); + int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, int maxBatchSize, + size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, + bool const* skipDecode, int const* batchSlots); template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, half const* logProbs, int const* idVals, int* offsetBuf, - int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, - int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, bool const* skipDecode, - int const* batchSlots); + int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, int maxBatchSize, + size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, + bool const* skipDecode, int const* batchSlots); template void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf, - curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, float const topP, - cudaStream_t stream, bool const* skipDecode, int const* batchSlots) + curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, + float const topP, cudaStream_t stream, bool const* skipDecode, int const* batchSlots) { invokeBatchTopPSampling(workspace, workspaceSize, cubTempStorageSize, outputIds, sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, logProbs, idVals, offsetBuf, beginOffsetBuf, curandstate, - batchSize, vocabSizePadded, endIds, topP, nullptr, stream, skipDecode, batchSlots); + batchSize, maxBatchSize, vocabSizePadded, endIds, topP, nullptr, stream, skipDecode, batchSlots); } template void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, float const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf, - curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, float const topP, - cudaStream_t stream, bool const* skipDecode, int const* batchSlots); + curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, + float const topP, cudaStream_t stream, bool const* skipDecode, int const* batchSlots); template void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, half const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf, - curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, float const topP, - cudaStream_t stream, bool const* skipDecode, int const* batchSlots); + curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, + float const topP, cudaStream_t stream, bool const* skipDecode, int const* batchSlots); __global__ void computeToppDecay(float* runtimeTopP, float const* runtimeInitialTopP, int const** outputIds, float const* topPDecay, float const* topPMin, int32_t const* topPResetIds, int const* sequenceLengths, diff --git a/cpp/tensorrt_llm/kernels/samplingTopPKernels.h b/cpp/tensorrt_llm/kernels/samplingTopPKernels.h index 311786f50..9a54d359d 100644 --- a/cpp/tensorrt_llm/kernels/samplingTopPKernels.h +++ b/cpp/tensorrt_llm/kernels/samplingTopPKernels.h @@ -63,6 +63,7 @@ void invokeTopPInitialize(int* topPIdValBuf, int* topPOffsetBuf, int* beginTopPO //! \param curandstate input buffer [maxBatchSize]. Curand states properly initialized using //! invokeCurandInitialize per request. //! \param batchSize batch size +//! \param maxBatchSize maximum batch size //! \param vocabSizePadded size of padded vocab //! \param endIds input buffer [maxBatchSize]. EOS token ids per request //! \param maxTopP maximum among all topPs P for topP sampling @@ -77,7 +78,7 @@ template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf, - curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, + curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, bool const* skipDecode, int const* batchSlots); //! \brief Specialization of invokeBatchTopPSampling with topPs=nullptr @@ -85,8 +86,8 @@ template void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf, - curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, float const topPp, - cudaStream_t stream, bool const* skipDecode, int const* batchSlots); + curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, + float const topPp, cudaStream_t stream, bool const* skipDecode, int const* batchSlots); //! \brief Given logProbs, performs top P sampling. //! Note different from invokeTopPSampling() and invokeBatchTopPSampling() there two functions invokeAirTopPSampling @@ -116,6 +117,7 @@ void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempS //! \param curandstate input buffer [batchSize]. Curand states properly initialized using invokeCurandInitialize per //! request. //! \param batchSize batch size +//! \param maxBatchSize max batch size //! \param vocabSizePadded size of padded vocab //! \param endIds input buffer [batchSize]. EOS token ids per request //! \param maxTopP maximum among all topPs P for topP sampling @@ -128,16 +130,17 @@ void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempS template void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - T const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, - float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, bool const* skipDecode, - int32_t const* batchSlots); + T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, + int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, + bool const* skipDecode, int32_t const* batchSlots); //! \brief Specialization of invokeBatchAirTopPSampling with topPs=nullptr template void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - T const* logProbs, curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds, - float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots); + T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, + int const* endIds, float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, + int32_t const* batchSlots); //! \brief Calculate the number of blocks based on the number of multiprocessors, batchSize and vocabSize. //! \tparam T the data type of value diff --git a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp index 39495abe7..862a638e9 100644 --- a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp +++ b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp @@ -372,7 +372,8 @@ void DynamicDecodeLayer::forward(OutputParams& outputs, ForwardParams const& checkStopCriteria(outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, mStream); // Copy nextIds and transpose logits when needed - prepareOutputData(outputs, params, mIdsPtrHost, batchSlots, batchSize, beamWidth, maxSeqLen, mCyclicStep, mStream); + prepareOutputData( + outputs, params, mIdsPtrHost, batchSlots, batchSize, mMaxBatchSize, beamWidth, maxSeqLen, mCyclicStep, mStream); mCyclicStep += 1; @@ -489,10 +490,8 @@ void DynamicDecodeLayer::layersForward(Tensor& logits, OutputParams& outputs, } if (outputs.output_log_probs_tiled) { - TLLM_CHECK(0 <= mCyclicStep && mCyclicStep < maxSeqLen); Tensor& output_log_probs = outputs.output_log_probs_tiled.value(); - size_t step_offset = mCyclicStep * batchSize * beamWidth; - decode_outputs.output_log_probs = output_log_probs.slice({1, localBatchSize * beamWidth}, step_offset); + decode_outputs.output_log_probs = output_log_probs.slice({1, localBatchSize * beamWidth}, 0); } // Run TopK + TopP decode layers. @@ -697,8 +696,8 @@ void DynamicDecodeLayer::prepareIdsPtrs( template void DynamicDecodeLayer::prepareOutputData(OutputParams& outputs, ForwardParams const& params, - runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, - size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream) + runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, size_t maxBatchSize, + size_t beamWidth, size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto idsPtrHostSlice = ITensor::slice(idsPtrsHost, cyclicStep, 1); @@ -713,8 +712,8 @@ void DynamicDecodeLayer::prepareOutputData(OutputParams& outputs, ForwardPara invokeTransposeLogProbs(outputs.output_log_probs.value().template getPtr(), outputs.output_log_probs_tiled.value().template getPtr(), - outputs.sequence_length->template getPtr(), batchSlots, batchSize, beamWidth, logProbsMaxSeqLen, - stream); + outputs.sequence_length->template getPtr(), batchSlots, batchSize, maxBatchSize, beamWidth, + logProbsMaxSeqLen, stream); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } diff --git a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h index f544bf3a3..387f7d7fa 100644 --- a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h +++ b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h @@ -150,7 +150,7 @@ class DynamicDecodeLayer : public BaseLayer std::optional output_log_probs_tiled; // [request_output_length, batch_size, beam_width], must be float*, optional std::optional - output_log_probs; // [batchSize, beam_width, request_ouptut_length], must be float*, optional + output_log_probs; // [batch_size, beam_width, request_output_length], must be float*, optional std::optional tgt_cache_indirection; // [local_batch_size, beam_width, max_seq_len], the k/v cache index for beam search std::shared_ptr @@ -164,6 +164,11 @@ class DynamicDecodeLayer : public BaseLayer void allocateBuffer(); void freeBuffer(); + T* getRuntimeLogitsDevice() + { + return mRuntimeLogitsDevice; + } + private: void initialize(); void initializeLayers(); @@ -197,8 +202,8 @@ class DynamicDecodeLayer : public BaseLayer void prepareIdsPtrs( OutputParams& outputs, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen); static void prepareOutputData(OutputParams& outputs, ForwardParams const& params, - runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, - size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream); + runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, + size_t maxBatchSize, size_t beamWidth, size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream); private: std::unique_ptr> mOnlineBeamSearchDecode; diff --git a/cpp/tensorrt_llm/layers/topKSamplingLayer.cu b/cpp/tensorrt_llm/layers/topKSamplingLayer.cu index d49d89e7f..029dff340 100644 --- a/cpp/tensorrt_llm/layers/topKSamplingLayer.cu +++ b/cpp/tensorrt_llm/layers/topKSamplingLayer.cu @@ -74,8 +74,8 @@ void TopKSamplingLayer::allocateBuffer(size_t const batchSize) { TLLM_LOG_TRACE(__PRETTY_FUNCTION__); invokeTopKSampling(nullptr, mSamplingWorkspaceSize, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, TOP_K_MAX, 1.0f, mVocabSizePadded, nullptr, nullptr, mStream, batchSize, nullptr, - mNormalizeLogProbs, false); + nullptr, nullptr, TOP_K_MAX, 1.0f, mVocabSizePadded, nullptr, nullptr, mStream, batchSize, mMaxBatchSize, + nullptr, mNormalizeLogProbs, false); std::array deviceBufferSizes; deviceBufferSizes[0] = sizeof(uint32_t) * batchSize; @@ -213,7 +213,7 @@ void TopKSamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& invokeBatchTopKSampling(samplingWorkspaceDevice, mSamplingWorkspaceSize, logits, outputs.output_ids_ptr.template getPtr(), sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, curandStatesDevice, (int) mRuntimeMaxTopK, (int*) (mRuntimeTopKDevice), 1.0f, - mRuntimeTopPDevice, mVocabSizePadded, endIds, batchSlots, mStream, batchSize, mSkipDecodeDevice, + mRuntimeTopPDevice, mVocabSizePadded, endIds, batchSlots, mStream, batchSize, mMaxBatchSize, mSkipDecodeDevice, mNormalizeLogProbs, probsComputed); sync_check_cuda_error(); } diff --git a/cpp/tensorrt_llm/layers/topPSamplingLayer.cu b/cpp/tensorrt_llm/layers/topPSamplingLayer.cu index 849f1289b..135de4b4f 100644 --- a/cpp/tensorrt_llm/layers/topPSamplingLayer.cu +++ b/cpp/tensorrt_llm/layers/topPSamplingLayer.cu @@ -78,8 +78,8 @@ void TopPSamplingLayer::allocateBuffer(size_t batchSize) nullptr, // cum_log_probs nullptr, // output_log_probs nullptr, // log_probs - mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, nullptr, batchSize, mVocabSizePadded, nullptr, - 0.f, mStream, nullptr, nullptr); + mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, nullptr, batchSize, mMaxBatchSize, + mVocabSizePadded, nullptr, 0.f, mStream, nullptr, nullptr); } else { @@ -91,7 +91,8 @@ void TopPSamplingLayer::allocateBuffer(size_t batchSize) nullptr, // cum_log_probs nullptr, // output_log_probs nullptr, // log_probs) - nullptr, batchSize, mVocabSizePadded, nullptr, 0.f, mStream, mAirTopPBlockNum, nullptr, nullptr); + nullptr, batchSize, mMaxBatchSize, mVocabSizePadded, nullptr, 0.f, mStream, mAirTopPBlockNum, nullptr, + nullptr); } std::array deviceBufferSizes; @@ -315,8 +316,8 @@ void TopPSamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& invokeBatchTopPSampling(samplingWorkspaceDevice, mSamplingWorkspaceSize, mCubTempStorageSize, outputs.output_ids_ptr.template getPtr(), sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, probs, mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, curandStatesDevice, - batchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, mSkipDecodeDevice, - batchSlots); + batchSize, mMaxBatchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, + mSkipDecodeDevice, batchSlots); sync_check_cuda_error(); invokeComputeToppDecay(mRuntimeTopPDevice, mInitialTopPDevice, outputs.output_ids_ptr.template getPtr(), mTopPDecayDevice, mTopPMinDevice, mTopPResetIdsDevice, @@ -327,8 +328,8 @@ void TopPSamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& { invokeBatchAirTopPSampling(samplingWorkspaceDevice, mSamplingWorkspaceSize, outputs.output_ids_ptr.template getPtr(), sequenceLength, finishedInput, finishedOutput, cumLogProbs, - outputLogProbs, probs, curandStatesDevice, batchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, - mRuntimeTopPDevice, mStream, mAirTopPBlockNum, mSkipDecodeDevice, batchSlots); + outputLogProbs, probs, curandStatesDevice, batchSize, mMaxBatchSize, mVocabSizePadded, endIds, + mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, mAirTopPBlockNum, mSkipDecodeDevice, batchSlots); sync_check_cuda_error(); } } diff --git a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp index 04291698b..a476798a5 100644 --- a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp @@ -47,8 +47,27 @@ BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_s , mRemovePadding(remove_padding) { // pre-check whether FMHA is supported in order to save memory allocation - mEnableContextFMHA = mEnableContextFMHA && (mType == DataType::kHALF) && MHARunner::fmha_supported(mHeadSize, mSM) - && !mRelativeAttention; + if (mEnableContextFMHA) + { + mEnableContextFMHA = false; + if (!(mType == DataType::kHALF || mType == DataType::kBF16)) + { + TLLM_LOG_WARNING("Fall back to unfused MHA because of unsupported data type."); + } + else if (!MHARunner::fmha_supported(mHeadSize, mSM)) + { + TLLM_LOG_WARNING( + "Fall back to unfused MHA because of unsupported head size %d in sm_{%d}.", mHeadSize, mSM); + } + else if (mRelativeAttention) + { + TLLM_LOG_WARNING("Fall back to unfused MHA because of relative position embedding."); + } + else + { + mEnableContextFMHA = true; + } + } } // Parameterized constructor @@ -450,9 +469,23 @@ int BertAttentionPlugin::initialize() noexcept mCublasWrapper.reset(new tc::CublasMMWrapper(cublasHandle, cublasLtHandle, nullptr, nullptr)); if (mEnableContextFMHA) { - mFMHARunner.reset(new FusedMHARunnerV2(DATA_TYPE_FP16, mNumHeads, mHeadSize, mQScaling)); + // Pre-checked during constructing. + Data_type data_type; + if (mType == DataType::kHALF) + { + data_type = DATA_TYPE_FP16; + } + else if (mType == DataType::kBF16) + { + data_type = DATA_TYPE_BF16; + } + else + { + TLLM_CHECK_WITH_INFO(false, "GPTAttentionPlugin received wrong data type."); + } + mFMHARunner.reset(new FusedMHARunnerV2(data_type, mNumHeads, mHeadSize, mQScaling)); // set flags: force_fp32_acc, is_s_padded, causal_mask, num_kv_heads = num_heads - mFMHARunner->setup_flags(mFMHAForceFP32Acc, true, false, mNumHeads); + mFMHARunner->setup_flags(mFMHAForceFP32Acc, !mRemovePadding, false, mNumHeads); } return 0; diff --git a/cpp/tensorrt_llm/runtime/bufferManager.cpp b/cpp/tensorrt_llm/runtime/bufferManager.cpp index 450f32af9..fbdc40f3c 100644 --- a/cpp/tensorrt_llm/runtime/bufferManager.cpp +++ b/cpp/tensorrt_llm/runtime/bufferManager.cpp @@ -80,12 +80,12 @@ BufferManager::ITensorPtr BufferManager::pinnedPool(nvinfer1::Dims dims, nvinfer return std::make_unique(dims, type); } -BufferManager::IBufferPtr BufferManager::managed(std::size_t size, nvinfer1::DataType type) const +BufferManager::IBufferPtr BufferManager::managed(std::size_t size, nvinfer1::DataType type) { return std::make_unique(size, type); } -BufferManager::ITensorPtr BufferManager::managed(nvinfer1::Dims dims, nvinfer1::DataType type) const +BufferManager::ITensorPtr BufferManager::managed(nvinfer1::Dims dims, nvinfer1::DataType type) { return std::make_unique(dims, type); } @@ -149,8 +149,10 @@ BufferManager::IBufferPtr BufferManager::allocate( case MemoryType::kCPU: return cpu(size, type); case MemoryType::kGPU: return gpu(size, type); case MemoryType::kPINNED: return pinned(size, type); - default: TLLM_THROW("Unknown memory type"); + case MemoryType::kUVM: return managed(size, type); } + + TLLM_THROW("Unknown memory type"); } BufferManager::ITensorPtr BufferManager::allocate( @@ -161,8 +163,10 @@ BufferManager::ITensorPtr BufferManager::allocate( case MemoryType::kCPU: return cpu(dims, type); case MemoryType::kGPU: return gpu(dims, type); case MemoryType::kPINNED: return pinned(dims, type); - default: TLLM_THROW("Unknown memory type"); + case MemoryType::kUVM: return managed(dims, type); } + + TLLM_THROW("Unknown memory type"); } BufferManager::IBufferPtr BufferManager::copyFrom(IBuffer const& src, MemoryType memoryType) const diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp index ffb370135..765f580d7 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp @@ -297,6 +297,12 @@ void GptDecoderBatch::newRequest( = ITensor::slice(constPointerCast(dJointInput.embeddingBias), batchIdx, localBatchSize); if (request.embeddingBias) { + TLLM_CHECK(request.embeddingBias->getShape().nbDims == 2); + TLLM_CHECK(request.embeddingBias->getShape().d[0] == 1); + TLLM_CHECK_WITH_INFO(request.embeddingBias->getShape().d[1] == static_cast(mVocabSize), + "The embedding bias shape is not as expected. Expected last dimension to be same as vocab size: %lu.", + mVocabSize); + manager.copy(*request.embeddingBias, *embeddingBiasSlice); dInput->embeddingBias = embeddingBiasSlice; } diff --git a/cpp/tensorrt_llm/runtime/iBuffer.cpp b/cpp/tensorrt_llm/runtime/iBuffer.cpp index ea27a905e..fec030e6d 100644 --- a/cpp/tensorrt_llm/runtime/iBuffer.cpp +++ b/cpp/tensorrt_llm/runtime/iBuffer.cpp @@ -35,11 +35,12 @@ MemoryType IBuffer::memoryType(void const* data) switch (attributes.type) { case cudaMemoryTypeHost: return MemoryType::kPINNED; - case cudaMemoryTypeDevice: - case cudaMemoryTypeManaged: return MemoryType::kGPU; + case cudaMemoryTypeDevice: return MemoryType::kGPU; + case cudaMemoryTypeManaged: return MemoryType::kUVM; case cudaMemoryTypeUnregistered: return MemoryType::kCPU; - default: TLLM_THROW("Unsupported memory type"); } + + TLLM_THROW("Unsupported memory type"); } IBuffer::UniquePtr IBuffer::slice(IBuffer::SharedPtr buffer, std::size_t offset, std::size_t size) diff --git a/cpp/tensorrt_llm/runtime/iTensor.cpp b/cpp/tensorrt_llm/runtime/iTensor.cpp index f245863c0..05e3a1ef0 100644 --- a/cpp/tensorrt_llm/runtime/iTensor.cpp +++ b/cpp/tensorrt_llm/runtime/iTensor.cpp @@ -89,7 +89,12 @@ ITensor::UniquePtr ITensor::wrap(void* data, nvinfer1::DataType type, nvinfer1:: new GenericTensor( shape, capacity, type, GpuBorrowingAllocator(data, capacityInBytes))); break; - default: TLLM_THROW("Unknown memory type"); + case MemoryType::kUVM: + result.reset( // NOLINT(modernize-make-unique) + new GenericTensor( + shape, capacity, type, ManagedBorrowingAllocator(data, capacityInBytes))); + break; + default: TLLM_THROW("Invalid memory type."); break; } return result; } diff --git a/cpp/tensorrt_llm/runtime/tllmBuffers.h b/cpp/tensorrt_llm/runtime/tllmBuffers.h index e0fed5728..088a05eee 100644 --- a/cpp/tensorrt_llm/runtime/tllmBuffers.h +++ b/cpp/tensorrt_llm/runtime/tllmBuffers.h @@ -240,6 +240,7 @@ class BorrowingAllocator : public BaseAllocator, using CpuBorrowingAllocator = BorrowingAllocator; using GpuBorrowingAllocator = BorrowingAllocator; using PinnedBorrowingAllocator = BorrowingAllocator; +using ManagedBorrowingAllocator = BorrowingAllocator; // using UVMBorrowingAllocator = BorrowingAllocator; diff --git a/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp b/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp index 91d8e5a49..9e454c57c 100644 --- a/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp +++ b/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp @@ -144,12 +144,13 @@ void FtDynamicDecode::forward(th::Tensor& logits, // (batch_size, beam_width, th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop, th::optional finished_input, th::optional finished_output, th::optional sequence_lengths_opt, th::optional cum_log_probs_opt, - th::optional output_log_probs_opt, th::optional parent_ids_opt, - th::optional tgt_cache_indirection_opt, th::optional beam_hyps_output_ids_tgt_opt, - th::optional beam_hyps_sequence_lengths_tgt_opt, th::optional beam_hyps_cum_log_probs_opt, - th::optional beam_hyps_normed_scores_opt, th::optional beam_hyps_log_probs_opt, - th::optional beam_hyps_min_normed_scores_opt, th::optional beam_hyps_num_beams_opt, - th::optional beam_hyps_is_done_opt, bool use_beam_hyps) + th::optional output_log_probs_opt, th::optional output_log_probs_tiled_opt, + th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, + th::optional beam_hyps_output_ids_tgt_opt, th::optional beam_hyps_sequence_lengths_tgt_opt, + th::optional beam_hyps_cum_log_probs_opt, th::optional beam_hyps_normed_scores_opt, + th::optional beam_hyps_log_probs_opt, th::optional beam_hyps_min_normed_scores_opt, + th::optional beam_hyps_num_beams_opt, th::optional beam_hyps_is_done_opt, + bool use_beam_hyps) { auto const& logits_converted = convert_tensor(logits); @@ -190,6 +191,7 @@ void FtDynamicDecode::forward(th::Tensor& logits, // (batch_size, beam_width, safeUpdate(parent_ids_opt, outputParams.parent_ids); safeUpdate(cum_log_probs_opt, outputParams.cum_log_probs); safeUpdate(output_log_probs_opt, outputParams.output_log_probs); + safeUpdate(output_log_probs_tiled_opt, outputParams.output_log_probs_tiled); safeUpdate(tgt_cache_indirection_opt, outputParams.tgt_cache_indirection); if (use_beam_hyps) @@ -297,12 +299,12 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max th::optional finished_output, th::optional seuqence_lengths_opt, // length of the current sequences. th::optional cum_log_probs_opt, th::optional output_log_probs_opt, - th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, - th::optional beam_hyps_output_ids_tgt_opt, th::optional beam_hyps_sequence_lengths_tgt_opt, - th::optional beam_hyps_cum_log_probs_opt, th::optional beam_hyps_normed_scores_opt, - th::optional beam_hyps_log_probs_opt, th::optional beam_hyps_min_normed_scores_opt, - th::optional beam_hyps_num_beams_opt, th::optional beam_hyps_is_done_opt, - bool use_beam_hyps) + th::optional output_log_probs_tiled_opt, th::optional parent_ids_opt, + th::optional tgt_cache_indirection_opt, th::optional beam_hyps_output_ids_tgt_opt, + th::optional beam_hyps_sequence_lengths_tgt_opt, th::optional beam_hyps_cum_log_probs_opt, + th::optional beam_hyps_normed_scores_opt, th::optional beam_hyps_log_probs_opt, + th::optional beam_hyps_min_normed_scores_opt, th::optional beam_hyps_num_beams_opt, + th::optional beam_hyps_is_done_opt, bool use_beam_hyps) { // Input Arguments: // logits: [batch_size, beam_width, vocab_size_padded], T @@ -349,6 +351,7 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max CHECK_OPTIONAL_INPUT(seuqence_lengths_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(cum_log_probs_opt, torch::kFloat32); CHECK_OPTIONAL_INPUT(output_log_probs_opt, torch::kFloat32); + CHECK_OPTIONAL_INPUT(output_log_probs_tiled_opt, torch::kFloat32); CHECK_OPTIONAL_INPUT(parent_ids_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(tgt_cache_indirection_opt, torch::kInt32); @@ -363,7 +366,7 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max static_cast(max_bad_words_len), no_repeat_ngram_size_opt, src_cache_indirection_opt, // Outputs output_token_ids, newTokens, should_stop, finished_input, finished_output, seuqence_lengths_opt, - cum_log_probs_opt, output_log_probs_opt, parent_ids_opt, tgt_cache_indirection_opt, + cum_log_probs_opt, output_log_probs_opt, output_log_probs_tiled_opt, parent_ids_opt, tgt_cache_indirection_opt, beam_hyps_output_ids_tgt_opt, beam_hyps_sequence_lengths_tgt_opt, beam_hyps_cum_log_probs_opt, beam_hyps_normed_scores_opt, beam_hyps_log_probs_opt, beam_hyps_min_normed_scores_opt, beam_hyps_num_beams_opt, beam_hyps_is_done_opt, use_beam_hyps); diff --git a/cpp/tensorrt_llm/thop/dynamicDecodeOp.h b/cpp/tensorrt_llm/thop/dynamicDecodeOp.h index 410bf7d7c..0c21ffb2d 100644 --- a/cpp/tensorrt_llm/thop/dynamicDecodeOp.h +++ b/cpp/tensorrt_llm/thop/dynamicDecodeOp.h @@ -48,8 +48,9 @@ class IFtDynamicDecode th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop, th::optional finished_input, th::optional finished_output, th::optional sequence_lengths_opt, th::optional cum_log_probs_opt, - th::optional output_log_probs_opt, th::optional parent_ids_opt, - th::optional tgt_cache_indirection_opt, th::optional beam_hyps_output_ids_tgt_opt, + th::optional output_log_probs_opt, th::optional output_log_probs_tiled_opt, + th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, + th::optional beam_hyps_output_ids_tgt_opt, th::optional beam_hyps_sequence_lengths_tgt_opt, th::optional beam_hyps_cum_log_probs_opt, th::optional beam_hyps_normed_scores_opt, th::optional beam_hyps_log_probs_opt, th::optional beam_hyps_min_normed_scores_opt, @@ -87,8 +88,9 @@ class FtDynamicDecode : public IFtDynamicDecode th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop, th::optional finished_input, th::optional finished_output, th::optional sequence_lengths_opt, th::optional cum_log_probs_opt, - th::optional output_log_probs_opt, th::optional parent_ids_opt, - th::optional tgt_cache_indirection_opt, th::optional beam_hyps_output_ids_tgt_opt, + th::optional output_log_probs_opt, th::optional output_log_probs_tiled_opt, + th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, + th::optional beam_hyps_output_ids_tgt_opt, th::optional beam_hyps_sequence_lengths_tgt_opt, th::optional beam_hyps_cum_log_probs_opt, th::optional beam_hyps_normed_scores_opt, th::optional beam_hyps_log_probs_opt, th::optional beam_hyps_min_normed_scores_opt, @@ -134,8 +136,8 @@ class DynamicDecodeOp : public th::jit::CustomClassHolder th::optional finished_output, th::optional sequence_lengths_opt, // length of the current sequences. th::optional cum_log_probs_opt, th::optional output_log_probs_opt, - th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, - th::optional beam_hyps_output_ids_tgt_opt, + th::optional output_log_probs_tiled_opt, th::optional parent_ids_opt, + th::optional tgt_cache_indirection_opt, th::optional beam_hyps_output_ids_tgt_opt, th::optional beam_hyps_sequence_lengths_tgt_opt, th::optional beam_hyps_cum_log_probs_opt, th::optional beam_hyps_normed_scores_opt, th::optional beam_hyps_log_probs_opt, th::optional beam_hyps_min_normed_scores_opt, diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 371f0858e..f08fa99f9 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -69,6 +69,7 @@ add_gtest(tllmBuffersTest runtime/tllmBuffersTest.cpp) add_gtest(bufferManagerTest runtime/bufferManagerTest.cpp) add_gtest(runtimeKernelTest runtime/runtimeKernelTest.cpp) add_gtest(samplingTest runtime/samplingTest.cpp) +add_gtest(samplingConfigTest runtime/samplingConfigTest.cpp) add_gtest(iTensorTest runtime/iTensorTest.cpp) add_gtest(worldConfigTest runtime/worldConfigTest.cpp) add_gtest(medusaModuleTest runtime/medusaModuleTest.cpp) @@ -101,3 +102,7 @@ if(BUILD_BATCH_MANAGER) add_subdirectory(batch_manager) endif() endif() + +if(BUILD_EXECUTOR) + add_subdirectory(executor) +endif() diff --git a/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp b/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp index 2250c7ddf..39a8db935 100644 --- a/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp +++ b/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp @@ -43,6 +43,7 @@ class AirTopPSamplingKernelTest : public SamplingKernelTest private: size_t getWorkspaceSize(const SamplingKernelTestParam& params) override { + auto const maxBatchSize = 2 * params.batchSize; size_t sampling_workspace_size_; tk::invokeAirTopPSampling(nullptr, sampling_workspace_size_, nullptr, // output_ids @@ -52,7 +53,7 @@ class AirTopPSamplingKernelTest : public SamplingKernelTest nullptr, // cum_log_probs nullptr, // output_log_probs nullptr, // log_probs) - this->mCurandStatesDevice, params.batchSize, params.vocabSize, nullptr, this->mMaxTopP, + this->mCurandStatesDevice, params.batchSize, maxBatchSize, params.vocabSize, nullptr, this->mMaxTopP, this->mStream->get(), 0, nullptr, nullptr); return sampling_workspace_size_; } @@ -65,6 +66,7 @@ class AirTopPSamplingKernelTest : public SamplingKernelTest int smCnt; TLLM_CUDA_CHECK(cudaGetDevice(&dev)); TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCnt, cudaDevAttrMultiProcessorCount, dev)); + auto const maxBatchSize = 2 * params.batchSize; int blockNum = tk::calcAirTopPBlockNum(params.batchSize, params.vocabSize, smCnt); // Perform batched TopP sampling @@ -79,8 +81,8 @@ class AirTopPSamplingKernelTest : public SamplingKernelTest // log-prob if cum_log_probs or output_log_probs are // provided. It's because the sampling layer already // preprocesses log_prob_buf when those are provided. - bufferCast(*this->mProbsDevice), this->mCurandStatesDevice, params.batchSize, params.vocabSize, - bufferCast(*this->mEndIdsDevice), this->mMaxTopP, + bufferCast(*this->mProbsDevice), this->mCurandStatesDevice, params.batchSize, maxBatchSize, + params.vocabSize, bufferCast(*this->mEndIdsDevice), this->mMaxTopP, hasDiffRuntimeArgs ? bufferCast(*this->mTopPsDevice) : nullptr, this->mStream->get(), blockNum, bufferCast(*this->mSkipDecodeDevice), bufferCast(*this->mBatchSlots)); } diff --git a/cpp/tests/kernels/sampling/samplingTest.cpp b/cpp/tests/kernels/sampling/samplingTest.cpp index 248d5de59..281194d84 100644 --- a/cpp/tests/kernels/sampling/samplingTest.cpp +++ b/cpp/tests/kernels/sampling/samplingTest.cpp @@ -64,7 +64,7 @@ void SamplingKernelTest::allocateBuffers( mCumLogProbsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); mOutputLogProbsDevice - = mBufferManager->gpu(ITensor::makeShape({maxBatchSize, outputLen}), nvinfer1::DataType::kFLOAT); + = mBufferManager->gpu(ITensor::makeShape({maxSeqLen, maxBatchSize}), nvinfer1::DataType::kFLOAT); mZeroParentIdsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize, maxSeqLen}), nvinfer1::DataType::kINT32); diff --git a/cpp/tests/kernels/sampling/samplingTopKTest.cpp b/cpp/tests/kernels/sampling/samplingTopKTest.cpp index 00318b4f7..54fac9fb9 100644 --- a/cpp/tests/kernels/sampling/samplingTopKTest.cpp +++ b/cpp/tests/kernels/sampling/samplingTopKTest.cpp @@ -42,16 +42,18 @@ class TopKSamplingKernelTest : public SamplingKernelTest size_t getWorkspaceSize(const SamplingKernelTestParam& params) override { + auto const maxBatchSize = 2 * params.batchSize; size_t workspaceSize; tk::invokeTopKSampling(nullptr, workspaceSize, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, this->mMaxTopK, 1.0f, params.vocabSize, nullptr, nullptr, this->mStream->get(), params.batchSize, - nullptr, true, false); + maxBatchSize, nullptr, true, false); return workspaceSize; } void callTestedFunction(const SamplingKernelTestParam& params, bool hasDiffRuntimeArgs, size_t workspaceSize, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override { + auto const maxBatchSize = 2 * params.batchSize; // Perform batched TopK sampling tk::invokeBatchTopKSampling(workspaceDevice->data(), workspaceSize, // Note that the kernel needs vocab probs instead of @@ -69,7 +71,7 @@ class TopKSamplingKernelTest : public SamplingKernelTest hasDiffRuntimeArgs ? bufferCast(*this->mTopKsDevice) : nullptr, params.topP, hasDiffRuntimeArgs ? bufferCast(*this->mTopPsDevice) : nullptr, params.vocabSize, bufferCast(*this->mEndIdsDevice), bufferCast(*this->mBatchSlots), this->mStream->get(), - params.batchSize, bufferCast(*this->mSkipDecodeDevice), params.normalizeLogProbs, + params.batchSize, maxBatchSize, bufferCast(*this->mSkipDecodeDevice), params.normalizeLogProbs, params.logitsHasProbs); } }; diff --git a/cpp/tests/kernels/sampling/samplingTopPTest.cpp b/cpp/tests/kernels/sampling/samplingTopPTest.cpp index 6636e40fd..2e7b8c555 100644 --- a/cpp/tests/kernels/sampling/samplingTopPTest.cpp +++ b/cpp/tests/kernels/sampling/samplingTopPTest.cpp @@ -43,6 +43,7 @@ class TopPSamplingKernelTest : public SamplingKernelTest private: size_t getWorkspaceSize(const SamplingKernelTestParam& params) override { + auto const maxBatchSize = 2 * params.batchSize; size_t workspaceSize; size_t cubTempStorageSize; tk::invokeBatchTopPSampling(nullptr, // workspace @@ -55,7 +56,7 @@ class TopPSamplingKernelTest : public SamplingKernelTest nullptr, // output_log_probs nullptr, // log_probs bufferCast(*this->mTopPIdValsDevice), bufferCast(*this->mEndOffsetsDevice), - bufferCast(*this->mBeginOffsetsDevice), this->mCurandStatesDevice, params.batchSize, + bufferCast(*this->mBeginOffsetsDevice), this->mCurandStatesDevice, params.batchSize, maxBatchSize, params.vocabSize, nullptr, this->mMaxTopP, bufferCast(*this->mTopPsDevice), this->mStream->get(), nullptr, nullptr); return workspaceSize; @@ -64,6 +65,7 @@ class TopPSamplingKernelTest : public SamplingKernelTest void callTestedFunction(const SamplingKernelTestParam& params, bool hasDiffRuntimeArgs, size_t workspaceSize, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override { + auto const maxBatchSize = 2 * params.batchSize; size_t cubTempStorageSize; tk::invokeBatchTopPSampling(nullptr, // workspace workspaceSize, cubTempStorageSize, @@ -75,7 +77,7 @@ class TopPSamplingKernelTest : public SamplingKernelTest nullptr, // output_log_probs nullptr, // log_probs bufferCast(*this->mTopPIdValsDevice), bufferCast(*this->mEndOffsetsDevice), - bufferCast(*this->mBeginOffsetsDevice), this->mCurandStatesDevice, params.batchSize, + bufferCast(*this->mBeginOffsetsDevice), this->mCurandStatesDevice, params.batchSize, maxBatchSize, params.vocabSize, nullptr, this->mMaxTopP, bufferCast(*this->mTopPsDevice), this->mStream->get(), nullptr, nullptr); @@ -98,8 +100,9 @@ class TopPSamplingKernelTest : public SamplingKernelTest // preprocesses log_prob_buf when those are provided. bufferCast(*this->mProbsDevice), bufferCast(*this->mTopPIdValsDevice), bufferCast(*this->mEndOffsetsDevice), bufferCast(*this->mBeginOffsetsDevice), - this->mCurandStatesDevice, params.batchSize, params.vocabSize, bufferCast(*this->mEndIdsDevice), - this->mMaxTopP, hasDiffRuntimeArgs ? bufferCast(*this->mTopPsDevice) : nullptr, this->mStream->get(), + this->mCurandStatesDevice, params.batchSize, maxBatchSize, params.vocabSize, + bufferCast(*this->mEndIdsDevice), this->mMaxTopP, + hasDiffRuntimeArgs ? bufferCast(*this->mTopPsDevice) : nullptr, this->mStream->get(), bufferCast(*this->mSkipDecodeDevice), bufferCast(*this->mBatchSlots)); } }; diff --git a/cpp/tests/layers/dynamicDecodeLayerTest.cpp b/cpp/tests/layers/dynamicDecodeLayerTest.cpp index 889fc1fbe..e06eb9601 100644 --- a/cpp/tests/layers/dynamicDecodeLayerTest.cpp +++ b/cpp/tests/layers/dynamicDecodeLayerTest.cpp @@ -15,6 +15,7 @@ */ #include "tests/layers/dynamicDecodeLayerTest.h" +#include namespace tensorrt_llm::tests::layers::sampling { @@ -25,7 +26,7 @@ namespace tensorrt_llm::tests::layers::sampling // - finished sum // - max length // - repeat n grams -// - output logits +// - padded vocab // - beam search using namespace tensorrt_llm::runtime; @@ -129,17 +130,19 @@ void DynamicDecodeLayerTest::setup(uint64_t seed, SamplingParams const& param // clang-format off - // prob = (0.0, 0.0, 0.0, 0.0, 0.4, 0.3, 0.2, 0.1) + // prob = (0.0, 0.0, 0.0, 0.0, 0.4, 0.3, 0.2, 0.1, 0.0) mTestLogitsInit = { - -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, // step 0 - -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // step 1 - -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, // step 2 - -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX // step 3 + -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, // step 0 + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, // step 1 + -FLT_MAX, -FLT_MAX, -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, // step 2 + -0.9163, -1.2040, -1.6094, -2.3026, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX // step 3 }; // clang-format on mLogitsDevice = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mBeamWidth, mVocabSizePadded}), dataType); + mRuntimeLogitsHost + = mBufferManager->pinned(ITensor::makeShape({mBatchSize, mBeamWidth, mVocabSizePadded}), dataType); mSeqLengthsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mContextLengthDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); @@ -154,6 +157,13 @@ void DynamicDecodeLayerTest::setup(uint64_t seed, SamplingParams const& param mEmbeddingBiasHost = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize, mVocabSizePadded}), dataType); mEmbeddingBiasDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mVocabSizePadded}), dataType); + mRefLogProbsHost + = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kFLOAT); + mOutputLogProbsDevice + = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kFLOAT); + mOutputLogProbsTiledDevice + = mBufferManager->gpu(ITensor::makeShape({mMaxSeqLen, mMaxBatchSize}), nvinfer1::DataType::kFLOAT); + mCumLogProbsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kFLOAT); mMaxBadWordsLen = getMaxWordsLen(params.badWords); @@ -177,6 +187,9 @@ void DynamicDecodeLayerTest::setup(uint64_t seed, SamplingParams const& param trk::invokeFill(*mOutputIdsDevice, int32_t{0}, *mStream); trk::invokeFill(*mEmbeddingBiasDevice, T{0.0f}, *mStream); trk::invokeFill(*mCumLogProbsDevice, float{0.0f}, *mStream); + trk::invokeFill(*mOutputLogProbsDevice, float{0.0f}, *mStream); + trk::invokeFill(*mOutputLogProbsTiledDevice, float{0.0f}, *mStream); + trk::invokeFill(*mRefLogProbsHost, float{0.0f}, *mStream); trk::invokeFill(*mEndIdsDevice, int32_t{mEndId}, *mStream); auto batchSlotsPtr = bufferCast(*mBatchSlots); @@ -229,6 +242,7 @@ void DynamicDecodeLayerTest::setup(uint64_t seed, SamplingParams const& param = params.minTopP.size() ? std::make_optional>(params.minTopP) : std::nullopt; setupParams.top_p_reset_ids = params.topPResetIds.size() ? std::make_optional>(params.topPResetIds) : std::nullopt; + setupParams.normalize_log_probs = {false}; initXWordsTensors(batchSlotsPtr, bufferCast(*mBadWords), reinterpret_cast(bufferCast(*mBadWordsPtrs)), bufferCast(*mBadWordsLens), @@ -350,10 +364,12 @@ typename DynamicDecodeLayer::OutputParams DynamicDecodeLayerTest::createOu outputParams.newTokens = tcc::toTllmTensor(*mNewTokens); + outputParams.output_log_probs = tcc::toTllmTensor(*mOutputLogProbsDevice); + + outputParams.output_log_probs_tiled = tcc::toTllmTensor(*mOutputLogProbsTiledDevice); + // TODO(nkorobov): extend to // std::optional parent_ids; - // std::optional output_log_probs_tiled; - // std::optional output_log_probs; // std::optional tgt_cache_indirection; // std::shared_ptr beamHypotheses; @@ -375,7 +391,7 @@ void DynamicDecodeLayerTest::batchCopy(int32_t step) } template -bool DynamicDecodeLayerTest::checkResult(int32_t* outputIds, std::vector>& expectedIds, +bool DynamicDecodeLayerTest::checkResult(int32_t* outputIds, std::vector> const& expectedIds, int32_t* seqLens, int32_t leadingDim, int32_t stride, int32_t step) { assert(expectedIds.size() == leadingDim * stride); @@ -415,11 +431,35 @@ bool DynamicDecodeLayerTest::checkResult(int32_t* outputIds, std::vector +void DynamicDecodeLayerTest::fillRefLogits( + int32_t const* seqLenHost, std::vector> const& expectedOutputIds, SizeType step) +{ + auto const batchSlotsPtr = bufferCast(*mBatchSlots); + auto const runtimeLogitsHost = bufferCast(*mRuntimeLogitsHost); + for (SizeType bi = 0; bi < mBatchBeam; ++bi) + { + auto const batchSlot = batchSlotsPtr[bi]; + if (seqLenHost[batchSlot] <= step) + { + continue; + } + auto& expectedSet = expectedOutputIds[step * mBatchBeam + bi]; + TLLM_CHECK(expectedSet.size() == 1); + auto expectedToken = *expectedSet.begin(); + bufferCast(*mRefLogProbsHost)[batchSlot * mMaxSeqLen + step] + = logf(runtimeLogitsHost[bi * mVocabSizePadded + expectedToken]); + } +} + template void DynamicDecodeLayerTest::runTestImpl( - std::vector> expectedOutputIds, SamplingParams const& params, int32_t endId) + std::vector> const& expectedOutputIds, SamplingParams const& params, int32_t endId) { - mEndId = endId; + mEndId = endId == -1 ? mVocabSize - 1 : endId; + + bool greedySearch + = std::all_of(expectedOutputIds.begin(), expectedOutputIds.end(), [](auto v) { return v.size() == 1; }); for (uint64_t seed = 0; seed < mMaxSeed; ++seed) { setup(seed, params); @@ -439,6 +479,14 @@ void DynamicDecodeLayerTest::runTestImpl( auto const seqLenHost = mBufferManager->copyFrom(*mSeqLengthsDevice, tensorrt_llm::runtime::MemoryType::kCPU); auto const logitsHost = mBufferManager->copyFrom(*mLogitsDevice, tensorrt_llm::runtime::MemoryType::kCPU); + mBufferManager->copy( + mDecodeLayer->getRuntimeLogitsDevice(), *mRuntimeLogitsHost, tensorrt_llm::runtime::MemoryType::kGPU); + mStream->synchronize(); + + if (greedySearch) + { + fillRefLogits(bufferCast(*seqLenHost), expectedOutputIds, step); + } { bool passed = checkResult(bufferCast(*newTokensHost), expectedOutputIds, @@ -462,24 +510,35 @@ void DynamicDecodeLayerTest::runTestImpl( mStream->synchronize(); - const auto outputIdsHost = mBufferManager->copyFrom(*mOutputIdsDevice, tensorrt_llm::runtime::MemoryType::kCPU); - const auto seqLenHost = mBufferManager->copyFrom(*mSeqLengthsDevice, tensorrt_llm::runtime::MemoryType::kCPU); + auto const outputIdsHost = mBufferManager->copyFrom(*mOutputIdsDevice, tensorrt_llm::runtime::MemoryType::kCPU); + auto const seqLenHost = mBufferManager->copyFrom(*mSeqLengthsDevice, tensorrt_llm::runtime::MemoryType::kCPU); + auto const logProbsHost + = mBufferManager->copyFrom(*mOutputLogProbsDevice, tensorrt_llm::runtime::MemoryType::kCPU); + + { + bool passed = checkResult(bufferCast(*outputIdsHost), expectedOutputIds, + bufferCast(*seqLenHost), mMaxSeqLen, mBatchBeam, /* step */ 0); + EXPECT_TRUE(passed) << "Output Ids check failed at seed " << seed; + if (!passed) + { + std::stringstream ss; + ss << "Actual output ids:" << std::endl << *outputIdsHost; + TLLM_LOG_DEBUG(ss.str()); + } + } - bool passed = checkResult(bufferCast(*outputIdsHost), expectedOutputIds, - bufferCast(*seqLenHost), mMaxSeqLen, mBatchBeam, /* step */ 0); - EXPECT_TRUE(passed) << "Output Ids check failed at seed " << seed; - if (!passed) + if (greedySearch) { - std::stringstream ss; - ss << "Actual output ids:" << std::endl << *outputIdsHost; - TLLM_LOG_DEBUG(ss.str()); + bool passed = compareValues( + bufferCast(*logProbsHost), bufferCast(*mRefLogProbsHost), mMaxSeqLen * mMaxBatchSize); + EXPECT_TRUE(passed) << "Log probs check failed at seed " << seed; } } } template void DynamicDecodeLayerTest::runTest( - std::vector> expectedOutputIds, SamplingParams const& params, int32_t endId) + std::vector> const& expectedOutputIds, SamplingParams const& params, int32_t endId) { TLLM_LOG_DEBUG("Run test with linear logits"); mUseLogitsVec = false; diff --git a/cpp/tests/layers/dynamicDecodeLayerTest.h b/cpp/tests/layers/dynamicDecodeLayerTest.h index f62119617..a82e81f0f 100644 --- a/cpp/tests/layers/dynamicDecodeLayerTest.h +++ b/cpp/tests/layers/dynamicDecodeLayerTest.h @@ -70,7 +70,7 @@ class DynamicDecodeLayerTest : public testing::Test int32_t const mMaxBatchSize = 2 * mBatchSize; int32_t const mBeamWidth = 1; int32_t const mBatchBeam = mBatchSize * mBeamWidth; - int32_t const mVocabSize = 8; + int32_t const mVocabSize = 9; int32_t const mVocabSizePadded = mVocabSize; int32_t const mMaxInputLen = 0; // has no effect. @@ -82,6 +82,7 @@ class DynamicDecodeLayerTest : public testing::Test bool mUseLogitsVec = false; TensorPtr mLogitsDevice; + TensorPtr mRuntimeLogitsHost; TensorPtr mLogitsRefHost; TensorPtr mContextLengthDevice; TensorPtr mSeqLengthsDevice; @@ -103,6 +104,10 @@ class DynamicDecodeLayerTest : public testing::Test TensorPtr mEmbeddingBiasHost; TensorPtr mEmbeddingBiasDevice; + TensorPtr mRefLogProbsHost; + TensorPtr mOutputLogProbsDevice; + TensorPtr mOutputLogProbsTiledDevice; + TensorPtr mCumLogProbsDevice; std::vector mLogitsVec; @@ -134,14 +139,18 @@ class DynamicDecodeLayerTest : public testing::Test typename tensorrt_llm::layers::DynamicDecodeLayer::OutputParams createOutputTensors(); void batchCopy(int32_t step); - bool checkResult(int32_t* outputIds, std::vector>& expectedIds, int32_t* seqLens, + bool checkResult(int32_t* outputIds, std::vector> const& expectedIds, int32_t* seqLens, int32_t leadingDim, int32_t stride, int32_t step); void runTestImpl( - std::vector> expectedOutputIds, SamplingParams const& params, int32_t endId = -1); + std::vector> const& expectedOutputIds, SamplingParams const& params, int32_t endId = -1); + + void fillRefLogits( + int32_t const* seqLenHost, std::vector> const& expectedOutputIds, int32_t step); public: - void runTest(std::vector> expectedOutputIds, SamplingParams const& params, int32_t endId = -1); + void runTest( + std::vector> const& expectedOutputIds, SamplingParams const& params, int32_t endId = -1); }; typedef testing::Types FloatAndHalfTypes; diff --git a/cpp/tests/runtime/samplingConfigTest.cpp b/cpp/tests/runtime/samplingConfigTest.cpp new file mode 100644 index 000000000..e91293abc --- /dev/null +++ b/cpp/tests/runtime/samplingConfigTest.cpp @@ -0,0 +1,77 @@ +/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, + * Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a + * copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/runtime/samplingConfig.h" +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/types.h" +#include +#include + +using ::testing::_; +using ::testing::Invoke; + +namespace tr = tensorrt_llm::runtime; +namespace tc = tensorrt_llm::common; +namespace texec = tensorrt_llm::executor; + +TEST(samplingConfigTest, validInputs) +{ + { + texec::SamplingConfig execSamplingCfg(1); + tr::SamplingConfig samplingCfg(execSamplingCfg, std::nullopt); + EXPECT_EQ(samplingCfg.beamWidth, execSamplingCfg.getBeamWidth()); + EXPECT_EQ(samplingCfg.draftAcceptanceThreshold, std::nullopt); + } + { + texec::SamplingConfig execSamplingCfg(1); + texec::SpeculativeDecodingConfig specCfg({1}, std::nullopt, 0.5); + tr::SamplingConfig samplingCfg(execSamplingCfg, specCfg); + EXPECT_EQ(samplingCfg.beamWidth, execSamplingCfg.getBeamWidth()); + EXPECT_TRUE(samplingCfg.draftAcceptanceThreshold.has_value()); + EXPECT_THAT(samplingCfg.draftAcceptanceThreshold.value(), testing::ElementsAre(0.5f)); + } + { + texec::SizeType topK = 1; + texec::FloatType topP = 0.5; + texec::FloatType topPMin = 0.1; + texec::SizeType topPResetIds = 1; + texec::FloatType topPDecay = 0.6; + uint64_t randomSeed = 7777; + texec::FloatType temperature = 0.245; + texec::SizeType minLength = 1234; + texec::FloatType beamSearchDiversityRate = 0.9999; + texec::FloatType repetitionPenalty = 0.11; + texec::FloatType presencePenalty = 0.22; + texec::FloatType frequencyPenalty = 0.33; + texec::FloatType lengthPenalty = 0.44; + + texec::SamplingConfig execSamplingCfg(1, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature, + minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty); + texec::SpeculativeDecodingConfig specCfg({1}, std::nullopt, 0.5); + tr::SamplingConfig samplingCfg(execSamplingCfg, specCfg); + EXPECT_EQ(samplingCfg.beamWidth, execSamplingCfg.getBeamWidth()); + EXPECT_THAT(samplingCfg.draftAcceptanceThreshold.value(), testing::ElementsAre(0.5f)); + EXPECT_THAT(samplingCfg.temperature.value(), testing::ElementsAre(temperature)); + EXPECT_THAT(samplingCfg.minLength.value(), testing::ElementsAre(minLength)); + EXPECT_THAT(samplingCfg.repetitionPenalty.value(), testing::ElementsAre(repetitionPenalty)); + EXPECT_THAT(samplingCfg.presencePenalty.value(), testing::ElementsAre(presencePenalty)); + EXPECT_THAT(samplingCfg.frequencyPenalty.value(), testing::ElementsAre(frequencyPenalty)); + EXPECT_THAT(samplingCfg.topK.value(), testing::ElementsAre(topK)); + EXPECT_THAT(samplingCfg.topP.value(), testing::ElementsAre(topP)); + EXPECT_THAT(samplingCfg.randomSeed.value(), testing::ElementsAre(randomSeed)); + EXPECT_THAT(samplingCfg.topPMin.value(), testing::ElementsAre(topPMin)); + EXPECT_THAT(samplingCfg.topPResetIds.value(), testing::ElementsAre(topPResetIds)); + EXPECT_THAT(samplingCfg.beamSearchDiversityRate.value(), testing::ElementsAre(beamSearchDiversityRate)); + EXPECT_THAT(samplingCfg.lengthPenalty.value(), testing::ElementsAre(lengthPenalty)); + } +} diff --git a/examples/enc_dec/README.md b/examples/enc_dec/README.md index 2b9ba072c..55b87fc9a 100644 --- a/examples/enc_dec/README.md +++ b/examples/enc_dec/README.md @@ -113,6 +113,7 @@ python build.py --model_type t5 \ --max_beam_width 3 # Example 4: build bart-large-cnn using a single GPU, FP32, running greedy search +# Note: non-T5 models can enable FMHA for the encoder part, for FP16/BF16 python build.py --model_type bart \ --weight_dir tmp/trt_models/bart-large-cnn/tp1 \ -o tmp/trt_engines/bart-large-cnn/1-gpu \ @@ -120,6 +121,7 @@ python build.py --model_type bart \ --remove_input_padding \ --use_bert_attention_plugin \ --use_gpt_attention_plugin \ + --enable_context_fmha \ --use_gemm_plugin \ --dtype float32 \ --max_beam_width 1 @@ -237,12 +239,14 @@ pushd tmp && (git clone https://github.com/facebookresearch/fairseq.git || true) python nmt/convert.py -i tmp/fairseq_models/wmt14 -o tmp/trt_models/wmt14 --weight_data_type float32 --inference_tensor_para_size 1 # Build TensorRT engine(s) +# Note: non-T5 models can enable FMHA for the encoder part, although only FP16/BF16 precisions are valid python build.py --model_type nmt \ --weight_dir tmp/trt_models/wmt14/tp1/ \ -o tmp/trt_engines/wmt14/1-gpu \ --engine_name wmt14 \ --use_bert_attention_plugin \ --use_gpt_attention_plugin \ + --enable_context_fmha \ --dtype float32 \ --max_beam_width 1 diff --git a/examples/enc_dec/build.py b/examples/enc_dec/build.py index 497aced20..ff50c38f7 100644 --- a/examples/enc_dec/build.py +++ b/examples/enc_dec/build.py @@ -27,6 +27,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.network import net_guard +from tensorrt_llm.plugin.plugin import ContextFMHAType from t5.weight import parse_t5_config, load_from_hf_t5, load_from_binary_t5 # isort:skip from bart.weight import parse_bart_config, load_from_binary_bart # isort:skip @@ -185,6 +186,12 @@ def parse_arguments(component): parser.add_argument('--enable_qk_half_accum', default=False, action='store_true') + parser.add_argument('--enable_context_fmha', + default=False, + action='store_true') + parser.add_argument('--enable_context_fmha_fp32_acc', + default=False, + action='store_true') parser.add_argument('--builder_opt', type=int, default=None) parser.add_argument('--remove_input_padding', default=False, @@ -404,6 +411,14 @@ def build_rank_engine(builder: Builder, network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) if args.enable_qk_half_accum: network.plugin_config.enable_qk_half_accum() + assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc) + if args.enable_context_fmha and not args.relative_attention: + logger.warning("Only non-T5 enc-dec models support FMHA") + network.plugin_config.set_context_fmha(ContextFMHAType.enabled) + if args.enable_context_fmha_fp32_acc and not args.relative_attention: + logger.warning("Only non-T5 enc-dec models support FMHA") + network.plugin_config.set_context_fmha( + ContextFMHAType.enabled_with_fp32_acc) if args.remove_input_padding: network.plugin_config.enable_remove_input_padding() if args.use_lookup_plugin: diff --git a/examples/enc_dec/run.py b/examples/enc_dec/run.py index 91e487563..aa4a24394 100644 --- a/examples/enc_dec/run.py +++ b/examples/enc_dec/run.py @@ -417,6 +417,7 @@ def generate( prompt_tasks=None, prompt_vocab_size=None, attention_mask=None, + time_encoder=False, ): ## ensure all externally provided tensors are on the correct device. encoder_input_ids = encoder_input_ids.to(self.device) @@ -436,6 +437,8 @@ def generate( if not self.skip_encoder: logger.info(f"Rank {self.runtime_rank} Running encoder engine ...") + if time_encoder: + tik = time.time() encoder_output = self.encoder_run( encoder_input_ids, encoder_input_lengths, @@ -445,6 +448,9 @@ def generate( prompt_tasks=prompt_tasks, prompt_vocab_size=prompt_vocab_size, attention_mask=attention_mask) + if time_encoder: + tok = time.time() + print(f"TRT-LLM Encoder time {(tok-tik)*1000}ms") else: encoder_output = prompt_embedding_table if encoder_input_ids.dim() > 1: @@ -472,7 +478,10 @@ def generate( sampling_config = SamplingConfig(end_id=eos_token_id, pad_id=pad_token_id, num_beams=num_beams, - min_length=1) + min_length=1, + return_dict=return_dict) + sampling_config.update(output_cum_log_probs=False, + output_log_probs=False) # decoder autoregressive generation self.decoder_session.setup( @@ -485,7 +494,7 @@ def generate( ) torch.cuda.synchronize() - output_ids = self.decoder_session.decode( + output = self.decoder_session.decode( decoder_input_ids, decoder_input_lengths, sampling_config, @@ -495,7 +504,7 @@ def generate( cross_attention_mask=cross_attention_mask) torch.cuda.synchronize() - return output_ids + return output def test_fairseq_models(args): @@ -545,8 +554,9 @@ def test_fairseq_models(args): inference_dtype = tllm_model.encoder_model_config.dtype + return_dict = False # when set return_dict=True, get outputs by key tik = time.time() - tllm_output_ids = tllm_model.generate( + tllm_output = tllm_model.generate( encoder_input_ids=input_ids, decoder_input_ids=decoder_input_ids, max_new_tokens=max_new_tokens, @@ -557,6 +567,12 @@ def test_fairseq_models(args): debug_mode=args.debug_mode, ) tok = time.time() + + if return_dict: + tllm_output_ids = tllm_output['output_ids'] + else: + tllm_output_ids = tllm_output + output_ids = tllm_output_ids[:, 0, :] output_ids = output_ids[output_ids != eos_token_id] fairseq_output_ids = fairseq_output_ids[fairseq_output_ids != eos_token_id] @@ -680,8 +696,10 @@ def test_fairseq_models(args): tllm_model = TRTLLMEncDecModel.from_engine(args.engine_name, args.engine_dir, debug_mode=args.debug_mode) + + return_dict = False # when set return_dict=True, get outputs by key tik = time.time() - tllm_output_ids = tllm_model.generate( + tllm_output = tllm_model.generate( encoder_input_ids=input_ids, decoder_input_ids=decoder_input_ids, max_new_tokens=max_new_tokens, @@ -690,10 +708,16 @@ def test_fairseq_models(args): pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, debug_mode=args.debug_mode, - return_dict=False, # when set return_dict=True, get outputs by key - attention_mask=tokenized_inputs.attention_mask) + return_dict=return_dict, + attention_mask=tokenized_inputs.attention_mask, + time_encoder=True) tok = time.time() + if return_dict: + tllm_output_ids = tllm_output['output_ids'] + else: + tllm_output_ids = tllm_output + inference_dtype = tllm_model.encoder_model_config.dtype if tensorrt_llm.mpi_rank() == 0: diff --git a/examples/enc_dec/t5/weight.py b/examples/enc_dec/t5/weight.py index c0015d685..53a77cac3 100644 --- a/examples/enc_dec/t5/weight.py +++ b/examples/enc_dec/t5/weight.py @@ -72,9 +72,8 @@ def parse_t5_config(config, component, args): args.hidden_act = config.get(component, 'dense_act_fn') args.gated_act = config.getboolean(component, 'is_gated_act') args.mlp_type = mlp_type_map['GatedMLP' if args.gated_act else 'MLP'] - args.relative_attention = config.getboolean(component, - 'relative_attention', - fallback=True) + args.relative_attention = config.get( + 'structure', 'position_embedding_type') == 'relative' args.num_buckets = config.getint(component, 'relative_attention_num_buckets') args.max_distance = config.getint(component, diff --git a/examples/gemma/README.md b/examples/gemma/README.md new file mode 100644 index 000000000..8db5c3d30 --- /dev/null +++ b/examples/gemma/README.md @@ -0,0 +1,734 @@ +# Run Gemma on TensorRT-LLM + +## Table Of Contents + +- [Run Gemma on TensorRT-LLM](#run-gemma-on-tensorrt-llm) + - [Table Of Contents](#table-of-contents) + - [Support Matrix](#support-matrix) + - [Common scripts](#common-scripts) + - [Convert checkpoint](#convert-checkpoint) + - [Build engine](#build-engine) + - [Run inference](#run-inference) + - [Specific commands](#specific-commands) + - [Run Gemma 2B](#run-gemma-2b) + - [Run inference under bfloat16 for keras checkpoint](#run-inference-under-bfloat16-for-keras-checkpoint) + - [Run inference under FP8 for keras checkpoint](#run-inference-under-fp8-for-keras-checkpoint) + - [Run inference under SmoothQuant for jax checkpoint](#run-2b-inference-under-smoothquant-for-jax-checkpoint) + - [Run inference under weight only for jax checkpoint](#run-inference-under-weight-only-for-jax-checkpoint) + - [Run inference under INT8 KV caches for jax checkpoint](#run-inference-under-int8-kv-caches-for-jax-checkpoint) + - [Run Gemma 7B](#run-gemma-7b) + - [Run inference under bfloat16 for torch checkpoint](#run-inference-under-bfloat16-for-torch-checkpoint) + - [Run inference under FP8 for jax checkpoint](#run-inference-under-fp8-for-jax-checkpoint) + - [Run inference under SmoothQuant for jax checkpoint](#run-7b-inference-under-smoothquant-for-jax-checkpoint) + - [Run inference under weight only for keras checkpoint](#run-inference-under-weight-only-for-keras-checkpoint) + - [Run inference under INT8 KV caches for keras checkpoint](#run-inference-under-int8-kv-caches-for-keras-checkpoint) + - [Run AMMO Quantization](#run-ammo-quantization) + - [Requirements](#requirements) + - [Quantize Checkpoints](#quantize-checkpoints) + - [Build Engines](#build-engines) + - [Accuracy Results on MMLU](#accuracy-results-on-mmlu) + +## Support Matrix + * FP32/FP16/BF16/INT8 Weight-Only/INT4 Weight-Only/SmoothQuant/FP8 + * For SmoothQuant, TRT-LLM only supports FP16 higher precision now. + * checkpoint type: Jax, Torch, Keras + * STRONGLY TYPED + * python runtime and triton backend + +## Common scripts + +### Convert checkpoint + +Users can use `convert_checkpoint.py` to convert the different source checkpoint to unified TensorRT-LLM checkpoint format. Users could set `--dtype` to determine the inference data type, and set the quantization options like `--enable_fp8`, `--fp8_kv_cache` `--use_smooth_quant`, `--calibrate_kv_cache` (for INT8 kv cache) and `--use-weight-only-with-precision` (weight only). Users could also control the source checkpoint type by `--ckpt-type`. Currently, supported checkpoint types are `jax`, `torch` and `keras`. + +```bash +CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/bf16/tp1/ + +python3 ./convert_checkpoint.py \ + --ckpt-type jax \ + --model-dir ${CKPT_PATH} \ + --dtype bfloat16 \ + --world-size 1 \ + --output-model-dir ${UNIFIED_CKPT_PATH} +``` + +### Build engine + +After getting checkpoint, we can use `trtllm-build` command to build TensorRT-LLM engines from TensorRT-LLM checkpoints. + +```bash +ENGINE_PATH=/tmp/gemma/2B/bf16/1-gpu/ +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin bfloat16 \ + --gpt_attention_plugin bfloat16 \ + --max_batch_size 8 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --output_dir ${ENGINE_PATH} +``` + +### Run inference + +We provide three examples to run inference `run.py`, `summarize.py` and `mmlu.py`. `run.py` only run inference with `input_text` and show the output. + +`summarize.py` runs summarization on [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset and evaluate the model by [ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)) scores and use the `ROUGE-1` score to validate the implementation. + +`mmlu.py` runs MMLU to evaluate the model by accuracy. + +Note that we need to download the dataset of MMLU first and the evaluation of MMLU requires more time. + +* run.py + +```bash +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model +python3 ../run.py --engine_dir ${ENGINE_PATH} \ + --max_output_len 30 \ + --vocab_file ${VOCAB_FILE_PATH} \ + --no_add_special_tokens + +[TensorRT-LLM] TensorRT-LLM version: 0.9.0.dev2024020600Input [Text 0]: " Born in north-east France, Soyer trained as a" +Output [Text 0 Beam 0]: "chef in the renowned kitchens of Lyon. After honing his skills in various Michelin-starred establishments, he embarked on a solo venture, establishing his own restaurant" +``` + +* summarize.py + +```bash +python3 ../summarize.py --test_trt_llm \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --vocab_file ${VOCAB_FILE_PATH} \ + --no_add_special_tokens + +[02/06/2024-10:08:54] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.2821836471557617 sec) +[02/06/2024-10:08:54] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 1989) +[02/06/2024-10:08:54] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 605.9989975648089) +[02/06/2024-10:08:54] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[02/06/2024-10:08:55] [TRT-LLM] [I] rouge1 : 26.376388677070615 +[02/06/2024-10:08:55] [TRT-LLM] [I] rouge2 : 7.468157586877296 +[02/06/2024-10:08:55] [TRT-LLM] [I] rougeL : 17.953060795106556 +[02/06/2024-10:08:55] [TRT-LLM] [I] rougeLsum : 22.410938121151652 +``` + +* mmlu.py + +Download the dataset first + +```bash +mkdir data +wget https://people.eecs.berkeley.edu/~hendrycks/data.tar -O data/mmlu.tar +tar -xf data/mmlu.tar -C data +mv data/data data/mmlu +``` + +Evaluate on MMLU dataset. + +```bash +python3 ../mmlu.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} + +Average accuracy 0.358 - social sciences +Average accuracy 0.359 - other (business, health, misc.) +Average accuracy: 0.329 +``` + +## Specific commands + +In this section, we demonstrate the scripts to convert checkpoint, building engine and run inference on different settings. We will not demonstrate all combinations here because there are too many cases. We choose some important cases to demonstrate. + +### Run Gemma 2B + +#### Run inference under bfloat16 for keras checkpoint + +```bash +CKPT_PATH=/tmp/models/gemma_keras/keras/gemma_2b_en/ +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_en_tensorrt_llm/bf16/tp1/ +ENGINE_PATH=/tmp/gemma/2B/bf16/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model + +python3 ./convert_checkpoint.py \ + --ckpt-type keras \ + --model-dir ${CKPT_PATH} \ + --dtype bfloat16 \ + --world-size 1 \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin bfloat16 \ + --gpt_attention_plugin bfloat16 \ + --max_batch_size 8 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --output_dir ${ENGINE_PATH} + +python3 ../summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --no_add_special_tokens + +[02/08/2024-05:04:13] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.96612286567688 sec) +[02/08/2024-05:04:13] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2510) +[02/08/2024-05:04:13] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 632.8598697034137) +[02/08/2024-05:04:13] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[02/08/2024-05:04:13] [TRT-LLM] [I] rouge1 : 20.40970022875146 +[02/08/2024-05:04:13] [TRT-LLM] [I] rouge2 : 5.512437888775742 +[02/08/2024-05:04:13] [TRT-LLM] [I] rougeL : 15.135998543979978 +[02/08/2024-05:04:13] [TRT-LLM] [I] rougeLsum : 17.250431908889873 +``` + +#### Run inference under FP8 for keras checkpoint + +WARNING: This way of running FP8 will introduce noticeable accuracy drop. To avoid that, use AMMO quantization mentioned in this readme. + +In this example, we demonstrate how to run FP8 inference on Gemma. Note that `convert_checkpoint.py` only uses identity activation scales, so the accuracy might be little worse than higher precision in some cases, but it is still very good because we don't do any calibration. This also shows the stability of FP8 compared to INT8. + +```bash +CKPT_PATH=/tmp/models/gemma_keras/keras/gemma_2b_en/ +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_en_tensorrt_llm/fp8/tp1/ +ENGINE_PATH=/tmp/gemma/2B/fp8/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model + +python3 ./convert_checkpoint.py \ + --ckpt-type keras \ + --model-dir ${CKPT_PATH} \ + --dtype bfloat16 \ + --world-size 1 \ + --enable_fp8 \ + --fp8_kv_cache \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin bfloat16 \ + --gpt_attention_plugin bfloat16 \ + --max_batch_size 8 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --output_dir ${ENGINE_PATH} + +python3 ../summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --no_add_special_tokens + +[02/08/2024-10:37:15] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.116227149963379 sec) +[02/08/2024-10:37:15] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2419) +[02/08/2024-10:37:15] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 776.259201781368) +[02/08/2024-10:37:15] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[02/08/2024-10:37:15] [TRT-LLM] [I] rouge1 : 20.206082692133098 +[02/08/2024-10:37:15] [TRT-LLM] [I] rouge2 : 5.902141189518428 +[02/08/2024-10:37:15] [TRT-LLM] [I] rougeL : 15.403458457907643 +[02/08/2024-10:37:15] [TRT-LLM] [I] rougeLsum : 17.44535527417846 + +python3 ../mmlu.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} + +Average accuracy 0.390 - social sciences +Average accuracy 0.405 - other (business, health, misc.) +Average accuracy: 0.356 +``` + +#### Run 2B inference under SmoothQuant for jax checkpoint + +```bash +CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/sq/tp1 +ENGINE_PATH=/tmp/gemma/2B/int8_sq/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model + +python3 ./convert_checkpoint.py \ + --ckpt-type jax \ + --model-dir ${CKPT_PATH} \ + --dtype float16 \ + --use_smooth_quant_plugin 0.5 \ + --tokenizer_dir ${VOCAB_FILE_PATH} \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin float16 \ + --gpt_attention_plugin float16 \ + --max_batch_size 8 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --enable_xqa enable \ + --output_dir ${ENGINE_PATH} + +python3 ../summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --no_add_special_tokens + +[02/08/2024-04:42:06] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.460859775543213 sec) +[02/08/2024-04:42:06] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 1786) +[02/08/2024-04:42:06] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 516.0567361385428) +[02/08/2024-04:42:06] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[02/08/2024-04:42:06] [TRT-LLM] [I] rouge1 : 22.534044843245525 +[02/08/2024-04:42:06] [TRT-LLM] [I] rouge2 : 5.940093176022924 +[02/08/2024-04:42:06] [TRT-LLM] [I] rougeL : 16.258991712579736 +[02/08/2024-04:42:06] [TRT-LLM] [I] rougeLsum : 19.60977626046262 +``` + +#### Run inference under weight only for jax checkpoint + +Available precisions: `int8` and `int4` + +* `int8` + +```bash +CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/w8_a16/tp1/ +ENGINE_PATH=/tmp/gemma/2B/w8_a16/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model + +python3 ./convert_checkpoint.py \ + --ckpt-type jax \ + --model-dir ${CKPT_PATH} \ + --use-weight-only-with-precision int8 \ + --dtype bfloat16 \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin bfloat16 \ + --gpt_attention_plugin bfloat16 \ + --max_batch_size 32 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --enable_xqa enable \ + --output_dir ${ENGINE_PATH} + +python3 ../summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --no_add_special_tokens + +[02/08/2024-04:44:54] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.5987987518310547 sec) +[02/08/2024-04:44:54] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 1797) +[02/08/2024-04:44:54] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 499.3332842203787) +[02/08/2024-04:44:54] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[02/08/2024-04:44:54] [TRT-LLM] [I] rouge1 : 24.48521318679745 +[02/08/2024-04:44:54] [TRT-LLM] [I] rouge2 : 7.240543314565931 +[02/08/2024-04:44:54] [TRT-LLM] [I] rougeL : 17.857921729984078 +[02/08/2024-04:44:54] [TRT-LLM] [I] rougeLsum : 21.214162155642896 +``` + +* `int4` + +```bash +CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/w4_a16/tp1/ +ENGINE_PATH=/tmp/gemma/2B/w4_a16/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model + +python3 ./convert_checkpoint.py \ + --ckpt-type jax \ + --model-dir ${CKPT_PATH} \ + --use-weight-only-with-precision int4 \ + --dtype bfloat16 \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin bfloat16 \ + --gpt_attention_plugin bfloat16 \ + --max_batch_size 32 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --enable_xqa enable \ + --output_dir ${ENGINE_PATH} + +python3 ../summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --no_add_special_tokens + +[02/08/2024-04:48:06] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.1938045024871826 sec) +[02/08/2024-04:48:06] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 1462) +[02/08/2024-04:48:06] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 457.7612683749003) +[02/08/2024-04:48:06] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[02/08/2024-04:48:06] [TRT-LLM] [I] rouge1 : 25.19118129834017 +[02/08/2024-04:48:06] [TRT-LLM] [I] rouge2 : 6.284558232487986 +[02/08/2024-04:48:06] [TRT-LLM] [I] rougeL : 18.133244708843726 +[02/08/2024-04:48:06] [TRT-LLM] [I] rougeLsum : 20.562024727650662 +``` + +#### Run inference under INT8 KV caches for jax checkpoint + +```bash +CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/int8kv/tp1 +ENGINE_PATH=/tmp/gemma/2B/int8kv/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model + +python3 ./convert_checkpoint.py \ + --ckpt-type jax \ + --model-dir ${CKPT_PATH} \ + --world-size 1 \ + --dtype bfloat16 \ + --calibrate_kv_cache \ + --tokenizer_dir ${VOCAB_FILE_PATH} \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin bfloat16 \ + --gpt_attention_plugin bfloat16 \ + --max_batch_size 32 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --enable_xqa enable \ + --strongly_type \ + --output_dir ${ENGINE_PATH} + +python3 ../summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --no_add_special_tokens + +[02/08/2024-04:52:22] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.5348474979400635 sec) +[02/08/2024-04:52:22] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 1819) +[02/08/2024-04:52:22] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 514.5907994786265) +[02/08/2024-04:52:22] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[02/08/2024-04:52:22] [TRT-LLM] [I] rouge1 : 24.0397941580232 +[02/08/2024-04:52:22] [TRT-LLM] [I] rouge2 : 7.325311340360227 +[02/08/2024-04:52:22] [TRT-LLM] [I] rougeL : 17.54210044633271 +[02/08/2024-04:52:22] [TRT-LLM] [I] rougeLsum : 20.627861723682177 +``` + +### Run Gemma 7B + +#### Run inference under bfloat16 for torch checkpoint + +Since torch model does not have model config, we need to add it manually in `CKPT_PATH` with file name `config.json`. + +```bash +CKPT_PATH=/tmp/models/pytorch/ckpt/ +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/bf16/tp1/ +ENGINE_PATH=/tmp/gemma/7B/bf16/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model + +python3 ./convert_checkpoint.py \ + --ckpt-type torch \ + --model-dir ${CKPT_PATH} \ + --dtype bfloat16 \ + --world-size 1 \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin bfloat16 \ + --gpt_attention_plugin bfloat16 \ + --max_batch_size 8 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --output_dir ${ENGINE_PATH} + +python3 ../summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --no_add_special_tokens + +python3 ../mmlu.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} + +Average accuracy 0.739 - social sciences +Average accuracy 0.697 - other (business, health, misc.) +Average accuracy: 0.630 +``` + +#### Run inference under FP8 for jax checkpoint + +WARNING: This way of running FP8 will introduce noticeable accuracy drop. To avoid that, use AMMO quantization mentioned in this readme. + +In this example, we demonstrate how to run FP8 inference on Gemma. Note that `convert_checkpoint.py` only uses identity activation scales, so the accuracy might be little worse than higher precision in some cases, but it is still very good because we don't do any calibration. This also shows the stability of FP8 compared to INT8. + +```bash +CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_7b_it +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/fp8/tp1/ +ENGINE_PATH=/tmp/gemma/7B/fp8/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model + +python3 ./convert_checkpoint.py \ + --ckpt-type jax \ + --model-dir ${CKPT_PATH} \ + --dtype bfloat16 \ + --world-size 1 \ + --enable_fp8 \ + --fp8_kv_cache \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin bfloat16 \ + --gpt_attention_plugin bfloat16 \ + --max_batch_size 8 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --output_dir ${ENGINE_PATH} + +python3 ../summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --no_add_special_tokens + +[02/08/2024-06:42:13] [TRT-LLM] [I] TensorRT-LLM (total latency: 5.884302377700806 sec) +[02/08/2024-06:42:13] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2694) +[02/08/2024-06:42:13] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 457.8282737830064) +[02/08/2024-06:42:13] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[02/08/2024-06:42:13] [TRT-LLM] [I] rouge1 : 27.18633861010837 +[02/08/2024-06:42:13] [TRT-LLM] [I] rouge2 : 7.734928823230158 +[02/08/2024-06:42:13] [TRT-LLM] [I] rougeL : 19.32537431798716 +[02/08/2024-06:42:13] [TRT-LLM] [I] rougeLsum : 22.82522575944535 +``` + +#### Run 7B inference under SmoothQuant for jax checkpoint + +```bash +CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_7b_it +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/sq/tp1 +ENGINE_PATH=/tmp/gemma/7B/int8_sq/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model + +python3 ./convert_checkpoint.py \ + --ckpt-type jax \ + --model-dir ${CKPT_PATH} \ + --dtype float16 \ + --use_smooth_quant_plugin 0.5 \ + --tokenizer_dir ${VOCAB_FILE_PATH} \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin float16 \ + --gpt_attention_plugin float16 \ + --max_batch_size 8 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --enable_xqa enable \ + --output_dir ${ENGINE_PATH} + +python3 ../summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --no_add_special_tokens + +[02/19/2024-10:02:53] [TRT-LLM] [I] --------------------------------------------------------- +[02/19/2024-10:03:09] [TRT-LLM] [I] TensorRT-LLM (total latency: 13.65670919418335 sec) +[02/19/2024-10:03:09] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 8351) +[02/19/2024-10:03:09] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 611.494312521266) +[02/19/2024-10:03:09] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[02/19/2024-10:03:09] [TRT-LLM] [I] rouge1 : 28.8107815115074 +[02/19/2024-10:03:09] [TRT-LLM] [I] rouge2 : 8.623835512061866 +[02/19/2024-10:03:09] [TRT-LLM] [I] rougeL : 19.7277195532959 +[02/19/2024-10:03:09] [TRT-LLM] [I] rougeLsum : 23.434950511855114 +``` + +#### Run inference under weight only for keras checkpoint + +Available precisions: `int8` and `int4` + +* `int8` + +```bash +CKPT_PATH=/tmp/models/gemma_keras/keras/gemma_7b_en/ +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/w8_a16/tp1/ +ENGINE_PATH=/tmp/gemma/7B/w8_a16/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model + +python3 ./convert_checkpoint.py \ + --ckpt-type keras \ + --model-dir ${CKPT_PATH} \ + --use-weight-only-with-precision int8 \ + --dtype bfloat16 \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin bfloat16 \ + --gpt_attention_plugin bfloat16 \ + --max_batch_size 32 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --enable_xqa enable \ + --output_dir ${ENGINE_PATH} + +python3 ../summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --no_add_special_tokens + +[02/08/2024-07:38:15] [TRT-LLM] [I] TensorRT-LLM (total latency: 8.49835753440857 sec) +[02/08/2024-07:38:15] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2654) +[02/08/2024-07:38:15] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 312.2956393931832) +[02/08/2024-07:38:15] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[02/08/2024-07:38:16] [TRT-LLM] [I] rouge1 : 20.396209981234687 +[02/08/2024-07:38:16] [TRT-LLM] [I] rouge2 : 5.73302850102211 +[02/08/2024-07:38:16] [TRT-LLM] [I] rougeL : 16.001683776127507 +[02/08/2024-07:38:16] [TRT-LLM] [I] rougeLsum : 18.36957526315223 +``` + +* `int4` + +```bash +CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_7b_it +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/w4_a16/tp1/ +ENGINE_PATH=/tmp/gemma/7B/w4_a16/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model + +python3 ./convert_checkpoint.py \ + --ckpt-type jax \ + --model-dir ${CKPT_PATH} \ + --use-weight-only-with-precision int4 \ + --dtype bfloat16 \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin bfloat16 \ + --gpt_attention_plugin bfloat16 \ + --max_batch_size 32 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --enable_xqa enable \ + --output_dir ${ENGINE_PATH} + +python3 ../summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --no_add_special_tokens + +[02/08/2024-07:43:32] [TRT-LLM] [I] TensorRT-LLM (total latency: 7.282559156417847 sec) +[02/08/2024-07:43:32] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2253) +[02/08/2024-07:43:32] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 309.3692686333369) +[02/08/2024-07:43:32] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[02/08/2024-07:43:32] [TRT-LLM] [I] rouge1 : 27.22556858171486 +[02/08/2024-07:43:32] [TRT-LLM] [I] rouge2 : 6.889046653923549 +[02/08/2024-07:43:32] [TRT-LLM] [I] rougeL : 19.07040336076859 +[02/08/2024-07:43:32] [TRT-LLM] [I] rougeLsum : 22.840545705675858 +``` + +#### Run inference under INT8 KV caches for keras checkpoint + +```bash +CKPT_PATH=/tmp/models/gemma_keras/keras/gemma_7b_en/ +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/int8kv/tp1 +ENGINE_PATH=/tmp/gemma/7B/int8kv/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model + +python3 ./convert_checkpoint.py \ + --ckpt-type keras \ + --model-dir ${CKPT_PATH} \ + --world-size 1 \ + --dtype bfloat16 \ + --calibrate_kv_cache \ + --tokenizer_dir ${VOCAB_FILE_PATH} \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin bfloat16 \ + --gpt_attention_plugin bfloat16 \ + --max_batch_size 32 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --enable_xqa enable \ + --strongly_type \ + --output_dir ${ENGINE_PATH} + +python3 ../summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 8 \ + --max_ite 5 \ + --no_add_special_tokens + +[02/08/2024-07:51:11] [TRT-LLM] [I] TensorRT-LLM (total latency: 8.73880124092102 sec) +[02/08/2024-07:51:11] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2771) +[02/08/2024-07:51:11] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 317.09154649544956) +[02/08/2024-07:51:11] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[02/08/2024-07:51:11] [TRT-LLM] [I] rouge1 : 20.934864626327627 +[02/08/2024-07:51:11] [TRT-LLM] [I] rouge2 : 4.954721611692932 +[02/08/2024-07:51:11] [TRT-LLM] [I] rougeL : 15.307592049634444 +[02/08/2024-07:51:11] [TRT-LLM] [I] rougeLsum : 17.94213019528988 +``` + +### Run AMMO Quantization + +#### Requirements + +AMMO toolkit provides quantization solutions with better accuracy. To enable it, have the latest ammo and transformers Python package installed to support Gemma. Then run the following commands. + +#### Quantize Checkpoints + +``` +python ../quantization/quantize.py --model_dir ${HF_GEMMA_PATH} \ + --dtype float16 \ + --qformat ${QUANT_TYPE} \ + --output_dir ${UNIFIED_CKPT_PATH} \ + --tp_size 1 +``` +HF_GEMMA_PATH can either be HF model card name or the downloaded model path. QUANT_TYPE can be chosen from fp8, int4_awq, and int8_sq. + +#### Build Engines + +For fp8, build engines with: +``` +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin float16 \ + --gpt_attention_plugin float16 \ + --max_batch_size 8 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --output_dir ${ENGINE_PATH} +``` + +For int4_awq and int8_sq, build engines with: + +``` +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin float16 \ + --gpt_attention_plugin float16 \ + --max_batch_size 8 \ + --max_input_len 3000 \ + --max_output_len 100 \ + --context_fmha enable \ + --enable_xqa enable \ + --output_dir ${ENGINE_PATH} +``` + +#### Accuracy Results on MMLU + +| Model | fp8 | int4_awq | int8_sq | +|---------------|-------|----------|---------| +| 2B Pretrained | 0.407 | 0.378 | 0.328 | +| 7B Pretrained | 0.643 | 0.615 | 0.480 | diff --git a/examples/gemma/convert_checkpoint.py b/examples/gemma/convert_checkpoint.py new file mode 100644 index 000000000..7107c6ba3 --- /dev/null +++ b/examples/gemma/convert_checkpoint.py @@ -0,0 +1,856 @@ +#!/usr/bin/env python3 +import argparse +import json +import logging +import math +import pathlib +import re +import time +import typing + +import flax.traverse_util +import h5py +import numpy as np +import safetensors.numpy +import safetensors.torch +import sentencepiece as sp +import torch +import utils.params +import utils.transformer +from datasets import load_dataset +from easydict import EasyDict + +import tensorrt_llm +from tensorrt_llm._utils import torch_to_numpy +from tensorrt_llm.models.gemma.smoothquant import * +from tensorrt_llm.models.gemma.weight import (dummy_weights_awq, + load_from_fp8_llama, + quantize_fp8_weigths) + +LOGGER = logging.getLogger("convert_checkpoint") + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt-type", + type=str, + choices=["jax", "keras", "torch"]) + parser.add_argument("--model-dir", type=pathlib.Path, required=True) + parser.add_argument("--output-model-dir", type=pathlib.Path, required=True) + parser.add_argument("--world-size", + type=int, + default=1, + help="world size, only support tensor parallelism now") + parser.add_argument( + "--use-weight-only-with-precision", + choices=["int8", "int4", "w4a8_awq", "w4a16_awq"], + help= + "help='Quantize weights for the various GEMMs to INT4/INT8. Define the precision for the weights.", + ) + parser.add_argument("--dtype", + type=str, + choices=["float32", "bfloat16", "float16"]) + parser.add_argument( + "--enable_fp8", + action="store_true", + help="Use FP8 Linear layer for Attention QKV/Dense and MLP.") + parser.add_argument( + "--fp8_kv_cache", + action="store_true", + help= + "By default, we use dtype for KV cache. fp8_kv_cache chooses int8 quantization for KV", + ) + parser.add_argument( + "--ammo_quant_ckpt_path", + default=None, + help= + "Path of a directory to quantized model checkpoints in .safetensors format or \ + path of a quantized model checkpoint in .npz format") + parser.add_argument('--use_smooth_quant', + default=False, + action="store_true", + help="Use smooth quant.") + parser.add_argument( + "--calibrate_kv_cache", + "-kv", + action="store_true", + help= + "Generate scaling factors for KV cache. Used for storing KV cache in int8." + ) + parser.add_argument( + '--per_channel', + default=False, + action="store_true", + help= + 'By default, we use a single static scaling factor for the GEMM\'s result. ' + 'per_channel instead uses a different static scaling factor for each channel. ' + 'The latter is usually more accurate, but a little slower.') + parser.add_argument( + '--per_token', + default=False, + action="store_true", + help= + 'By default, we use a single static scaling factor to scale activations in the int8 range. ' + 'per_token chooses at run time, and for each token, a custom scaling factor. ' + 'The latter is usually more accurate, but a little slower.') + parser.add_argument( + "--use_smooth_quant_plugin", + "-sq", + type=float, + default=None, + help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" + " to Smoothquant the model, and output int8 weights." + " A good first try is 0.5. Must be in [0, 1]") + parser.add_argument( + '--tokenizer_dir', + default=None, + help='tokenizer path; defaults to jax_model_dir if left unspecified') + + args = parser.parse_args() + + return args + + +class JAXParser: + + def load_parameters(self, checkpoint_path: pathlib.Path): + checkpoint_path = checkpoint_path.absolute() + return utils.params.nest_params( + utils.params.param_remapper( + utils.params.load_params(checkpoint_path))) + + def embedding_weights(self, ckpt_params): + return ckpt_params["transformer"]["embedder"]["input_embedding"] + + def get_config(self, checkpoint_path, ckpt_params, num_embed): + return utils.transformer.TransformerConfig.from_params( + ckpt_params, num_embed=num_embed) + + def rename_to_trt_llm(self, name: str): + """Rename a gemma parameter name by the corresponding TRT-LLM style name.""" + prefix, name = name.split(".", maxsplit=1) + assert prefix == "transformer" + sub_patterns = ( + (r"embedder.input_embedding", r"vocab_embedding.weight"), + (r"layer_(\d+).pre_attention_norm.scale", + r"layers.\1.input_layernorm.weight"), + (r"layer_(\d+).attn.q_einsum.w", r"layers.\1.attention.qkv.weight"), + (r"layer_(\d+).attn.kv_einsum.w", + None), # drop as kv will be concatenated with q + (r"layer_(\d+).attn.qkv_einsum.w", + r"layers.\1.attention.qkv.weight"), + (r"layer_(\d+).attn.attn_vec_einsum.w", + r"layers.\1.attention.dense.weight"), + (r"layer_(\d+).mlp.gating_einsum", r"layers.\1.mlp.fc.weight"), + (r"layer_(\d+).mlp.linear", r"layers.\1.mlp.proj.weight"), + (r"layer_(\d+).pre_ffw_norm.scale", + r"layers.\1.post_layernorm.weight"), + (r"final_norm.scale", r"ln_f.weight"), + ) + + for source, target in sub_patterns: + if re.match(source, name): + if target is None: + return target + else: + name = re.sub(source, target, name) + return ".".join((prefix, name)) + else: + raise ValueError(f"Don't know how to rename {prefix}.{name}") + + def flatten_params(self, params): + return flax.traverse_util.flatten_dict(params, sep=".") + + +class KerasParser: + + def load_parameters(self, checkpoint_path: pathlib.Path): + checkpoint_path = checkpoint_path.absolute() + config_file = "config.json" + weights_file = json.load(open(checkpoint_path / config_file))["weights"] + h5_path = checkpoint_path / weights_file + return h5py.File(h5_path, "r+") + + def embedding_weights(self, ckpt_params): + return np.array(ckpt_params["layers/reversible_embedding/vars/0"]) + + def get_config(self, checkpoint_path, ckpt_params, num_embed): + checkpoint_path = checkpoint_path.absolute() + config_file = "config.json" + config_old = json.load(open(checkpoint_path / config_file))["config"] + config_new = {} + config_new["num_layers"] = config_old["num_layers"] + config_new["num_embed"] = config_old["vocabulary_size"] + config_new["embed_dim"] = config_old["hidden_dim"] + config_new["hidden_dim"] = config_old["intermediate_dim"] // 2 + config_new["num_heads"] = config_old["num_query_heads"] + config_new["head_dim"] = config_old["head_dim"] + config_new["num_kv_heads"] = config_old["num_key_value_heads"] + return EasyDict(config_new) + + def rename_to_trt_llm(self, name: str): + """Rename a gemma parameter name by the corresponding TRT-LLM style name.""" + prefix = "transformer" + name = name.replace("/gemma_decoder_block/", "/gemma_decoder_block_0/") + sub_patterns = ( + (r"layers/reversible_embedding/vars/0", r"vocab_embedding.weight"), + (r"layers/gemma_decoder_block_(\d+)/pre_attention_norm/vars/0", + r"layers.\1.input_layernorm.weight"), + (r"layers/gemma_decoder_block_(\d+)/attention/query_dense/vars/0", + r"layers.\1.attention.qkv.weight"), + (r"layers/gemma_decoder_block_(\d+)/attention/key_dense/vars/0", + None), # drop as k will be concatenated with q + (r"layers/gemma_decoder_block_(\d+)/attention/value_dense/vars/0", + None), # drop as v will be concatenated with q + (r"layers/gemma_decoder_block_(\d+)/attention/output_dense/vars/0", + r"layers.\1.attention.dense.weight"), + (r"layers/gemma_decoder_block_(\d+)/gating_ffw/vars/0", + r"layers.\1.mlp.fc.weight"), + (r"layers/gemma_decoder_block_(\d+)/gating_ffw_2/vars/0", + None), # merged with above + (r"layers/gemma_decoder_block_(\d+)/ffw_linear/vars/0", + r"layers.\1.mlp.proj.weight"), + (r"layers/gemma_decoder_block_(\d+)/pre_ffw_norm/vars/0", + r"layers.\1.post_layernorm.weight"), + (r"layers/rms_normalization/vars/0", r"ln_f.weight"), + (r"optimizer/vars/(\d+)", None), # Not used + ) + + for source, target in sub_patterns: + if re.match(source, name): + if target is None: + return target + else: + name = re.sub(source, target, name) + return ".".join((prefix, name)) + else: + raise ValueError(f"Don't know how to rename {prefix}.{name}") + + def flatten_params(self, params): + f_params = {} + + def walk(name, obj): + if isinstance(obj, h5py.Dataset): + f_params[name] = np.array(obj) + + params.visititems(walk) + return f_params + + +class TorchParser: + + def load_parameters(self, checkpoint_path: pathlib.Path): + ckpt_path = list(checkpoint_path.glob('*.ckpt'))[0] + model_params = torch.load(ckpt_path)['model_state_dict'] + model_params.pop('freqs_cis') + return model_params + + def embedding_weights(self, ckpt_params): + return ckpt_params["embedder.weight"] + + def get_config(self, checkpoint_path, ckpt_params, num_embed): + checkpoint_path = checkpoint_path.absolute() + config_file = "config.json" + with open(checkpoint_path / config_file, 'r') as f: + json_str = f.read() + json_str = json_str.replace("'", "\"") + json_str = json_str.replace(",\n}", "\n}") + config_old = json.loads(json_str) + config_new = {} + config_new["num_layers"] = config_old["num_hidden_layers"] + config_new["num_embed"] = config_old["vocab_size"] + config_new["embed_dim"] = config_old["hidden_size"] + config_new["hidden_dim"] = config_old["intermediate_size"] + config_new["num_heads"] = config_old["num_attention_heads"] + config_new["head_dim"] = config_old["head_dim"] + config_new["num_kv_heads"] = config_old["num_key_value_heads"] + return EasyDict(config_new) + + def rename_to_trt_llm(self, name: str): + """Rename a gemma parameter name by the corresponding TRT-LLM style name.""" + prefix = "transformer" + sub_patterns = ( + (r"embedder.weight", r"vocab_embedding.weight"), + (r"model.layers.(\d+).input_layernorm.weight", + r"layers.\1.input_layernorm.weight"), + (r"model.layers.(\d+).self_attn.qkv_proj.weight", + r"layers.\1.attention.qkv.weight"), + (r"model.layers.(\d+).self_attn.o_proj.weight", + r"layers.\1.attention.dense.weight"), + (r"model.layers.(\d+).mlp.gate_proj.weight", + r"layers.\1.mlp.fc.weight"), + (r"model.layers.(\d+).mlp.up_proj.weight", + None), # merged with above + (r"model.layers.(\d+).mlp.down_proj.weight", + r"layers.\1.mlp.proj.weight"), + (r"model.layers.(\d+).post_attention_layernorm.weight", + r"layers.\1.post_layernorm.weight"), + (r"model.norm.weight", r"ln_f.weight"), + ) + + for source, target in sub_patterns: + if re.match(source, name): + if target is None: + return target + else: + name = re.sub(source, target, name) + return ".".join((prefix, name)) + else: + raise ValueError(f"Don't know how to rename {name}") + + def flatten_params(self, params): + f_params = {} + for k, v in params.items(): + if v.dtype == torch.bfloat16: + v = v.float() + f_params[k] = torch_to_numpy(v) + return f_params + + +CKPT_PARSER = {'jax': JAXParser, 'keras': KerasParser, 'torch': TorchParser} + + +def split(v, tp_size, idx, dim=0): + if tp_size == 1: + return v + return np.split(v, tp_size, axis=dim)[idx] + + +def split_matrix_tp(v, tensor_parallel, rank, dim): + return split(v, tensor_parallel, rank, dim=dim) + + +def add_trt_llm_weight(weights: typing.Dict[str, np.ndarray], + name: str, + param: np.ndarray, + dtype: typing.Optional[np.dtype] = None): + assert name not in weights, f"{name} is already added." + if dtype is not None: + param = param.astype(dtype) + param = np.ascontiguousarray(param) + weights[name] = param + + +def quantize(param: np.ndarray, + quant_mode: tensorrt_llm.quantization.QuantMode): + if quant_mode.is_int8_weight_only(): + quant_dtype = torch.int8 + elif quant_mode.is_int4_weight_only(): + quant_dtype = torch.quint4x2 + else: + raise ValueError(f"Invalid configuration got quant_mode={quant_mode}") + + if param.dtype == np.dtype("bfloat16"): + param = torch.from_numpy(param.astype(np.float32)).to(torch.bfloat16) + else: + param = torch.from_numpy(param) + param = param.t().contiguous() + + # previously this fn was available in torch.ops.fastertransformer namespace + ( + quantized_weights, + scales, + ) = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + param, quant_dtype) + + if scales.dtype == torch.bfloat16: + scales = scales.to(torch.float32).numpy().astype("bfloat16") + else: + scales = scales.numpy() + return quantized_weights.numpy(), scales + + +def convert_from_checkpoint( + trt_llm_config: tensorrt_llm.models.modeling_utils.PretrainedConfig, + model_dir: typing.Union[str, pathlib.Path], + ckpt_parser, + rank=0, +): + print("Loading weights...") + tik = time.time() + + tp_rank = rank + tp_size = trt_llm_config.mapping.tp_size + hidden_size = trt_llm_config.hidden_size + head_dim = trt_llm_config.head_size + + weights = {} + for model_file in [model_dir]: + LOGGER.debug(f"Loading directory {str(model_file)}...") + model_params = ckpt_parser.load_parameters(model_file) + model_params = ckpt_parser.flatten_params(model_params) + + for name, param in model_params.items(): + LOGGER.debug(f"Converting weight {name}...") + trt_llm_name = ckpt_parser.rename_to_trt_llm(name) + if trt_llm_name is None: # omit as used with other params + continue + + if "attn.q_einsum" in name: + gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads + assert gqa_mode + + # initial shape: (num_q_heads, hidden_size, head_dim) + q_param = param.transpose(1, 0, 2) + q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=1) + + # initial shape: (2, num_kv_heads, hidden_size, head_dim) + kv_name = name.replace("q_einsum", "kv_einsum") + kv_param = model_params[kv_name] + kv_param = kv_param.reshape( + trt_llm_config.num_key_value_heads * 2, + hidden_size, + head_dim, + ).transpose(1, 0, 2) + + # -> (hidden_size, num_q_heads / tp_size + 2, head_dim) + qkv_param = np.concatenate([q_param, kv_param], axis=1) + qkv_param = qkv_param.reshape(qkv_param.shape[0], -1) + qkv_param = qkv_param.transpose(1, 0) + + # If int8 kv enabled, weight-only quantization will be done later. + if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \ + not trt_llm_config.quant_mode.has_int8_kv_cache(): + qkv_param_quantized, qkv_param_scales = quantize( + qkv_param, trt_llm_config.quant_mode) + add_trt_llm_weight(weights, trt_llm_name, + qkv_param_quantized) + add_trt_llm_weight( + weights, + trt_llm_name.replace(".weight", ".per_channel_scale"), + qkv_param_scales, + trt_llm_config.dtype, + ) + else: + add_trt_llm_weight(weights, trt_llm_name, qkv_param, + trt_llm_config.dtype) + elif "self_attn.qkv_proj" in name: + q_param, k_param, v_param = np.split(param, [ + trt_llm_config.num_attention_heads * + trt_llm_config.head_size, + trt_llm_config.num_attention_heads * + trt_llm_config.head_size + + trt_llm_config.num_key_value_heads * + trt_llm_config.head_size + ], + axis=0) + gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads + + q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0) + if not gqa_mode: + k_param = split_matrix_tp(k_param, tp_size, tp_rank, dim=0) + v_param = split_matrix_tp(v_param, tp_size, tp_rank, dim=0) + + qkv_param = np.concatenate([q_param, k_param, v_param], axis=0) + if trt_llm_config.quant_mode.is_weight_only( + ) and not trt_llm_config.quant_mode.has_per_group_scaling(): + qkv_param_quantized, qkv_param_scales = quantize( + qkv_param, trt_llm_config.quant_mode) + add_trt_llm_weight(weights, trt_llm_name, + qkv_param_quantized) + add_trt_llm_weight( + weights, + trt_llm_name.replace(".weight", ".per_channel_scale"), + qkv_param_scales, + trt_llm_config.dtype, + ) + else: + add_trt_llm_weight(weights, trt_llm_name, qkv_param, + trt_llm_config.dtype) + elif "attn.qkv_einsum" in name: + gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads + assert not gqa_mode + # initial shape: [3, num_heads, hidden_size, head_dim] -> [3, num_heads, head_dim, hidden_size] + qkv_param = param.transpose(0, 1, 3, 2) + qkv_param = qkv_param.reshape(qkv_param.shape[0], -1, + qkv_param.shape[3]) + qkv_param = split_matrix_tp(qkv_param, tp_size, tp_rank, dim=1) + qkv_param = qkv_param.reshape(-1, qkv_param.shape[2]) + if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() \ + and not trt_llm_config.quant_mode.has_int8_kv_cache(): + qkv_param_quantized, qkv_param_scales = quantize( + qkv_param, trt_llm_config.quant_mode) + add_trt_llm_weight(weights, trt_llm_name, + qkv_param_quantized) + add_trt_llm_weight( + weights, + trt_llm_name.replace(".weight", ".per_channel_scale"), + qkv_param_scales, + trt_llm_config.dtype, + ) + else: + add_trt_llm_weight(weights, trt_llm_name, qkv_param, + trt_llm_config.dtype) + elif "attention/query_dense" in name: + # Keras specific KQV convert + gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads + if gqa_mode: + + # initial shape: (num_q_heads, hidden_size, head_dim) + q_param = param.transpose(1, 0, 2) + q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=1) + + # initial shape: (2, num_kv_heads, hidden_size, head_dim) + k_name = name.replace("query", "key") + k_param = model_params[k_name] + v_name = name.replace("query", "value") + v_param = model_params[v_name] + kv_param = np.stack((k_param, v_param), axis=0) + + kv_param = kv_param.reshape( + trt_llm_config.num_key_value_heads * 2, + hidden_size, + head_dim, + ).transpose(1, 0, 2) + + # -> (hidden_size, num_q_heads / tp_size + 2, head_dim) + qkv_param = np.concatenate([q_param, kv_param], axis=1) + qkv_param = qkv_param.reshape(qkv_param.shape[0], -1) + qkv_param = qkv_param.transpose(1, 0) + + if trt_llm_config.quant_mode.is_weight_only( + ) and not trt_llm_config.quant_mode.has_int8_kv_cache(): + qkv_param_quantized, qkv_param_scales = quantize( + qkv_param, trt_llm_config.quant_mode) + add_trt_llm_weight(weights, trt_llm_name, + qkv_param_quantized) + add_trt_llm_weight( + weights, + trt_llm_name.replace(".weight", + ".per_channel_scale"), + qkv_param_scales, + trt_llm_config.dtype, + ) + else: + add_trt_llm_weight(weights, trt_llm_name, qkv_param, + trt_llm_config.dtype) + else: + q_param = param + k_name = name.replace("query", "key") + k_param = model_params[k_name] + v_name = name.replace("query", "value") + v_param = model_params[v_name] + # initial shape: [3, num_heads, hidden_size, head_dim] -> [3, num_heads, head_dim, hidden_size] + qkv_param = np.stack((q_param, k_param, v_param), axis=0) + qkv_param = qkv_param.transpose(0, 1, 3, 2) + qkv_param = qkv_param.reshape(qkv_param.shape[0], -1, + qkv_param.shape[3]) + qkv_param = split_matrix_tp(qkv_param, + tp_size, + tp_rank, + dim=1) + qkv_param = qkv_param.reshape(-1, qkv_param.shape[2]) + if trt_llm_config.quant_mode.is_weight_only( + ) and not trt_llm_config.quant_mode.has_int8_kv_cache(): + qkv_param_quantized, qkv_param_scales = quantize( + qkv_param, trt_llm_config.quant_mode) + add_trt_llm_weight(weights, trt_llm_name, + qkv_param_quantized) + add_trt_llm_weight( + weights, + trt_llm_name.replace(".weight", + ".per_channel_scale"), + qkv_param_scales, + trt_llm_config.dtype, + ) + else: + add_trt_llm_weight(weights, trt_llm_name, qkv_param, + trt_llm_config.dtype) + elif "attention.dense.weight" in trt_llm_name: + # initial shape: (num_heads, head_dim, hidden_size) + if len(param.shape) == 3: + param = param.reshape(-1, param.shape[2]) + param = param.transpose( + 1, 0) # (hidden_size, num_heads * head_dum) + param = split_matrix_tp(param, tp_size, tp_rank, dim=1) + if trt_llm_config.quant_mode.is_weight_only( + ) and not trt_llm_config.quant_mode.has_int8_kv_cache(): + param_quantized, param_scales = quantize( + param, trt_llm_config.quant_mode) + add_trt_llm_weight(weights, trt_llm_name, param_quantized) + add_trt_llm_weight( + weights, + trt_llm_name.replace(".weight", ".per_channel_scale"), + param_scales, + trt_llm_config.dtype, + ) + else: + add_trt_llm_weight(weights, trt_llm_name, param, + trt_llm_config.dtype) + elif "mlp.fc.weight" in trt_llm_name: + if isinstance(ckpt_parser, KerasParser): + # initial shape: (hidden_size, intermediate_size) + fc_param, gate_param = param, model_params[name.replace( + "gating_ffw", "gating_ffw_2")] + elif isinstance(ckpt_parser, TorchParser): + # initial shape: (intermediate_size, hidden_size) + fc_param, gate_param = param, model_params[name.replace( + "mlp.gate_proj", "mlp.up_proj")] + fc_param = fc_param.transpose(1, 0) + gate_param = gate_param.transpose(1, 0) + else: + # initial shape: (2, hidden_size, intermediate_size) + fc_param, gate_param = param[0], param[1] + fc_param = fc_param.transpose(1, 0) + fc_param = split_matrix_tp(fc_param, tp_size, tp_rank, dim=0) + if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \ + not trt_llm_config.quant_mode.has_int8_kv_cache(): + fc_param_quantized, fc_param_scales = quantize( + fc_param, trt_llm_config.quant_mode) + add_trt_llm_weight(weights, trt_llm_name, + fc_param_quantized) + add_trt_llm_weight( + weights, + trt_llm_name.replace(".weight", ".per_channel_scale"), + fc_param_scales, + trt_llm_config.dtype, + ) + else: + add_trt_llm_weight(weights, trt_llm_name, fc_param, + trt_llm_config.dtype) + + gate_param = gate_param.transpose(1, 0) + gate_param = split_matrix_tp(gate_param, + tp_size, + tp_rank, + dim=0) + trt_llm_name = trt_llm_name.replace("mlp.fc.weight", + "mlp.gate.weight") + if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \ + not trt_llm_config.quant_mode.has_int8_kv_cache(): + gate_param_quantized, gate_param_scales = quantize( + gate_param, trt_llm_config.quant_mode) + add_trt_llm_weight(weights, trt_llm_name, + gate_param_quantized) + add_trt_llm_weight( + weights, + trt_llm_name.replace(".weight", ".per_channel_scale"), + gate_param_scales, + trt_llm_config.dtype, + ) + else: + add_trt_llm_weight(weights, trt_llm_name, gate_param, + trt_llm_config.dtype) + elif "mlp.proj.weight" in trt_llm_name: + if not isinstance(ckpt_parser, TorchParser): + # initial shape: (intermediate_size, hidden_size) + param = param.transpose(1, 0) + param = split_matrix_tp(param, tp_size, tp_rank, dim=1) + if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \ + not trt_llm_config.quant_mode.has_int8_kv_cache(): + param_quantized, param_scales = quantize( + param, trt_llm_config.quant_mode) + add_trt_llm_weight(weights, trt_llm_name, param_quantized) + add_trt_llm_weight( + weights, + trt_llm_name.replace(".weight", ".per_channel_scale"), + param_scales, + trt_llm_config.dtype, + ) + else: + add_trt_llm_weight(weights, trt_llm_name, param, + trt_llm_config.dtype) + elif "embedder.input_embedding" in name or "reversible_embedding" in name or "embedder.weight" in name: + if not trt_llm_config.share_embedding_table: + # TODO: safetensor doesn't allow to save a shared tensor. + # Currently, we clone the weight but to save the disk, it + # would be better to skip saving lm_head weights and + # handle it at the loading phase. + lm_head = split_matrix_tp(param, tp_size, tp_rank, dim=0) + add_trt_llm_weight(weights, "lm_head.weight", + np.copy(lm_head), trt_llm_config.dtype) + + param = np.multiply( + param.astype(np.float32), + math.sqrt(trt_llm_config.hidden_size), + ) + if trt_llm_config.use_parallel_embedding: + assert trt_llm_config.vocab_size % tp_size == 0 + param = split_matrix_tp( + param, + tp_size, + tp_rank, + dim=trt_llm_config.embedding_sharding_dim, + ) + add_trt_llm_weight(weights, trt_llm_name, param, + trt_llm_config.dtype) + elif any(keyword in name for keyword in ( + "pre_attention_norm.scale", + "pre_ffw_norm.scale", + "final_norm.scale", + "pre_attention_norm/vars/0", + "pre_ffw_norm/vars/0", + "rms_normalization/vars/0", + "input_layernorm", + "post_attention_layernorm", + "model.norm.weight", + )): + param = param + 1.0 # upcasted to float32 in case of bfloat16 + add_trt_llm_weight(weights, trt_llm_name, param, + trt_llm_config.dtype) + else: + raise RuntimeError(f"Unhandled {name} module weights") + del model_params + + print( + f"Weights loaded. Total time: {time.strftime('%H:%M:%S', time.gmtime(time.time() - tik))}" + ) + return weights + + +def convert(worker_rank, args, convert_kwargs): + for rank in range(worker_rank, args.world_size): + weights = convert_from_checkpoint(rank=rank, **convert_kwargs) + trt_llm_config = convert_kwargs.get("trt_llm_config") + if args.use_smooth_quant_plugin is not None or args.calibrate_kv_cache: + qkv_para = {} + smoother = {} + dataset = load_dataset("ccdv/cnn_dailymail", '3.0.0') + tokenizer = sp.SentencePieceProcessor(model_file=args.tokenizer_dir) + hf_model = create_model_from_config(trt_llm_config, weights) + act_range = capture_activation_range(hf_model, tokenizer, dataset) + if args.use_smooth_quant_plugin is not None: + smooth_model(hf_model, act_range, args.use_smooth_quant_plugin, + qkv_para, smoother) + weights = convert_hf_model( + hf_model, trt_llm_config.mapping, trt_llm_config.vocab_size, + args.dtype, False, 0, + args.use_weight_only_with_precision != None, + torch.int8 if args.use_weight_only_with_precision == 'int8' else + torch.quint4x2, args.use_smooth_quant_plugin is not None, + args.per_channel, args.per_token, args.calibrate_kv_cache, + act_range, qkv_para, smoother) + safetensors.torch.save_file( + weights, args.output_model_dir / f"rank{rank}.safetensors") + return + + use_awq = False + if args.use_weight_only_with_precision: + if args.use_weight_only_with_precision.endswith("awq"): + use_awq = True + if use_awq: + weights = dummy_weights_awq( + weights=weights, + precision=args.use_weight_only_with_precision, + trt_llm_config=trt_llm_config, + group_size=128) + elif args.enable_fp8 or args.fp8_kv_cache: + weight_scales = quantize_fp8_weigths( + weights, trt_llm_config.num_hidden_layers, + trt_llm_config.mapping) + scales = load_from_fp8_llama(args.ammo_quant_ckpt_path, + trt_llm_config.num_hidden_layers, + trt_llm_config.mapping, + args.fp8_kv_cache, weight_scales) + weights.update(scales) + + safetensors.numpy.save_file( + weights, args.output_model_dir / f"rank{rank}.safetensors") + + +def main(): + args = parse_arguments() + + tik = time.time() + + print(f"Loading source parameters from {args.model_dir.absolute()}") + ckpt_parser = CKPT_PARSER[args.ckpt_type]() + ckpt_params = ckpt_parser.load_parameters(args.model_dir) + input_embedding_weights = ckpt_parser.embedding_weights(ckpt_params) + num_embed, _ = input_embedding_weights.shape + ckpt_params_dtype = str( + input_embedding_weights.dtype).split(".")[-1] # np.bfloat16 -> bfloat16 + ckpt_config = ckpt_parser.get_config(args.model_dir, ckpt_params, num_embed) + # 2B TransformerConfig(num_layers=18, num_embed=256128, embed_dim=2048, hidden_dim=16384, num_heads=8, head_dim=256, num_kv_heads=1) + # 7B TransformerConfig(...) + + print(f"Source configuration determined from parameters: {ckpt_config}") + + quant_mode = tensorrt_llm.quantization.QuantMode(0) + quant_kwargs = {} + quant_algo = None + kv_cache_quant_algo = None + if args.use_weight_only_with_precision: + quant_algo = { + "int8": "W8A16", + "int4": "W4A16", + "w4a8_awq": "W4A8_AWQ", + "w4a16_awq": "W4A16_AWQ", + }[args.use_weight_only_with_precision] + elif args.enable_fp8: + quant_algo = "FP8" + elif args.use_smooth_quant: + quant_algo = "W8A8_SQ_PER_CHANNEL" + + if args.fp8_kv_cache: + kv_cache_quant_algo = "FP8" + if args.calibrate_kv_cache: + kv_cache_quant_algo = "INT8" + if args.use_smooth_quant: + quant_algo = "W8A8_SQ_PER_CHANNEL" + elif args.use_smooth_quant_plugin is not None: + if args.per_token and args.per_channel: + quant_algo = 'W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN' + elif not args.per_token and not args.per_channel: + quant_algo = 'W8A8_SQ_PER_TENSOR_PLUGIN' + elif not args.per_token and args.per_channel: + quant_algo = 'W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN' + elif args.per_token and not args.per_channel: + quant_algo = 'W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN' + quant_kwargs.update(sq_use_plugin=True) + + quant_kwargs.update(quant_algo=quant_algo, + kv_cache_quant_algo=kv_cache_quant_algo) + if quant_algo is not None or kv_cache_quant_algo is not None: + quant_mode = tensorrt_llm.quantization.QuantMode.from_quant_algo( + quant_algo, + kv_cache_quant_algo=kv_cache_quant_algo, + ) + if args.use_weight_only_with_precision: + if args.use_weight_only_with_precision.endswith("awq"): + quant_kwargs.update(has_zero_point=False, + pre_quant_scale=True, + exclude_modules=["lm_head"]) + + trt_llm_config = tensorrt_llm.models.modeling_utils.PretrainedConfig( + architecture="GemmaForCausalLM", + dtype=args.dtype or ckpt_params_dtype, + logits_dtype="float32", + vocab_size=ckpt_config.num_embed, + max_position_embeddings=8192, + hidden_size=ckpt_config.embed_dim, + num_hidden_layers=ckpt_config.num_layers, + num_attention_heads=ckpt_config.num_heads, + num_key_value_heads=ckpt_config.num_kv_heads, + head_size=ckpt_config.head_dim, + hidden_act="gelu", + intermediate_size=ckpt_config.hidden_dim, + norm_epsilon=1e-6, # hard-coded in RMSNorm from gemma/layers.py + position_embedding_type="rope_gpt_neox", + world_size=args.world_size, + tp_size=args.world_size, + pp_size=1, + quant_mode=quant_mode, + quant_kwargs=quant_kwargs, + ) + + trt_llm_config_dict = trt_llm_config.to_dict() + print(f"Determined TensorRT-LLM configuration {trt_llm_config_dict}") + + config_path = args.output_model_dir / "config.json" + config_path.parent.mkdir(exist_ok=True, parents=True) + LOGGER.debug(f"Saving TensorRT-LLM configuration to {config_path}") + with config_path.open("w") as config_file: + json.dump(trt_llm_config_dict, config_file, indent=4) + + convert_args = dict(trt_llm_config=trt_llm_config, + model_dir=args.model_dir, + ckpt_parser=ckpt_parser) + convert(0, args, convert_args) + + elapsed = time.strftime("%H:%M:%S", time.gmtime(time.time() - tik)) + print(f"Total time of converting checkpoints: {elapsed}") + + +if __name__ == "__main__": + main() diff --git a/examples/gemma/requirements.txt b/examples/gemma/requirements.txt new file mode 100644 index 000000000..3d8ffcfa5 --- /dev/null +++ b/examples/gemma/requirements.txt @@ -0,0 +1,9 @@ +-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +flax~=0.8.0 +jax[cuda12_pip]~=0.4.19 +safetensors~=0.4.1 +sentencepiece~=0.1.99 +h5py~=3.10.0 +easydict~=1.11 +rouge_score +nltk diff --git a/examples/gemma/utils/__init__.py b/examples/gemma/utils/__init__.py new file mode 100644 index 000000000..c506269a5 --- /dev/null +++ b/examples/gemma/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/examples/gemma/utils/layers.py b/examples/gemma/utils/layers.py new file mode 100644 index 000000000..0c2f47129 --- /dev/null +++ b/examples/gemma/utils/layers.py @@ -0,0 +1,39 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base layers.""" + +import jax +import jax.numpy as jnp +from flax import linen as nn + + +class Einsum(nn.Module): + shape: tuple[int, ...] + + @nn.compact + def __call__(self, eqn: str, x: jax.Array) -> jax.Array: + w = self.param('w', nn.initializers.zeros_init(), self.shape) + return jnp.einsum(eqn, x, w) + + +class RMSNorm(nn.Module): + + @nn.compact + def __call__(self, x): + scale = self.param('scale', nn.initializers.zeros_init(), (x.shape[-1])) + var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) + normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) + normed_inputs = normed_inputs * (1 + scale) + return normed_inputs diff --git a/examples/gemma/utils/modules.py b/examples/gemma/utils/modules.py new file mode 100644 index 000000000..2ec10855e --- /dev/null +++ b/examples/gemma/utils/modules.py @@ -0,0 +1,206 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer sub-modules. +""" + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from . import layers, positional_embeddings + +K_MASK = -2.3819763e38 # Set to a large negative number. +LayerCache = dict[str, jax.Array] + + +def init_layer_cache(cache_size: int, num_heads: int, head_dim: int, + batch_size: int) -> LayerCache: + return { + 'v': + jnp.zeros((batch_size, cache_size, num_heads, head_dim), + dtype=jnp.float32), + 'k': + jnp.zeros((batch_size, cache_size, num_heads, head_dim), + dtype=jnp.float32), + } + + +class Embedder(nn.Module): + """Embedder module.""" + + vocab_size: int + embed_dim: int + + def setup(self): + self.input_embedding_table = self.param( + 'input_embedding', + nn.initializers.zeros_init(), + (self.vocab_size, self.embed_dim), + ) + + def encode(self, x: jax.Array) -> jax.Array: + x = self.input_embedding_table[(x, )] + x *= jnp.sqrt(self.embed_dim).astype(x.dtype) + return x + + def decode(self, x: jax.Array) -> jax.Array: + return jnp.dot(x, self.input_embedding_table.T) + + +class Attention(nn.Module): + """Attention module.""" + + num_heads: int + num_kv_heads: int + features: int + head_dim: int + + @property + def use_qkv_einsum(self): + return self.num_kv_heads == self.num_heads + + def setup(self): + self.attn_vec_einsum = layers.Einsum(shape=(self.num_heads, + self.head_dim, + self.features), ) + + if self.use_qkv_einsum: + self.qkv_einsum = layers.Einsum(shape=(3, self.num_heads, + self.features, + self.head_dim), ) + else: + self.q_einsum = layers.Einsum(shape=(self.num_heads, self.features, + self.head_dim), ) + self.kv_einsum = layers.Einsum(shape=(2, self.num_kv_heads, + self.features, + self.head_dim), ) + + def __call__( + self, + x: jax.Array, + segment_pos: int, + cache: LayerCache, + attn_mask: jax.Array, + time_step: int, + ) -> tuple[LayerCache, jax.Array]: + + bsz = x.shape[0] + + if self.use_qkv_einsum: + query_proj, key_proj, value_proj = self.qkv_einsum( + 'BTD,SNDH->SBTNH', x) + else: + query_proj = self.q_einsum('BTD,NDH->BTNH', x) + key_proj, value_proj = self.kv_einsum('BSD,CKDH->CBSKH', x) + + query_proj = positional_embeddings.apply_rope( + query_proj, + segment_pos, + head_dim=self.head_dim, + ) + query_scaled = query_proj * self.head_dim**-0.5 + + key_proj = positional_embeddings.apply_rope( + key_proj, + segment_pos, + head_dim=self.head_dim, + ) + + # Cache is left aligned. + cache['v'] = (cache['v'].at[:bsz, [time_step], :, :].set(value_proj) + ) # values + cache['k'] = (cache['k'].at[:bsz, [time_step], :, :].set(key_proj) + ) # rotated_keys + + logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, cache['k']) + logits = logits.astype(jnp.float32) + + padded_logits = jnp.where( + (jnp.expand_dims(attn_mask, -2) >= K_MASK * 0.5), logits, K_MASK) + probs = jax.nn.softmax(padded_logits, axis=-1).astype(cache['k'].dtype) + + encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, cache['v']) + attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded) + + return cache, attn_output + + +class FeedForward(nn.Module): + """Feed forward module.""" + + features: int + hidden_dim: int + + @nn.compact + def __call__(self, x): + w_gating = self.param( + 'gating_einsum', + nn.initializers.zeros_init(), + ((2, self.features, self.hidden_dim)), + ) + ff_gate = jnp.dot(x, w_gating[0]) + gate_value = nn.gelu(ff_gate) + + ff1 = jnp.dot(x, w_gating[1]) + activations = gate_value * ff1 + + w_linear = self.param( + 'linear', + nn.initializers.zeros_init(), + (self.hidden_dim, self.features), + ) + outputs = jnp.dot(activations, w_linear) + + return outputs + + +class Block(nn.Module): + """Transformer block.""" + + num_heads: int + num_kv_heads: int + embed_dim: int + head_dim: int + hidden_dim: int + + def setup(self): + self.pre_attention_norm = layers.RMSNorm() + self.attn = Attention( + num_heads=self.num_heads, + features=self.embed_dim, + head_dim=self.head_dim, + num_kv_heads=self.num_kv_heads, + ) + self.pre_ffw_norm = layers.RMSNorm() + self.mlp = FeedForward(features=self.embed_dim, + hidden_dim=self.hidden_dim) + + def __call__( + self, + x: jax.Array, + segment_pos: int, + cache: LayerCache, + attn_mask: jax.Array, + time_step: int, + ): + inputs_normalized = self.pre_attention_norm(x) + cache, attn_output = self.attn(inputs_normalized, segment_pos, cache, + attn_mask, time_step) + attn_output += x + residual = attn_output + attn_output = self.pre_ffw_norm(attn_output) + outputs = self.mlp(attn_output) + outputs = residual + outputs + return cache, outputs diff --git a/examples/gemma/utils/params.py b/examples/gemma/utils/params.py new file mode 100644 index 000000000..cfce8bcac --- /dev/null +++ b/examples/gemma/utils/params.py @@ -0,0 +1,73 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utils for loading Gemma params. + +These utilities are just helpers for current development. They will not be +needed once Gemma switches to Orbax and changes checkpoint formats ahead of +open sourcing. +""" + +import functools +from typing import Any + +import orbax.checkpoint + +Params = dict[str, Any] + + +@functools.cache +def load_params(path: str) -> Params: + """Loads parameters from a checkpoint path.""" + checkpointer = orbax.checkpoint.PyTreeCheckpointer() + params = checkpointer.restore(path) + return params + + +def param_remapper(orig_params: Params) -> Params: + """Remaps params to new module layout. + + This is needed here because the model definition does not have a separate + `mlp` module. For the real code release, we will just save the params in a + different format and this will not be needed. + + Args: + orig_params: original dict of parameters in Gemma format. + + Returns: + dict of params with different names. + """ + new_params = {} + for k, v in orig_params.items(): + if 'mlp/' in k: + layer_name, param = k.rsplit('/', maxsplit=1) + if layer_name not in new_params: + new_params[layer_name] = {} + if 'w' in v: + new_params[layer_name][param] = v['w'] + else: + new_params[k] = v + return new_params + + +def nest_params(params: Params) -> Params: + """Nests params as a dict of dicts rather than a flat dict.""" + nested_params = {} + for path, param in params.items(): + *path, leaf = path.split('/') + subdict = nested_params + for key in path: + subdict = subdict.setdefault(key, {}) + subdict[leaf] = param + return nested_params diff --git a/examples/gemma/utils/positional_embeddings.py b/examples/gemma/utils/positional_embeddings.py new file mode 100644 index 000000000..25707692d --- /dev/null +++ b/examples/gemma/utils/positional_embeddings.py @@ -0,0 +1,92 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utils for positional embeddings (including RoPE). +""" + +import jax +import jax.numpy as jnp + +_MAX_WAVELENGTH = 10_000 + + +def add_positional_embedding( + input_embedding: jax.Array, + position: int, + max_wavelength: int = _MAX_WAVELENGTH, +) -> jax.Array: + """Adds positional embeddings to input embeddings.""" + embed_dim = input_embedding.shape[-1] + num_timescales = embed_dim // 2 + log_timescale_increment = jnp.log(float(max_wavelength)) / jnp.maximum( + jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1) + inv_timescales = jnp.exp( + jnp.arange(num_timescales, dtype=jnp.float32) * + -log_timescale_increment) + scaled_time = position * inv_timescales + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)]) + signal = jnp.pad(signal, [[0, jnp.mod(embed_dim, 2)]]) + position_embedding = signal.astype(jnp.float32) + + return input_embedding + position_embedding + + +def _rotary_embed( + inputs: jax.Array, # [B, 1, H, D] + position: jax.Array, # [B,] + head_dim: int, + max_wavelength: int = _MAX_WAVELENGTH, +) -> jax.Array: + """Helper for RoPE.""" + fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim + timescale = max_wavelength**fraction + timescale = timescale[jnp.newaxis, jnp.newaxis, jnp.newaxis, :] + + sinusoid_inp = position[:, jnp.newaxis, jnp.newaxis, + jnp.newaxis] / timescale + sin = jnp.sin(sinusoid_inp) + cos = jnp.cos(sinusoid_inp) + + first_half, second_half = jnp.split(inputs, 2, axis=-1) + first_part = first_half * cos - second_half * sin + second_part = second_half * cos + first_half * sin + + return jnp.concatenate([first_part, second_part], axis=-1) + + +def apply_rope( + inputs: jax.Array, + position: int, + head_dim: int, + max_wavelength: int = _MAX_WAVELENGTH, +) -> jax.Array: + """Applies RoPE.""" + batch_size, seq_length = inputs.shape[0:2] + + position = jnp.broadcast_to(position, [batch_size])[:, jnp.newaxis] + prefix_position = jnp.arange(seq_length, dtype=jnp.int32) + prefix_position = (position - jnp.flip(prefix_position)[jnp.newaxis, :] + ) # [B, seq_len] + prefix_position = jnp.where(prefix_position < 0, + jnp.zeros_like(prefix_position), + prefix_position).reshape((batch_size, )) + + output = _rotary_embed( + inputs, + position=prefix_position, + head_dim=head_dim, + max_wavelength=max_wavelength, + ) + + return output diff --git a/examples/gemma/utils/sampler.py b/examples/gemma/utils/sampler.py new file mode 100644 index 000000000..4a65b90c6 --- /dev/null +++ b/examples/gemma/utils/sampler.py @@ -0,0 +1,190 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Sampler for Gemma transformer. + +An example of a sampling class for a Gemma model. +""" + +import chex +import jax +import jax.numpy as jnp +import sentencepiece as spm + +from . import modules +from . import params as params_lib +from . import transformer as transformer_lib + + +def _compute_attention_masks(time_step: jax.Array, seq_len: int, + input_mask: jax.Array) -> jax.Array: + """Computes causal attention mask.""" + bsz = input_mask.shape[0] + batch_time_step = jnp.full((bsz, 1), time_step, dtype=jnp.uint32) + causal_padding = jnp.greater(jnp.expand_dims(jnp.arange(seq_len), 0), + batch_time_step) + causal_padding = causal_padding * jnp.expand_dims(input_mask, axis=-1) + attention_mask = ( + causal_padding[:, jnp.newaxis, jnp.newaxis, :].astype(jnp.float32) * + modules.K_MASK) + attention_mask = jnp.squeeze(attention_mask, axis=1) + return attention_mask + + +@chex.dataclass +class _SamplingState: + + # Number of tokens in the prompt. + num_input_tokens: jnp.int32 # [B] + + # Fixed-size buffer for accumulating the output tokens. + token_buffer: jnp.ndarray # [B, L] + + # Model state for conditioning the model on autoregressively. + cache: dict[str, modules.LayerCache] + + +class Sampler: + """Sampler for Gemma transformer.""" + + def __init__( + self, + transformer_config: transformer_lib.TransformerConfig, + vocab: spm.SentencePieceProcessor, + params: params_lib.Params, + cache_size: int, + buffer_size: int, + max_decode_steps: int, + ): + self.transformer = transformer_lib.Transformer( + config=transformer_config) + self.vocab = vocab + self.params = params + self.cache_size = cache_size + self.buffer_size = buffer_size + self.max_decode_steps = max_decode_steps + self._compiled_sample_fn = jax.jit(self._sample_fn) + + def _sample_step(self, params, time_step, + sampler_state: _SamplingState) -> _SamplingState: + """Performs a single sampling step.""" + time_step = jnp.asarray(time_step, dtype=jnp.int32) + last_token = sampler_state.token_buffer[:, time_step] + input_mask = last_token != self.vocab.pad_id() + attention_mask = _compute_attention_masks( + time_step, self.cache_size, input_mask).astype(jnp.float32) + + logits, cache = self.transformer.apply( + {'params': params}, + last_token, + time_step, + sampler_state.cache, + attention_mask, + time_step, + ) + + next_token_candidate = jnp.argmax(logits, axis=-1) # [B, 1] + next_token_candidate = next_token_candidate[:, 0] # [B,] + + next_token_candidate = jnp.where( + time_step < sampler_state.num_input_tokens - 1, + sampler_state.token_buffer[:, time_step + 1], + next_token_candidate, + ) + + token_buffer = sampler_state.token_buffer.at[:, time_step + 1].set( + next_token_candidate) + + return _SamplingState( + num_input_tokens=sampler_state.num_input_tokens, + token_buffer=token_buffer, + cache=cache, + ) + + def init_cache(self, bsz) -> dict[str, modules.LayerCache]: + """Initializes the attention cache for each layer.""" + return { + f'layer_{i}': modules.init_layer_cache( + self.cache_size, + self.transformer.config.num_heads, + self.transformer.config.head_dim, + bsz, + ) + for i in range(self.transformer.config.num_layers) + } + + def init_sample_state(self, + all_input_ids: list[jax.Array]) -> _SamplingState: + """Initializes the sampling state given input prompts.""" + bsz = len(all_input_ids) + num_input_tokens = [len(input_ids) for input_ids in all_input_ids] + + token_buffer = jnp.full( + ( + bsz, + self.buffer_size, + ), + self.vocab.pad_id(), + dtype=jnp.int32, + ) + for i, (input_ids, + num_tokens) in enumerate(zip(all_input_ids, num_input_tokens)): + token_buffer = token_buffer.at[i, :num_tokens].set(input_ids) + + return _SamplingState( + num_input_tokens=jnp.array(num_input_tokens, dtype=jnp.int32), + token_buffer=token_buffer, + cache=self.init_cache(bsz), + ) + + def tokenize(self, input_string: str) -> jax.Array: + """Tokenizes the input string.""" + input_ids = self.vocab.EncodeAsIds(input_string) + input_ids = jnp.array([self.vocab.bos_id()] + + jnp.array(input_ids).tolist(), + dtype=jnp.int32) + return input_ids + + def _sample_fn( + self, + params: params_lib.Params, + initial_sampling_state: _SamplingState, + ) -> _SamplingState: + + def sample_with_params(time_step: int, sampler_state: _SamplingState): + return self._sample_step(params, time_step, sampler_state) + + return jax.lax.fori_loop(0, self.max_decode_steps, sample_with_params, + initial_sampling_state) + + def __call__(self, input_strings: list[str] | str) -> list[str]: + """Samples a completion of the input string.""" + if isinstance(input_strings, str): + input_strings = [input_strings] + all_input_ids = [self.tokenize(x) for x in input_strings] + initial_sampling_state = self.init_sample_state(all_input_ids) + + sampling_state = self._compiled_sample_fn(self.params, + initial_sampling_state) + + out_tokens = [ + buffer[num_tokens:num_tokens + self.max_decode_steps] + for buffer, num_tokens in zip(sampling_state.token_buffer, + sampling_state.num_input_tokens) + ] + decoded_outputs = [ + self.vocab.DecodeIds(out_tokens.tolist()) + for out_tokens in out_tokens + ] + return decoded_outputs diff --git a/examples/gemma/utils/transformer.py b/examples/gemma/utils/transformer.py new file mode 100644 index 000000000..9fb89fcbb --- /dev/null +++ b/examples/gemma/utils/transformer.py @@ -0,0 +1,113 @@ +"""Gemma transformer.""" + +import dataclasses + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from . import layers, modules +from . import params as params_lib + +Cache = dict[str, modules.LayerCache] + + +@dataclasses.dataclass +class TransformerConfig: + """Configuration for the Gemma transformer.""" + + num_layers: int + num_embed: int + embed_dim: int + hidden_dim: int + num_heads: int + head_dim: int + num_kv_heads: int + + @classmethod + def from_params(cls, params: params_lib.Params, + num_embed: int) -> 'TransformerConfig': + """Creates a TransformerConfig from loaded parameters.""" + num_layers = (max([ + int(k.split('_')[1]) + for k in params['transformer'].keys() if 'layer_' in k + ]) + 1) + hidden_dim, embed_dim = ( + params['transformer']['layer_0']['mlp']['linear'].shape) + num_heads, head_dim, _ = (params['transformer']['layer_0']['attn'] + ['attn_vec_einsum']['w'].shape) + use_qkv_einsum = 'qkv_einsum' in params['transformer']['layer_0'][ + 'attn'] + if use_qkv_einsum: + num_kv_heads = num_heads + else: + num_kv_heads = params['transformer']['layer_0']['attn'][ + 'kv_einsum']['w'].shape[1] + return cls( + num_layers=num_layers, + num_embed=num_embed, + embed_dim=embed_dim, + hidden_dim=hidden_dim, + num_heads=num_heads, + head_dim=head_dim, + num_kv_heads=num_kv_heads, + ) + + +def init_cache(config: TransformerConfig, cache_size: int, + batch_size: int) -> Cache: + """Initializes a new Transformer cache.""" + return { + f'layer_{i}': modules.init_layer_cache(cache_size, config.num_heads, + config.head_dim, batch_size) + for i in range(config.num_layers) + } + + +class Transformer(nn.Module): + """Gemma transformer.""" + + config: TransformerConfig + + def setup(self): + self.embedder = modules.Embedder( + vocab_size=self.config.num_embed, + embed_dim=self.config.embed_dim, + ) + self.blocks = [ + modules.Block( + name=f'layer_{i}', + num_heads=self.config.num_heads, + num_kv_heads=self.config.num_kv_heads, + embed_dim=self.config.embed_dim, + head_dim=self.config.head_dim, + hidden_dim=self.config.hidden_dim, + ) for i in range(self.config.num_layers) + ] + self.final_norm = layers.RMSNorm() + + def __call__( + self, + last_tokens: jax.Array, # [B,] + current_token_position: int, + cache: Cache, + attention_mask: jax.Array, # [B, 1, L] + time_step: int, + ) -> tuple[jax.Array, Cache]: + input_emb = self.embedder.encode(last_tokens) + x = jnp.expand_dims(input_emb, axis=1) # adding temporal dimension + + for i, block in enumerate(self.blocks): + layer_name = f'layer_{i}' + cache[layer_name], x = block( + x, + current_token_position, + cache[layer_name], + attention_mask, + time_step, + ) + + x = self.final_norm(x) + logits = self.embedder.decode(x) + + return logits, cache diff --git a/examples/llama/convert_checkpoint.py b/examples/llama/convert_checkpoint.py index a1ca61740..cc10378e0 100644 --- a/examples/llama/convert_checkpoint.py +++ b/examples/llama/convert_checkpoint.py @@ -30,11 +30,6 @@ from tensorrt_llm.models.modeling_utils import PretrainedConfig from tensorrt_llm.runtime.lora_manager import LoraConfig -try: - from transformers import MixtralForCausalLM -except ImportError: - MixtralForCausalLM = None - try: from transformers import LlavaConfig, LlavaForConditionalGeneration except ImportError: diff --git a/examples/mixtral/README.md b/examples/mixtral/README.md index bc640e392..cb4fa945d 100644 --- a/examples/mixtral/README.md +++ b/examples/mixtral/README.md @@ -34,8 +34,7 @@ Here are some examples: python convert_checkpoint.py --model_dir ./Mixtral-8x7B-v0.1 \ --output_dir ./tllm_checkpoint_mixtral_2gpu \ --dtype float16 \ - --world_size 2 \ - --Pp_size 2 + --pp_size 2 trtllm-build --checkpoint_dir ./tllm_checkpoint_mixtral_2gpu \ --output_dir ./trt_engines/mixtral/pp2 \ --gemm_plugin float16 @@ -47,7 +46,6 @@ trtllm-build --checkpoint_dir ./tllm_checkpoint_mixtral_2gpu \ python convert_checkpoint.py --model_dir ./Mixtral-8x7B-v0.1 \ --output_dir ./tllm_checkpoint_mixtral_2gpu \ --dtype float16 \ - --world_size 2 \ --tp_size 2 trtllm-build --checkpoint_dir ./tllm_checkpoint_mixtral_2gpu \ --output_dir ./trt_engines/mixtral/tp2 \ diff --git a/examples/mmlu.py b/examples/mmlu.py index 24cebcfcc..ac560ff48 100644 --- a/examples/mmlu.py +++ b/examples/mmlu.py @@ -248,7 +248,7 @@ def __init__(self, tokenizer, model, model_name, pad_id, end_id, def __call__(self, prompt): # Run the model in batch size 1 and beam size 1 - if self.model_name == 'SpecialForCausalLM': + if self.model_name == 'GemmaForCausalLM': inputs = self.tokenizer.encode(prompt, add_special_tokens=False) inputs = torch.tensor([self.tokenizer.bos_token_id] + inputs) else: diff --git a/examples/quantization/README.md b/examples/quantization/README.md index 0b48fdfe2..6506d91d1 100644 --- a/examples/quantization/README.md +++ b/examples/quantization/README.md @@ -20,8 +20,8 @@ docker run --gpus all --ipc=host --ulimit memlock=-1 --shm-size=20g -it pip install -r requirements.txt @@ -29,7 +29,7 @@ pip install -r requirements.txt ## APIs -[`ammo.py`](../../tensorrt_llm/models/quantized/ammo.py) uses the quantization toolkit to calibrate the PyTorch models, and generate a model config, saved as a json (for the model structure) and npz files (for the model weights) that TensorRT-LLM could parse. The model config includes everything needed by TensorRT-LLM to build the TensorRT inference engine, as explained below. +[`quantize.py`](./quantize.py) uses the quantization toolkit to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format). The checkpoints can be directly used by `trtllm-build` command to build TensorRT-LLM engines. See this [`doc`](../../docs/source/new_workflow.md) for more details on the TensorRT-LLM checkpoint format. > *This quantization step may take a long time to finish and requires large GPU memory. Please use a server grade GPU if a GPU out-of-memory error occurs* @@ -41,33 +41,35 @@ pip install -r requirements.txt PTQ can be achieved with simple calibration on a small set of training or evaluation data (typically 128-512 samples) after converting a regular PyTorch model to a quantized model. ```python +import torch +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM import ammo.torch.quantization as atq -model = AutoModelForCausalLM.from_pretrained("...") +model = AutoModelForCausalLM.from_pretrained(...) # Select the quantization config, for example, FP8 config = atq.FP8_DEFAULT_CFG # Prepare the calibration set and define a forward loop -def forward_loop(): - for data in calib_set: +calib_dataloader = DataLoader(...) +def calibrate_loop(): + for data in calib_dataloader: model(data) # PTQ with in-place replacement to quantized modules with torch.no_grad(): - atq.quantize(model, config, forward_loop) + atq.quantize(model, config, forward_loop=calibrate_loop) ``` ### Export Quantized Model -After the model is quantized, the model config can be stored. The model config files include all the information needed by TensorRT-LLM to generate the deployable engine, including the quantized scaling factors. +After the model is quantized, it can be exported to a TensorRT-LLM checkpoint, which includes -The exported model config are stored as - -- A single JSON file recording the model structure and metadata and -- A group of npz files each recording the model on a single tensor parallel rank (model weights, scaling factors per GPU). +- One json file recording the model structure and metadata, and +- One or several rank weight files storing quantized model weights and scaling factors. The export API is @@ -80,6 +82,8 @@ with torch.inference_mode(): decoder_type, # The type of the model as str, e.g gptj, llama or gptnext. dtype, # The exported weights data type as torch.dtype. export_dir, # The directory where the exported files will be stored. - inference_gpus, # The number of GPUs used in the inference time for tensor parallelism. + inference_tensor_parallel=tp_size, # The tensor parallelism size for inference. + inference_pipeline_parallel=pp_size, # The pipeline parallelism size for inference. + export_tensorrt_llm_config=True, # Enable exporting TensorRT-LLM checkpoint config file. ) ``` diff --git a/examples/quantization/quantize.py b/examples/quantization/quantize.py index 71a327efa..757811ecd 100644 --- a/examples/quantization/quantize.py +++ b/examples/quantization/quantize.py @@ -110,6 +110,7 @@ "Bloom": "bloom", "ChatGLM": "chatglm", "QWen": "qwen", + "Gemma": "gemma", } @@ -296,7 +297,7 @@ def main(args): torch.save(model.state_dict(), export_path) else: export_npz = (model_type not in [ - 'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan' + 'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan', 'gemma' ]) export_model_config(model, model_type, @@ -320,19 +321,6 @@ def main(args): with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) - # TODO(enweiz): Remove if a newer AMMO version is released - # Workaround for baichuan - if model_type == 'baichuan': - with open(f"{export_path}/config.json", 'r') as f: - tensorrt_llm_config = json.load(f) - if hasattr(model.model, "alibi_mask"): - tensorrt_llm_config["position_embedding_type"] = 'alibi' - else: - tensorrt_llm_config[ - "position_embedding_type"] = 'rope_gpt_neox' - with open(f"{export_path}/config.json", "w") as f: - json.dump(tensorrt_llm_config, f, indent=4) - end_time = time.time() print( "Quantized model exported to {} \nTotal time used {:.2f} s.".format( diff --git a/examples/run.py b/examples/run.py index f7f3a662e..76e452b98 100644 --- a/examples/run.py +++ b/examples/run.py @@ -214,7 +214,7 @@ def parse_input(tokenizer, else: print('Input file format not supported.') raise SystemExit - if model_name == 'SpecialForCausalLM': + if model_name == 'GemmaForCausalLM': batch_input_ids[0] = [tokenizer.bos_token_id] + batch_input_ids[0] if num_prepend_vtokens: diff --git a/examples/skywork/README.md b/examples/skywork/README.md index c3a3826bd..3655ca0e6 100644 --- a/examples/skywork/README.md +++ b/examples/skywork/README.md @@ -57,9 +57,9 @@ python3 convert_checkpoint.py --model_dir ./Skywork-13B-base \ ```bash # fp16 trtllm-build --checkpoint_dir ./skywork-13b-base/trt_ckpt/fp16 \ - --use_gemm_plugin float16 \ - --use_gpt_attention_plugin float16 \ - --enable_context_fmha \ + --gemm_plugin float16 \ + --gpt_attention_plugin float16 \ + --context_fmha enable \ --max_batch_size 32 \ --max_input_len 512 \ --max_output_len 512 \ @@ -67,9 +67,9 @@ trtllm-build --checkpoint_dir ./skywork-13b-base/trt_ckpt/fp16 \ # bf16 trtllm-build --checkpoint_dir ./skywork-13b-base/trt_ckpt/bf16 \ - --use_gemm_plugin bfloat16 \ - --use_gpt_attention_plugin bfloat16 \ - --enable_context_fmha \ + --gemm_plugin bfloat16 \ + --gpt_attention_plugin bfloat16 \ + --context_fmha enable \ --max_batch_size 32 \ --max_input_len 512 \ --max_output_len 512 \ @@ -85,23 +85,23 @@ After building TRT engines, we can use them to perform various tasks. TensorRT-L python ../summarize.py --hf_model_dir ./Skywork-13B-base \ --test_hf \ --batch_size 32 \ - --max_input_length 512 + --max_input_length 512 \ --output_len 512 \ --test_trt_llm \ --engine_dir ./skywork-13b-base/trt_engine/fp16 \ --data_type fp16 \ - -check_accuracy \ + --check_accuracy \ --tensorrt_llm_rouge1_threshold=14 # bf16 python ../summarize.py --hf_model_dir ./Skywork-13B-base \ --test_hf \ --batch_size 32 \ - --max_input_length 512 + --max_input_length 512 \ --output_len 512 \ --test_trt_llm \ --engine_dir ./skywork-13b-base/trt_engine/bf16 \ --data_type bf16 \ - -check_accuracy \ + --check_accuracy \ --tensorrt_llm_rouge1_threshold=14 ``` diff --git a/examples/summarize.py b/examples/summarize.py index 560d2440b..9671af4db 100644 --- a/examples/summarize.py +++ b/examples/summarize.py @@ -157,12 +157,13 @@ def _prepare_inputs(batch_input_texts, max_input_length=test_token_num, ) input_ids = torch.tensor(input_id_list) - elif model_name == 'SpecialForCausalLM': + elif model_name == 'GemmaForCausalLM': input_ids = tokenizer.encode( curr_text, add_special_tokens=add_special_tokens, truncation=True, - max_length=test_token_num) + max_length=test_token_num - + 1) # minus 1 to add bos_token_id input_ids = torch.tensor([tokenizer.bos_token_id] + input_ids) else: input_ids = tokenizer.encode( diff --git a/examples/utils.py b/examples/utils.py index dfb3d812d..7ab587314 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -21,25 +21,25 @@ import tensorrt_llm -# TODO(enweiz): Update for refactered models +# TODO(enweiz): Update for refactored models DEFAULT_HF_MODEL_DIRS = { - 'baichuan': 'baichuan-inc/Baichuan-13B-Chat', + 'BaichuanForCausalLM': 'baichuan-inc/Baichuan-13B-Chat', 'BloomForCausalLM': 'bigscience/bloom-560m', 'ChatGLMForCausalLM': 'THUDM/chatglm3-6b', 'FalconForCausalLM': 'tiiuae/falcon-rw-1b', 'gpt': 'gpt2-medium', 'GPTJForCausalLM': 'EleutherAI/gpt-j-6b', 'GPTNeoXForCausalLM': 'EleutherAI/gpt-neox-20b', - 'internlm': 'internlm/internlm-chat-7b', - 'llama': 'meta-llama/Llama-2-7b-hf', - 'mpt': 'mosaicml/mpt-7b', + 'InternLMForCausalLM': 'internlm/internlm-chat-7b', + 'LlamaForCausalLM': 'meta-llama/Llama-2-7b-hf', + 'MPTForCausalLM': 'mosaicml/mpt-7b', 'PhiForCausalLM': 'microsoft/phi-2', 'OPTForCausalLM': 'facebook/opt-350m', 'qwen': 'Qwen/Qwen-7B', } DEFAULT_PROMPT_TEMPLATES = { - 'internlm': + 'InternLMForCausalLM': "<|User|>:{input_text}\n<|Bot|>:", 'qwen': "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n", @@ -110,7 +110,7 @@ def load_tokenizer(tokenizer_dir: Optional[str] = None, elif model_name == 'ChatGLMForCausalLM' and model_version == 'glm': pad_id = tokenizer.pad_token_id end_id = tokenizer.eop_token_id - elif model_name == 'SpecialForCausalLM': + elif model_name == 'GemmaForCausalLM': tokenizer.eos_token_id = tokenizer.sp_model.eos_id() tokenizer.bos_token_id = tokenizer.sp_model.bos_id() pad_id = tokenizer.pad_token_id diff --git a/examples/whisper/README.md b/examples/whisper/README.md index 00d37c8b6..63fa83ec7 100755 --- a/examples/whisper/README.md +++ b/examples/whisper/README.md @@ -37,10 +37,10 @@ TensorRT-LLM Whisper builds TensorRT engine(s) from the pytorch checkpoint. pip install -r requirements.txt # Build the large-v3 model using a single GPU with plugins. -python3 build.py --output_dir whisper_large_v3 --use_gpt_attention_plugin --use_gemm_plugin --use_bert_attention_plugin +python3 build.py --output_dir whisper_large_v3 --use_gpt_attention_plugin --use_gemm_plugin --use_bert_attention_plugin --enable_context_fmha # Build the large-v3 model using a single GPU with plugins and weight-only quantization. -python3 build.py --output_dir whisper_large_weight_only --use_gpt_attention_plugin --use_gemm_plugin --use_bert_attention_plugin --use_weight_only +python3 build.py --output_dir whisper_large_weight_only --use_gpt_attention_plugin --use_gemm_plugin --use_bert_attention_plugin --enable_context_fmha --use_weight_only ``` ### Run diff --git a/examples/whisper/build.py b/examples/whisper/build.py index a963c1eb3..ab9602939 100644 --- a/examples/whisper/build.py +++ b/examples/whisper/build.py @@ -26,6 +26,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.models import quantize_model from tensorrt_llm.network import net_guard +from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode MODEL_ENCODER_NAME = "whisper_encoder" @@ -116,6 +117,9 @@ def parse_arguments(): action="store_true", help='Quantize weights for the various GEMMs to INT4/INT8.' 'See --weight_only_precision to set the precision') + parser.add_argument('--enable_context_fmha', + default=False, + action='store_true') parser.add_argument( '--weight_only_precision', const='int8', @@ -203,9 +207,11 @@ def build_encoder(model, args): if args.use_weight_only: tensorrt_llm_whisper_encoder = quantize_model( tensorrt_llm_whisper_encoder, args.quant_mode) + use_gemm_woq_plugin = args.use_gemm_plugin and args.use_weight_only load_encoder_weight(tensorrt_llm_whisper_encoder, model_metadata, - model_params, model_metadata['n_audio_layer']) + model_params, model_metadata['n_audio_layer'], + use_gemm_woq_plugin) network = builder.create_network() network.plugin_config.to_legacy_setting() @@ -215,6 +221,8 @@ def build_encoder(model, args): if args.use_bert_attention_plugin: network.plugin_config.set_bert_attention_plugin( dtype=args.use_bert_attention_plugin) + if args.enable_context_fmha: + network.plugin_config.set_context_fmha(ContextFMHAType.enabled) if args.remove_input_padding: network.plugin_config.enable_remove_input_padding() if args.use_weight_only: @@ -310,11 +318,10 @@ def build_decoder(model, args): if args.use_weight_only: tensorrt_llm_whisper_decoder = quantize_model( tensorrt_llm_whisper_decoder, args.quant_mode) + use_gemm_woq_plugin = args.use_gemm_plugin and args.use_weight_only - load_decoder_weight( - tensorrt_llm_whisper_decoder, - model_params, - ) + load_decoder_weight(tensorrt_llm_whisper_decoder, model_params, + use_gemm_woq_plugin) network = builder.create_network() network.plugin_config.to_legacy_setting() @@ -324,8 +331,13 @@ def build_decoder(model, args): if args.use_gpt_attention_plugin: network.plugin_config.set_gpt_attention_plugin( dtype=args.use_gpt_attention_plugin) + if args.enable_context_fmha: + network.plugin_config.set_context_fmha(ContextFMHAType.enabled) if args.remove_input_padding: network.plugin_config.enable_remove_input_padding() + if args.use_weight_only: + network.plugin_config.set_weight_only_quant_matmul_plugin( + dtype=args.dtype) with net_guard(network): inputs = tensorrt_llm_whisper_decoder.prepare_inputs( diff --git a/examples/whisper/requirements.txt b/examples/whisper/requirements.txt index 5bb070e2a..d5dd1dcab 100644 --- a/examples/whisper/requirements.txt +++ b/examples/whisper/requirements.txt @@ -3,3 +3,4 @@ datasets kaldialign openai-whisper soundfile +safetensors diff --git a/examples/whisper/run.py b/examples/whisper/run.py index 45a763786..f9b199151 100644 --- a/examples/whisper/run.py +++ b/examples/whisper/run.py @@ -19,13 +19,14 @@ from collections import OrderedDict from pathlib import Path +import numpy as np import torch from datasets import load_dataset from tokenizer import get_tokenizer from torch.utils.data import DataLoader from whisper.normalizers import EnglishTextNormalizer -from whisper_utils import (log_mel_spectrogram, store_transcripts, - write_error_stats) +from whisper_utils import (N_SAMPLES, log_mel_spectrogram, pad_or_trim, + store_transcripts, write_error_stats) import tensorrt_llm import tensorrt_llm.logger as logger @@ -291,12 +292,18 @@ def decode_wav_file( def collate_wrapper(batch): - speeches, labels, ids = [], [], [] + speeches, durations, labels, ids = [], [], [], [] for item in batch: - speeches.append(item["audio"]["array"]) + speech = item["audio"]["array"] + duration = speech.shape[-1] + speech = pad_or_trim(speech, N_SAMPLES) + speech = speech.astype(np.float32) + speech = torch.from_numpy(speech) + speeches.append(speech) + durations.append(duration) labels.append(item["text"]) ids.append(item["id"]) - return speeches, labels, ids + return speeches, durations, labels, ids def decode_dataset( @@ -319,9 +326,12 @@ def decode_dataset( results = [] total_duration = 0 for batch in data_loader: - waveforms, texts, ids = batch - total_duration += sum([wave.shape[0] - for wave in waveforms]) / sample_rate + waveforms, durations, texts, ids = batch + total_duration += sum(durations) / sample_rate + + for wave in waveforms: + assert wave.is_pinned() + features = [ log_mel_spectrogram(wave, model.n_mels, diff --git a/examples/whisper/weight.py b/examples/whisper/weight.py index bd72adeed..12100632e 100644 --- a/examples/whisper/weight.py +++ b/examples/whisper/weight.py @@ -47,8 +47,11 @@ def trans_weight(weight): return np.ascontiguousarray(weight) -def load_encoder_weight(tensorrt_llm_whisper, model_metadata: dict, - model_params: dict, n_layer: int): +def load_encoder_weight(tensorrt_llm_whisper, + model_metadata: dict, + model_params: dict, + n_layer: int, + use_gemm_woq_plugin=True): tensorrt_llm.logger.info('Loading encoder weights from PT...') quant_mode = getattr(tensorrt_llm_whisper, 'quant_mode', QuantMode(0)) @@ -59,6 +62,8 @@ def load_encoder_weight(tensorrt_llm_whisper, model_metadata: dict, use_weight_only = quant_mode.is_weight_only() + param_dtype = 'float16' + tensorrt_llm_whisper.positional_embedding.value = sinusoids( model_metadata['n_audio_ctx'], model_metadata['n_audio_state']).numpy() @@ -92,7 +97,12 @@ def load_encoder_weight(tensorrt_llm_whisper, model_metadata: dict, processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( torch.tensor(np.ascontiguousarray(t.transpose(1, 0))), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() + if not use_gemm_woq_plugin: + dst.value = torch.tensor( + np.ascontiguousarray(t.transpose(1, 0))).numpy().astype( + str_dtype_to_np(param_dtype)) + else: + dst.value = processed_torch_weights.numpy() scales = tensorrt_llm_whisper.encoder_layers[ i].attention.qkv.per_channel_scale scales.value = torch_weight_scales.numpy() @@ -120,7 +130,12 @@ def load_encoder_weight(tensorrt_llm_whisper, model_metadata: dict, processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( torch.tensor(np.ascontiguousarray(t.transpose(1, 0))), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() + if not use_gemm_woq_plugin: + dst.value = torch.tensor( + np.ascontiguousarray(t.transpose(1, 0))).numpy().astype( + str_dtype_to_np(param_dtype)) + else: + dst.value = processed_torch_weights.numpy() scales = tensorrt_llm_whisper.encoder_layers[ i].attention.dense.per_channel_scale scales.value = torch_weight_scales.numpy() @@ -147,7 +162,12 @@ def load_encoder_weight(tensorrt_llm_whisper, model_metadata: dict, processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( torch.tensor(np.ascontiguousarray(t.transpose(1, 0))), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() + if not use_gemm_woq_plugin: + dst.value = torch.tensor( + np.ascontiguousarray(t.transpose(1, 0))).numpy().astype( + str_dtype_to_np(param_dtype)) + else: + dst.value = processed_torch_weights.numpy() scales = tensorrt_llm_whisper.encoder_layers[ i].mlp.fc.per_channel_scale scales.value = torch_weight_scales.numpy() @@ -164,7 +184,12 @@ def load_encoder_weight(tensorrt_llm_whisper, model_metadata: dict, processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( torch.tensor(np.ascontiguousarray(t.transpose(1, 0))), plugin_weight_only_quant_type) - dst.value = processed_torch_weights.numpy() + if not use_gemm_woq_plugin: + dst.value = torch.tensor( + np.ascontiguousarray(t.transpose(1, 0))).numpy().astype( + str_dtype_to_np(param_dtype)) + else: + dst.value = processed_torch_weights.numpy() scales = tensorrt_llm_whisper.encoder_layers[ i].mlp.proj.per_channel_scale scales.value = torch_weight_scales.numpy() @@ -186,10 +211,9 @@ def fuse_qkv(q, k, v): return qkv_weight -def load_decoder_weight( - tllm_model, - model_params: dict, -): +def load_decoder_weight(tllm_model, + model_params: dict, + use_gemm_woq_plugin=True): tensorrt_llm.logger.info('Loading decoder weights from PT...') quant_mode = getattr(tllm_model, 'quant_mode', QuantMode(0)) @@ -201,6 +225,8 @@ def load_decoder_weight( plugin_weight_only_quant_type = torch.quint4x2 use_weight_only = quant_mode.is_weight_only() + use_int8_kv_cache = quant_mode.has_int8_kv_cache() + tllm_model.embedding.vocab_embedding.weight.value = trans_weight( model_params['decoder.token_embedding.weight'].numpy()) tllm_model.lm_head.weight.value = trans_weight( @@ -225,8 +251,12 @@ def load_decoder_weight( processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( torch.tensor(np.ascontiguousarray(t.transpose(1, 0))), plugin_weight_only_quant_type) - dst.value = torch.tensor(np.ascontiguousarray(t.transpose( - 1, 0))).numpy().astype(str_dtype_to_np(param_dtype)) + if not use_gemm_woq_plugin: + dst.value = torch.tensor( + np.ascontiguousarray(t.transpose(1, 0))).numpy().astype( + str_dtype_to_np(param_dtype)) + else: + dst.value = processed_torch_weights.numpy() scales = layer.self_attention.qkv.per_channel_scale scales.value = torch_weight_scales.numpy() else: @@ -241,8 +271,12 @@ def load_decoder_weight( processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( torch.tensor(np.ascontiguousarray(t.transpose(1, 0))), plugin_weight_only_quant_type) - dst.value = torch.tensor(np.ascontiguousarray(t.transpose( - 1, 0))).numpy().astype(str_dtype_to_np(param_dtype)) + if not use_gemm_woq_plugin: + dst.value = torch.tensor( + np.ascontiguousarray(t.transpose(1, 0))).numpy().astype( + str_dtype_to_np(param_dtype)) + else: + dst.value = processed_torch_weights.numpy() scales = layer.self_attention.dense.per_channel_scale scales.value = torch_weight_scales.numpy() else: @@ -263,6 +297,12 @@ def load_decoder_weight( model_params['decoder.blocks.' + str(i) + '.attn.out.bias'].numpy()) + if use_int8_kv_cache: + t = fromfile( + "quantize/1-gpu", 'model.decoder.blocks.' + str(i) + + '.attn.query_key_value.scale_y_quant_orig.bin', [1], np.float32) + layer.self_attention.kv_cache_scaling_factor.value = t + layer.self_attention_layernorm.weight.value = trans_weight( model_params['decoder.blocks.' + str(i) + '.attn_ln.weight'].numpy()) @@ -284,17 +324,17 @@ def load_decoder_weight( processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( torch.tensor(np.ascontiguousarray(t.transpose(1, 0))), plugin_weight_only_quant_type) - dst.value = torch.tensor(np.ascontiguousarray(t.transpose( - 1, 0))).numpy().astype(str_dtype_to_np(param_dtype)) + if not use_gemm_woq_plugin: + dst.value = torch.tensor( + np.ascontiguousarray(t.transpose(1, 0))).numpy().astype( + str_dtype_to_np(param_dtype)) + else: + dst.value = processed_torch_weights.numpy() scales = layer.cross_attention.qkv.per_channel_scale scales.value = torch_weight_scales.numpy() else: dst.value = t - layer.cross_attention.dense.weight.value = trans_weight( - model_params['decoder.blocks.' + str(i) + - '.cross_attn.out.weight'].numpy()) - t = trans_weight(model_params['decoder.blocks.' + str(i) + '.cross_attn.out.weight'].numpy()) @@ -304,8 +344,12 @@ def load_decoder_weight( processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( torch.tensor(np.ascontiguousarray(t.transpose(1, 0))), plugin_weight_only_quant_type) - dst.value = torch.tensor(np.ascontiguousarray(t.transpose( - 1, 0))).numpy().astype(str_dtype_to_np(param_dtype)) + if not use_gemm_woq_plugin: + dst.value = torch.tensor( + np.ascontiguousarray(t.transpose(1, 0))).numpy().astype( + str_dtype_to_np(param_dtype)) + else: + dst.value = processed_torch_weights.numpy() scales = layer.cross_attention.dense.per_channel_scale scales.value = torch_weight_scales.numpy() else: @@ -329,6 +373,12 @@ def load_decoder_weight( model_params['decoder.blocks.' + str(i) + '.cross_attn.out.bias'].numpy()) + if use_int8_kv_cache: + t = fromfile( + "quantize/1-gpu", 'model.decoder.blocks.' + str(i) + + '.attn.query_key_value.scale_y_quant_orig.bin', [1], np.float32) + layer.self_attention.kv_cache_scaling_factor.value = t + layer.cross_attention_layernorm.weight.value = trans_weight( model_params['decoder.blocks.' + str(i) + '.cross_attn_ln.weight'].numpy()) @@ -345,8 +395,12 @@ def load_decoder_weight( processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( torch.tensor(np.ascontiguousarray(t.transpose(1, 0))), plugin_weight_only_quant_type) - dst.value = torch.tensor(np.ascontiguousarray(t.transpose( - 1, 0))).numpy().astype(str_dtype_to_np(param_dtype)) + if not use_gemm_woq_plugin: + dst.value = torch.tensor( + np.ascontiguousarray(t.transpose(1, 0))).numpy().astype( + str_dtype_to_np(param_dtype)) + else: + dst.value = processed_torch_weights.numpy() scales = layer.mlp.fc.per_channel_scale scales.value = torch_weight_scales.numpy() else: @@ -361,8 +415,12 @@ def load_decoder_weight( processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( torch.tensor(np.ascontiguousarray(t.transpose(1, 0))), plugin_weight_only_quant_type) - dst.value = torch.tensor(np.ascontiguousarray(t.transpose( - 1, 0))).numpy().astype(str_dtype_to_np(param_dtype)) + if not use_gemm_woq_plugin: + dst.value = torch.tensor( + np.ascontiguousarray(t.transpose(1, 0))).numpy().astype( + str_dtype_to_np(param_dtype)) + else: + dst.value = processed_torch_weights.numpy() scales = layer.mlp.proj.per_channel_scale scales.value = torch_weight_scales.numpy() else: diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index 6e4f09ae1..63c09a398 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -27,6 +27,8 @@ def _add_trt_llm_dll_directory(): _add_trt_llm_dll_directory() +import sys + import tensorrt_llm.functional as functional import tensorrt_llm.models as models import tensorrt_llm.quantization as quantization @@ -80,4 +82,6 @@ def _add_trt_llm_dll_directory(): _init(log_level="error") -print(f"[TensorRT-LLM] TensorRT-LLM version: {__version__}", end='') +print(f"[TensorRT-LLM] TensorRT-LLM version: {__version__}") + +sys.stdout.flush() diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index 1ff7df85e..a8eb9a29f 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -19,6 +19,7 @@ from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder from .falcon.model import FalconForCausalLM, FalconModel +from .gemma.model import GemmaForCausalLM from .gpt.model import GPTLMHeadModel, GPTModel from .gptj.model import GPTJForCausalLM, GPTJModel from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel @@ -68,6 +69,7 @@ 'MPTForCausalLM', 'MPTModel', 'SkyworkForCausalLM', + 'GemmaForCausalLM', ] MODEL_MAP = { @@ -87,4 +89,5 @@ 'MedusaForCausalLM': MedusaForCausalLm, 'BaichuanForCausalLM': BaichuanForCausalLM, 'SkyworkForCausalLM': LLaMAForCausalLM, + 'GemmaForCausalLM': GemmaForCausalLM, } diff --git a/tensorrt_llm/models/enc_dec/model.py b/tensorrt_llm/models/enc_dec/model.py index 178c017ad..4b32d2a6f 100644 --- a/tensorrt_llm/models/enc_dec/model.py +++ b/tensorrt_llm/models/enc_dec/model.py @@ -1107,7 +1107,7 @@ def prepare_inputs( # No enable_two_optimization_profiles support yet encoder_input_len_range = [ - 0, (max_encoder_input_len + 1) // 2, max_encoder_input_len + 1, (max_encoder_input_len + 1) // 2, max_encoder_input_len ] past_key_value = [] sequence_length = None diff --git a/tensorrt_llm/models/gemma/__init__.py b/tensorrt_llm/models/gemma/__init__.py new file mode 100644 index 000000000..71bf6d298 --- /dev/null +++ b/tensorrt_llm/models/gemma/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tensorrt_llm/models/gemma/model.py b/tensorrt_llm/models/gemma/model.py new file mode 100644 index 000000000..5b6e3b12e --- /dev/null +++ b/tensorrt_llm/models/gemma/model.py @@ -0,0 +1,456 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +from pathlib import Path +from typing import Optional + +from transformers import AutoConfig + +from tensorrt_llm import profiler +from tensorrt_llm._utils import pad_vocab_size +from tensorrt_llm.functional import RotaryScalingType, Tensor, recv, send +from tensorrt_llm.layers import (MOE, Attention, AttentionMaskType, + ColumnLinear, Embedding, FusedGatedMLP, + GatedMLP, MoeConfig, PositionEmbeddingType, + PromptTuningEmbedding, RmsNorm) +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import (DecoderLayerList, + DecoderModelForCausalLM) +from tensorrt_llm.module import Module +from tensorrt_llm.plugin import init_all_reduce_helper +from tensorrt_llm.quantization import QuantMode +from tensorrt_llm.runtime.lora_manager import LoraConfig +from tensorrt_llm.top_model_mixin import TopModelMixin + +from .weight import load_from_fp8_llama, load_from_hf_llama + + +class GemmaDecoderLayer(Module): + + def __init__(self, config, layer_idx): + super().__init__() + self.layer_idx = layer_idx + self.config = config + + self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) + + self.attention = Attention( + config.hidden_size, + config.num_attention_heads, + config.num_key_value_heads, + attention_head_size=config.head_size, + max_position_embeddings=config.max_position_embeddings, + dtype=config.dtype, + attention_mask_type=AttentionMaskType.causal, + bias=config.attn_bias, + position_embedding_type=PositionEmbeddingType.rope_gpt_neox, + rotary_embedding_base=config.rotary_base, + rotary_embedding_scaling=config.rotary_scaling, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + quant_mode=config.quant_mode, + enable_pos_shift=config.enable_pos_shift, + dense_context_fmha=config.dense_context_fmha, + ) + # max_lora_rank=config.max_lora_rank) + + mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size + + ClsMLP = GatedMLP + mlp_kwargs = {} + if config.moe_num_experts > 1: + ClsMLP = MOE + mlp_kwargs = { + "moe_config": + MoeConfig( + config.moe_num_experts, + config.moe_top_k, + config.moe_tp_mode, + config.moe_normalization_mode, + ), + "tp_rank": + config.mapping.tp_rank, + } + elif config.use_fused_mlp: + ClsMLP = FusedGatedMLP + + self.mlp = ClsMLP( + hidden_size=config.hidden_size, + ffn_hidden_size=mlp_hidden_size, + hidden_act=config.hidden_act, + dtype=config.dtype, + bias=config.mlp_bias, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + quant_mode=config.quant_mode, + # max_lora_rank=config.max_lora_rank, + **mlp_kwargs) + self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) + + def forward( + self, + hidden_states, + attention_mask=None, + medusa_packed_mask=None, # For Medusa support + medusa_position_offsets=None, + use_cache=False, + kv_cache_params=None, + attention_params=None, + lora_layer_params=None): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + attention_output = self.attention( + hidden_states, + attention_mask=attention_mask, + medusa_packed_mask=medusa_packed_mask, # For Medusa support + medusa_position_offsets=medusa_position_offsets, + use_cache=use_cache, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + lora_layer_params=lora_layer_params) + + if use_cache: + attention_output, presents = attention_output + + hidden_states = residual + attention_output + + residual = hidden_states + hidden_states = self.post_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states, + lora_layer_params=lora_layer_params) + + hidden_states = residual + hidden_states + if use_cache: + return (hidden_states, presents) + return hidden_states + + +class GemmaModel(Module): + + def __init__(self, config) -> None: + super().__init__() + init_all_reduce_helper() + + self.mapping = config.mapping + self.use_prompt_tuning = config.use_prompt_tuning + EmbeddingCls = PromptTuningEmbedding if config.use_prompt_tuning else Embedding + if self.mapping.is_first_pp_rank(): + self.vocab_embedding = EmbeddingCls( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + dtype=config.dtype, + tp_size=self.mapping.tp_size + if config.use_parallel_embedding else 1, + tp_group=self.mapping.tp_group + if config.use_parallel_embedding else None, + sharding_dim=config.embedding_sharding_dim, + tp_rank=self.mapping.tp_rank, + ) + + self.layers = DecoderLayerList(GemmaDecoderLayer, config) + + if self.mapping.is_last_pp_rank(): + self.ln_f = RmsNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) + + def forward( + self, + input_ids, + position_ids=None, + use_cache=False, + attention_mask=None, + medusa_position_offsets=None, # For Medusa support + medusa_packed_mask=None, # For Medusa support + kv_cache_params=None, + attention_params=None, + hidden_states=None, + prompt_embedding_table: Optional[Tensor] = None, + prompt_tasks: Optional[Tensor] = None, + prompt_vocab_size: Optional[Tensor] = None, + lora_params=None): + + kv_cache_params.fill_none_tensor_list(len(self.layers)) + + if use_cache: + presents = [] + + ptuning_args = [] + # if self.use_prompt_tuning: + # ptuning_args = [ + # prompt_embedding_table, prompt_tasks, prompt_vocab_size + # ] + + if self.mapping.is_first_pp_rank(): + hidden_states = self.vocab_embedding(input_ids, *ptuning_args) + else: + hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) + + hidden_states = self.layers.forward( + hidden_states, + use_cache=use_cache, + attention_mask=attention_mask, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + # all_reduce_workspace=all_reduce_workspace, + lora_params=lora_params, + # medusa_position_offsets=medusa_position_offsets, + # medusa_packed_mask=medusa_packed_mask, + ) + + if use_cache: + hidden_states, presents = hidden_states + + if self.mapping.is_last_pp_rank(): + hidden_states = self.ln_f(hidden_states) + else: + hidden_states = send(hidden_states, self.mapping.next_pp_rank()) + + if use_cache: + return (hidden_states, tuple(presents)) + return hidden_states + + +class GemmaForCausalLM(DecoderModelForCausalLM, TopModelMixin): + + def __init__(self, config): + + self.check_config(config) + transformer = GemmaModel(config) + + vocab_size_padded = pad_vocab_size(config.vocab_size, + config.mapping.tp_size) + if config.mapping.is_last_pp_rank(): + lm_head = ColumnLinear(config.hidden_size, + vocab_size_padded, + bias=False, + dtype=config.dtype, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + gather_output=True) + else: + lm_head = None + self.quant_mode = config.quant_mode + self.mapping = config.mapping + + super().__init__(config, transformer, lm_head) + + @classmethod + def from_hugging_face(cls, + hf_model_dir, + dtype='float16', + mapping: Optional[Mapping] = None, + quant_mode: Optional[QuantMode] = None, + **kwargs): + import transformers + from transformers import LlamaConfig + + from ...models.modeling_utils import PretrainedConfig + cfg = LlamaConfig.from_pretrained(hf_model_dir) + + num_kv_heads = cfg.num_key_value_heads if hasattr(cfg, "num_key_value_heads") \ + else cfg.num_attention_heads + if mapping is None: + mapping = Mapping() + if quant_mode is None: + quant_mode = QuantMode(0) + + cfg.mapping = mapping + + cfg.dtype = dtype + cfg.quant_mode = quant_mode + moe_config = kwargs.get("moe_config", MoeConfig()) + + cfg.norm_epsilon = cfg.rms_norm_eps + + config = { + 'architecture': cfg.architectures[0], + 'dtype': cfg.dtype, + 'logits_dtype': 'float32', + 'num_hidden_layers': cfg.num_hidden_layers, + 'num_attention_heads': cfg.num_attention_heads, + 'hidden_size': cfg.hidden_size, + 'intermediate_size': cfg.intermediate_size, + 'num_key_value_heads': cfg.num_key_value_heads, + 'vocab_size': cfg.vocab_size, + 'position_embedding_type': 'rope_gpt_neox', + 'max_position_embeddings': cfg.max_position_embeddings, + 'hidden_act': cfg.hidden_act, + 'rotary_base': getattr(cfg, 'rotary_base', 10000.0), + 'rotary_scaling': getattr(cfg, 'rotary_scaling', None), + 'norm_epsilon': cfg.rms_norm_eps, + 'quantization': quant_mode.to_dict(), + 'mapping': { + 'world_size': mapping.world_size, + 'tp_size': mapping.world_size, + }, + 'use_parallel_embedding': kwargs.get("use_parallel_embedding", + False), + 'embedding_sharding_dim': kwargs.get("embedding_sharding_dim", 0), + 'use_prompt_tuning': kwargs.get("use_prompt_tuning", False), + 'moe_num_experts': moe_config.num_experts, + 'moe_top_k': moe_config.top_k, + 'moe_tp_mode': moe_config.tp_mode, + 'moe_normalization_mode': moe_config.normalization_mode, + 'use_fused_mlp': kwargs.get("use_fused_mlp", False), + 'enable_pos_shift': kwargs.get("enable_pos_shift", False), + 'dense_context_fmha': kwargs.get("dense_context_fmha", False), + } + if quant_mode.is_int4_weight_only_per_group(): + config['quantization'].update({ + 'zero': False, + 'pre_quant_scale': True, + 'exclude_modules': [], + }) + + tllm_llama = GemmaForCausalLM(PretrainedConfig.from_dict(config)) + q_weights = {} + if quant_mode.has_any_quant(): + q_weights = tllm_llama._quantize(hf_model_dir, dtype, cfg, **kwargs) + + # For debug purpose, skip weights loading to be faster + if kwargs.get("skip_loading_weights", False): + return tllm_llama + + # TODO: support mixtral + + # weights already loaded in _quantize for int4 weight only + if not quant_mode.is_int4_weight_only_per_group(): + hf_model = transformers.LlamaForCausalLM + profiler.start("Loading weights from HF") + hf_llama = hf_model.from_pretrained( + hf_model_dir, + device_map={ + "model": "cpu", + "lm_head": "cpu", + "embed_tokens": "cpu", + "layers": "cpu", + "norm": "cpu", + }, # Load to CPU memory + torch_dtype='auto', + ) + + weights = load_from_hf_llama( + tllm_llama, + hf_llama, + mapping=mapping, + dtype=dtype, + # TODO: these shall be outside from_hugging_face too. + use_gemm_woq_plugin=kwargs.get("use_gemm_woq_plugin", False), + lora_config=kwargs.get("lora_config", LoraConfig()), + ) + profiler.stop("Loading weights from HF") + del hf_llama + weights.update(q_weights) + tllm_llama.load(weights) + else: + tllm_llama.load(q_weights) + return tllm_llama + + def _quantize(self, hf_model_dir, dtype, cfg, **kwargs): + '''Given the quant_mode set in the Module object, read from given hf model + call AMMO to generate quantization scales, and set the scales back the module parameters. + ''' + # use self destructed temporary path if kwargs[quantization_cache_dir] is not specified + # sometimes the quantization checkpoint path needs to be saved for debug purpose + quantized_temp_dir = tempfile.TemporaryDirectory("llama-quantized") + quantized_checkpoint_path = kwargs.get("quantization_cache_dir", + quantized_temp_dir.name) + quantize_lm_head = kwargs.get("quantize_lm_head", False) + quant_mode = cfg.quant_mode + ammo_qformat = None + calib_size = None + if quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache(): + ammo_qformat = 'fp8' + calib_size = 512 + # TODO: how to distinguish from quant_mode about int4_awq or int4_gptq? + elif quant_mode.is_int4_weight_only_per_group(): + ammo_qformat = 'int4_awq' + calib_size = 32 + assert ammo_qformat is not None + + # local import to avoid pytest issue when importing AMMO and transformers lib + from .quantize import quantize_llama_and_export + quantize_llama_and_export(hf_model_dir, + quantized_checkpoint_path, + ammo_qformat, + dtype, + calib_size=calib_size, + quantize_lm_head=quantize_lm_head) + + ckpt = Path(quantized_checkpoint_path) / "llama_tp1_rank0.npz" + assert ckpt.exists(), f"The expecting checkpoint path {ckpt} does not exist" \ + "it's likely quantization failed, pls check error logs" + hf_config = AutoConfig.from_pretrained(hf_model_dir, + trust_remote_code=True) + if ammo_qformat == 'fp8': + return load_from_fp8_llama( + str(ckpt), + hf_config, + cfg.mapping, + fp8_kv_cache=quant_mode.has_fp8_kv_cache()) + else: + return load_from_awq_llama(str(ckpt), + hf_config, + cfg.mapping, + dtype=dtype) + + # llama specific setters, user shall has the chance to change the module attributes after + # from_hugging_face factory method created the model when these attributes is not included in the huggingface checkpoint + + def rotary_base(self, val): + for decoder in self.layers: + decoder.attention.rotary_embedding_base = val + return self + + def rotary_scaling(self, scaling_type, factor): + # TODO: what if there are some other behaviors triggered by the these changes? + # should implement these assignment as setters of the Attention Module + assert scaling_type in ("linear", "dynamic"), f"Got {scaling_type}" + assert factor > 1.0, f"Got {factor}" + for decoder in self.layers: + decoder.attention.rotary_embedding_scale_type = RotaryScalingType.linear if scaling_type == "linear" else RotaryScalingType.dynamic + decoder.attention.rotary_embedding_scale = factor + return self + + def default_plugin_config(self, **kwargs): + plugin_config = super().default_plugin_config(**kwargs) + if self.quant_mode.is_int4_weight_only_per_group(): + plugin_config.set_weight_only_groupwise_quant_matmul_plugin() + return plugin_config + + def check_config(self, config): + config.set_if_not_exist('use_parallel_embedding', False) + config.set_if_not_exist('embedding_sharding_dim', 0) + config.set_if_not_exist('mlp_bias', False) + config.set_if_not_exist('attn_bias', False) + config.set_if_not_exist('rotary_base', 10000.0) + config.set_if_not_exist('rotary_scaling', None) + config.set_if_not_exist('enable_pos_shift', False) + config.set_if_not_exist('dense_context_fmha', False) + config.set_if_not_exist('use_fused_mlp', False) + config.set_if_not_exist('moe_num_experts', 0) + config.set_if_not_exist('moe_top_k', 0) + config.set_if_not_exist('moe_tp_mode', + MoeConfig.ParallelismMode.TENSOR_PARALLEL) + config.set_if_not_exist( + 'moe_normalization_mode', + MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE) diff --git a/tensorrt_llm/models/gemma/smoothquant.py b/tensorrt_llm/models/gemma/smoothquant.py new file mode 100644 index 000000000..385ffebd0 --- /dev/null +++ b/tensorrt_llm/models/gemma/smoothquant.py @@ -0,0 +1,1072 @@ +import copy +import functools +import math +import time +import warnings +from collections import defaultdict +from typing import Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm +from transformers import Cache, LlamaConfig, LlamaForCausalLM +from transformers.models.llama.modeling_llama import (LlamaAttention, + LlamaDecoderLayer, + apply_rotary_pos_emb, + repeat_kv) +from transformers.pytorch_utils import Conv1D + + +def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): + """ + This function has two purposes: + - compute quantized weights, scaled either per-tensor or per-column + - compute scaling factors + + Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ. + CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W. + CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor. + + Here is the list of what we need (T means per-tensor, C per-column): + - scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T) + - scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T) + - scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C) + - scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32) + to quant range (int8) (used for CUBLAS) (T, C) + + Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too, + but then the model would change depending on the number of GPUs used. + + For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it + as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V. + For our GEMM implementation to respect this behavior, we use per-column mode and replicate values along columns. + """ + + # compute weight scaling factors for fp->int8 and int8->fp + if is_qkv and not multi_query_mode: + scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max( + dim=-1, keepdims=True)[0].cpu().numpy() + scale_w_orig_quant_c = 127. / act_range["w"].reshape(3, + -1).cpu().numpy() + elif is_qkv and multi_query_mode: + hidden_dim = weights.shape[0] + local_dim = act_range["w"].shape[0] + kv_dim = (local_dim - hidden_dim) // 2 + scale_w_q = act_range["w"][0:hidden_dim] + scale_w_k = act_range["w"][hidden_dim:hidden_dim + kv_dim] + scale_w_v = act_range["w"][-kv_dim:] + + scale_w_qkv_t = torch.concat([ + scale_w_q.max(dim=0, keepdim=True)[0], + scale_w_k.max(dim=0, keepdim=True)[0], + scale_w_v.max(dim=0, keepdim=True)[0] + ]) + + scale_w_orig_quant_t = 127. / scale_w_qkv_t.cpu().numpy() + scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() + else: + scale_w_orig_quant_t = 127. / act_range["w"].max().cpu().numpy() + scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() + scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t + scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c + + scale_w_orig_quant_c = scale_w_orig_quant_c.astype(np.float32) + scale_w_orig_quant_t = scale_w_orig_quant_t.astype(np.float32) + # compute the rest of needed scaling factors + scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item()) + scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item()) + scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.) + scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t * + scale_w_orig_quant_t) + scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t * + scale_w_orig_quant_c) + if is_qkv and not multi_query_mode: + scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t, + scale_w_orig_quant_c.shape) + scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t, + scale_w_orig_quant_c.shape) + if is_qkv and multi_query_mode: + scale_q_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[0], + scale_w_q.shape) + scale_k_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[1], + scale_w_k.shape) + scale_v_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[2], + scale_w_v.shape) + scale_y_accum_quant_t = np.concatenate( + [scale_q_y_accum_t, scale_k_y_accum_t, scale_v_y_accum_t]) + scale_w_quant_orig_t = np.concatenate([ + np.broadcast_to(scale_w_quant_orig_t[0], scale_w_q.shape), + np.broadcast_to(scale_w_quant_orig_t[1], scale_w_k.shape), + np.broadcast_to(scale_w_quant_orig_t[2], scale_w_v.shape) + ]) + + to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8) + + if is_qkv and multi_query_mode: + weight_int8 = to_i8(weights / scale_w_quant_orig_t) + else: + weight_int8 = to_i8(weights * scale_w_orig_quant_t) + + return { + "weight.int8": weight_int8, + "weight.int8.col": to_i8(weights * scale_w_orig_quant_c), + "scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32), + "scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32), + "scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32), + "scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32), + "scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32), + "scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32), + } + + +@torch.no_grad() +def apply_smoothing(scales, + gemm_weights, + layernorm_weights=None, + layernorm_bias=None, + dtype=torch.float32, + layernorm_1p=False): + if not isinstance(gemm_weights, list): + gemm_weights = [gemm_weights] + + if layernorm_weights is not None: + assert layernorm_weights.numel() == scales.numel() + layernorm_weights.div_(scales).to(dtype) + if layernorm_bias is not None: + assert layernorm_bias.numel() == scales.numel() + layernorm_bias.div_(scales).to(dtype) + if layernorm_1p: + layernorm_weights += (1 / scales) - 1 + + for gemm in gemm_weights: + gemm.mul_(scales.view(1, -1)).to(dtype) + + +@torch.no_grad() +def smooth_gemm(gemm_weights, + act_scales, + layernorm_weights=None, + layernorm_bias=None, + alpha=0.5, + weight_scales=None): + if not isinstance(gemm_weights, list): + gemm_weights = [gemm_weights] + orig_dtype = gemm_weights[0].dtype + + for gemm in gemm_weights: + # gemm_weights are expected to be transposed + assert gemm.shape[1] == act_scales.numel() + + if weight_scales is None: + weight_scales = torch.cat( + [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], + dim=0) + weight_scales = weight_scales.max(dim=0)[0] + weight_scales.to(float).clamp(min=1e-5) + scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / + weight_scales.pow(1 - alpha)).clamp(min=1e-5) + + apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias, + orig_dtype) + + return scales + + +@torch.no_grad() +def capture_activation_range(model, + tokenizer, + dataset, + num_samples=1, + seq_len=512): + model.eval() + device = next(model.parameters()).device + act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None}) + + # tokenizer.pad_token = tokenizer.eos_token + + def stat_tensor(name, tensor, act_scales, key): + hidden_dim = tensor.shape[-1] + tensor = tensor.view(-1, hidden_dim).abs().detach() + comming_max = torch.max(tensor, dim=0)[0].float() + + if act_scales[name][key] is None: + act_scales[name][key] = comming_max + else: + act_scales[name][key] = torch.max(act_scales[name][key], + comming_max) + + def stat_input_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + stat_tensor(name, x, act_scales, "x") + stat_tensor(name, y, act_scales, "y") + + if act_scales[name]["w"] is None: + act_scales[name]["w"] = m.weight.abs().clip( + 1e-8, None).max(dim=1)[0].float() + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, nn.Linear) or isinstance(m, Conv1D): + hooks.append( + m.register_forward_hook( + functools.partial(stat_input_hook, name=name))) + + for i in tqdm(range(num_samples), desc="calibrating model"): + datapoint = dataset['train'][i:i + 1] + line = copy.copy(datapoint['article']) + line[0] = line[0] + ' TL;DR: ' + line[0] = line[0].strip() + line[0] = line[0].replace(" n't", "n't") + # input_ids = tokenizer(line, + # return_tensors="pt", + # max_length=seq_len, + # padding=True, + # truncation=True).input_ids.to(device) + inputs = tokenizer.EncodeAsIds(line[0]) + inputs = np.array([[tokenizer.bos_id()] + inputs], dtype=np.int32) + input_ids = torch.tensor(inputs, dtype=torch.int32).to(device) + model(input_ids) + + for h in hooks: + h.remove() + + return act_scales + + +@torch.no_grad() +def smooth_gemm_fc1_gate(fc1_weights, + gate_weights, + act_scales, + layernorm_weights=None, + layernorm_bias=None, + alpha=0.5, + weight_scales=None): + gemm_weights = [] + if not isinstance(fc1_weights, list): + fc1_weights = [fc1_weights] + if not isinstance(gate_weights, list): + gate_weights = [gate_weights] + + for i in range(len(fc1_weights)): + gemm_weight = torch.cat([fc1_weights[i], gate_weights[i]], dim=0) + gemm_weights.append(gemm_weight) + + orig_dtype = gemm_weights[0].dtype + + for gemm in gemm_weights: + # gemm_weights are expected to be transposed + assert gemm.shape[1] == act_scales.numel() + + if weight_scales is None: + weight_scales = torch.cat( + [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], + dim=0) + weight_scales = weight_scales.max(dim=0)[0] + weight_scales.to(float).clamp(min=1e-5) + scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / + weight_scales.pow(1 - alpha)).clamp(min=1e-5) + + apply_smoothing(scales, fc1_weights + gate_weights, layernorm_weights, + layernorm_bias, orig_dtype) + + return scales + + +@torch.no_grad() +def smooth_model(model, scales, alpha, qkv_para, smoother_dict): + # Smooth the activation and weights with smoother = $\diag{s}$ + for name, module in model.named_modules(): + if not isinstance(module, LlamaDecoderLayer): + continue + # qkv_proj + layer_name_q = name + ".self_attn.q_proj" + layer_name_k = name + ".self_attn.k_proj" + layer_name_v = name + ".self_attn.v_proj" + layer_name_qkv = name + ".self_attn.qkv_proj" + + weight = torch.cat([ + module.self_attn.q_proj.weight, module.self_attn.k_proj.weight, + module.self_attn.v_proj.weight + ], + dim=0) + + smoother = smooth_gemm(weight, scales[layer_name_q]["x"], + module.input_layernorm.weight, None, alpha) + + scales[layer_name_qkv]["x"] = scales[layer_name_q]["x"] / smoother + scales[layer_name_qkv]["w"] = weight.abs().max(dim=1)[0] + scales[layer_name_qkv]["y"] = torch.cat([ + scales[layer_name_q]["y"], scales[layer_name_k]["y"], + scales[layer_name_v]["y"] + ], + dim=0) + + # see transpose_weights function + qkv_para[layer_name_qkv] = weight.transpose(0, 1) + + # ================================================================= + layer_name = name + ".self_attn.o_proj" + smoother = smooth_gemm(module.self_attn.o_proj.weight, + scales[layer_name]["x"], None, None, alpha) + smoother_dict[layer_name] = smoother.float() + + scales[layer_name]["x"] = scales[layer_name]["x"] / smoother + scales[layer_name]["w"] = module.self_attn.o_proj.weight.abs().max( + dim=1)[0] + + # ================================================================== + fc1_layer_name = name + ".mlp.gate_proj" + gate_layer_name = name + ".mlp.up_proj" + + smoother = smooth_gemm_fc1_gate(module.mlp.gate_proj.weight, + module.mlp.up_proj.weight, + scales[fc1_layer_name]["x"], + module.post_attention_layernorm.weight, + None, alpha) + + scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother + scales[fc1_layer_name]["w"] = module.mlp.gate_proj.weight.abs().max( + dim=1)[0] + + scales[gate_layer_name]["x"] = scales[gate_layer_name]["x"] / smoother + scales[gate_layer_name]["w"] = module.mlp.up_proj.weight.abs().max( + dim=1)[0] + + # ================================================================== + layer_name = name + ".mlp.down_proj" + smoother = smooth_gemm(module.mlp.down_proj.weight, + scales[layer_name]["x"], None, None, alpha) + smoother_dict[layer_name] = smoother.float() + scales[layer_name]["x"] = scales[layer_name]["x"] / smoother + scales[layer_name]["w"] = module.mlp.down_proj.weight.abs().max( + dim=1)[0] + + +def get_tllm_linear_sq_weight(vals, + prefix, + shape, + tensor_parallel, + is_qkv=False, + per_token=False, + per_channel=False, + last_prefix=None, + bias=None, + smoother_value=None, + smoother_shape=None, + rank=0, + cat_dim=0, + multi_query_mode=False): + results = {} + + def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): + q, k, v = np.split(data, [local_dim, local_dim + head_size], axis=-1) + q_split = np.split(q, tp_size, axis=-1) + k_split = np.split(k, tp_size, axis=-1) + v_split = np.split(v, tp_size, axis=-1) + return [ + np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1) + for ii in range(tp_size) + ][cur_rank] + + col_shape = shape if (is_qkv or per_channel) else [1, 1] + + if per_token: + if per_channel: + original_weights = np.array(vals["weight.int8.col"]) + else: + original_weights = np.array(vals["weight.int8"]) + local_dim = original_weights.shape[0] + head_size = (original_weights.shape[1] - local_dim) // 2 + + if multi_query_mode: + cur_weights = multi_query_split(original_weights, local_dim, + head_size, tensor_parallel, rank) + else: + cur_weights = np.split(original_weights, + tensor_parallel, + axis=cat_dim)[rank] + if is_qkv: + hidden_dim = cur_weights.shape[0] + cur_weights = cur_weights.reshape(hidden_dim, -1) + results[prefix + + 'weight'] = torch.from_numpy(cur_weights).t().contiguous() + if smoother_value is None: + results[last_prefix] = torch.from_numpy( + np.array([1.0], dtype=np.float32)) + + if per_channel: + cur_per_channel_value = vals["scale_w_quant_orig.col"] + if smoother_value is None: + if multi_query_mode: + cur_per_channel_value = multi_query_split( + vals["scale_w_quant_orig.col"], local_dim, head_size, + tensor_parallel, rank) + else: + cur_per_channel_value = np.split( + vals["scale_w_quant_orig.col"], + tensor_parallel, + axis=cat_dim)[rank] + else: + cur_per_channel_value = vals["scale_w_quant_orig"] + if is_qkv: + if multi_query_mode: + cur_per_channel_value = multi_query_split( + vals["scale_w_quant_orig"], local_dim, head_size, + tensor_parallel, rank) + else: + cur_per_channel_value = np.split(vals["scale_w_quant_orig"], + tensor_parallel, + axis=cat_dim)[rank] + + results[prefix + 'per_channel_scale'] = torch.from_numpy( + np.array(cur_per_channel_value, + dtype=np.float32).reshape(col_shape)).contiguous() + else: + if per_channel: + original_weights = np.array(vals["weight.int8.col"]) + else: + original_weights = np.array(vals["weight.int8"]) + local_dim = original_weights.shape[0] + head_size = (original_weights.shape[1] - local_dim) // 2 + + if multi_query_mode: + cur_weights = multi_query_split(original_weights, local_dim, + head_size, tensor_parallel, rank) + else: + cur_weights = np.split(original_weights, + tensor_parallel, + axis=cat_dim)[rank] + if is_qkv: + hidden_dim = cur_weights.shape[0] + cur_weights = cur_weights.reshape(hidden_dim, -1) + results[prefix + + 'weight'] = torch.from_numpy(cur_weights).t().contiguous() + + if per_channel: + cur_per_channel_value = vals["scale_y_accum_quant.col"] + if smoother_value is None: + if multi_query_mode: + cur_per_channel_value = multi_query_split( + vals["scale_y_accum_quant.col"], local_dim, head_size, + tensor_parallel, rank) + else: + cur_per_channel_value = np.split( + vals["scale_y_accum_quant.col"], + tensor_parallel, + axis=cat_dim)[rank] + else: + cur_per_channel_value = vals["scale_y_accum_quant"] + # QKV is always per_channel + if is_qkv: + if multi_query_mode: + cur_per_channel_value = multi_query_split( + vals["scale_y_accum_quant"], local_dim, head_size, + tensor_parallel, rank) + else: + cur_per_channel_value = np.split( + vals["scale_y_accum_quant"], + tensor_parallel, + axis=cat_dim)[rank] + + results[prefix + 'per_channel_scale'] = torch.from_numpy( + np.array([cur_per_channel_value], + dtype=np.float32).reshape(col_shape)).contiguous() + + results[last_prefix] = torch.from_numpy( + np.array([vals['scale_x_orig_quant']], + dtype=np.float32)).contiguous() + + results[prefix + 'act_scale'] = torch.from_numpy( + np.array([[vals["scale_y_quant_orig"]]], + dtype=np.float32)).contiguous() + + if smoother_value is not None: + cur_smoother_value = np.split(smoother_value, + tensor_parallel, + axis=cat_dim)[rank] + results[prefix + 'smoother'] = cur_smoother_value.reshape( + smoother_shape).contiguous().to(torch.float32) + + if bias is not None: + results[prefix + 'bias'] = bias + + return results + + +def split(weight: torch.Tensor, + tp_size: int, + rank: int = 0, + dim: int = 0) -> torch.Tensor: + if tp_size == 1: + return weight + elif weight.ndim == 1: + return torch.chunk(weight, tp_size)[rank].contiguous() + else: + return torch.chunk(weight, tp_size, dim=dim)[rank].contiguous() + + +def split_qkv_tp(qkv, n_head, n_kv_heads, head_size, tensor_parallel, rank): + """ + Splits the QKV matrix according to tensor parallelism + """ + kv_head_size = n_kv_heads * head_size + q, k, v = torch.split(qkv, [n_head * head_size, kv_head_size, kv_head_size], + dim=0) + q = split(q, tensor_parallel, rank, dim=0) + k = split(k, tensor_parallel, rank, dim=0) + v = split(v, tensor_parallel, rank, dim=0) + return torch.concatenate([q, k, v], dim=0).contiguous() + + +def split_matrix_tp(weight: torch.Tensor, tp_size: int, rank: int, + dim: int) -> torch.Tensor: + return split(weight, tp_size, rank, dim=dim) + + +def get_weight(params: Dict[str, torch.Tensor], prefix: str, + dtype: torch.dtype) -> torch.Tensor: + if f'{prefix}.weight' not in params: + return None + return params[f'{prefix}.weight'].to(dtype).detach().cpu() + + +def get_bias(params: Dict[str, torch.Tensor], prefix: str, + dtype: torch.dtype) -> torch.Tensor: + if f'{prefix}.bias' not in params: + return None + return params[f'{prefix}.bias'].to(dtype).detach().cpu() + + +def get_weight_and_bias(params: Dict[str, torch.Tensor], prefix: str, + dtype: torch.dtype) -> Tuple[torch.Tensor]: + return get_weight(params, prefix, dtype), get_bias(params, prefix, dtype) + + +def get_tllm_linear_weight( + weight: torch.Tensor, + prefix: str, + bias: Optional[torch.Tensor] = None, + use_weight_only: bool = False, + plugin_weight_only_quant_type: torch.dtype = torch.int8 +) -> Dict[str, torch.Tensor]: + results = {} + if use_weight_only: + v = weight.t().contiguous() + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + v, plugin_weight_only_quant_type) + results[f'{prefix}weight'] = processed_torch_weights + results[f'{prefix}per_channel_scale'] = torch_weight_scales + else: + results[f'{prefix}weight'] = weight.contiguous() + + if bias is not None: + results[f'{prefix}bias'] = bias + + return results + + +class LlamaAttentionExtend(LlamaAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.head_dim = self.config.head_size + self.q_proj = nn.Linear(self.hidden_size, + self.num_heads * self.head_dim, + bias=False) + self.k_proj = nn.Linear(self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=False) + self.v_proj = nn.Linear(self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, + self.hidden_size, + bias=False) + self._init_rope() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * + self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, + dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) + for i in range(self.config.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) + for i in range(self.config.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) + for i in range(self.config.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index.") + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, + key_states, cos, sin, + position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to( + query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + + # Here is what we extend. + attn_output = attn_output.reshape(bsz, q_len, + self.num_heads * self.head_dim) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // + self.config.pretraining_tp, + dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // + self.config.pretraining_tp, + dim=1) + attn_output = sum([ + F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def create_model_from_config(trt_llm_config, weights): + model_config = LlamaConfig() + model_config.vocab_size = trt_llm_config.vocab_size + model_config.dtype = trt_llm_config.dtype + model_config.max_position_embeddings = trt_llm_config.max_position_embeddings + model_config.hidden_size = trt_llm_config.hidden_size + model_config.num_hidden_layers = trt_llm_config.num_hidden_layers + model_config.num_attention_heads = trt_llm_config.num_attention_heads + model_config.num_key_value_heads = trt_llm_config.num_key_value_heads + model_config.hidden_act = trt_llm_config.hidden_act + model_config.head_size = trt_llm_config.head_size + model_config.intermediate_size = trt_llm_config.intermediate_size + model = LlamaForCausalLM(model_config) + # Hack attention module since head_dim * num_heads > hidden_size for 7B. + for i in range(model_config.num_hidden_layers): + module = model.model.layers[i].self_attn + model.model.layers[i].self_attn = LlamaAttentionExtend( + module.config, module.layer_idx) + # Copy wegiht to LLAMA model. + replace_name_dict = { + 'attention.dense': 'self_attn.o_proj', + 'mlp.proj': 'mlp.down_proj', + 'mlp.gate': 'mlp.up_proj', + 'mlp.fc': 'mlp.gate_proj', + 'ln_f': 'norm', + 'post_layernorm': 'post_attention_layernorm', + 'vocab_embedding': 'embed_tokens', + } + for name in list(weights): + if model_config.dtype == "bfloat16": + param = torch.from_numpy(weights[name].astype(np.float32)).to( + torch.bfloat16) + else: + param = torch.from_numpy(weights[name]) + weights.pop(name) + new_name = name.replace('transformer', 'model') + for _name in replace_name_dict: + if _name in new_name: + new_name = new_name.replace(_name, replace_name_dict[_name]) + if 'attention.qkv' in name: + qw, kw, vw = torch.split(param, [ + model_config.num_attention_heads * model_config.head_size, + model_config.num_key_value_heads * model_config.head_size, + model_config.num_key_value_heads * model_config.head_size, + ], + dim=0) + weights[new_name.replace('attention.qkv', 'self_attn.q_proj')] = qw + weights[new_name.replace('attention.qkv', 'self_attn.k_proj')] = kw + weights[new_name.replace('attention.qkv', 'self_attn.v_proj')] = vw + else: + weights[new_name] = param + model.load_state_dict(weights) + return model + + +def convert_hf_model(hf_model, + mapping, + vocab_size=32000, + dtype='float32', + use_parallel_embedding=False, + sharding_dim=0, + use_weight_only=False, + plugin_weight_only_quant_type=torch.int8, + use_smooth_quant=False, + per_channel=False, + per_token=False, + int8_kv_cache=False, + act_range=[], + qkv_para=[], + smoother=[]): + + weights = {} + tik = time.time() + tensor_parallel = mapping.tp_size + model_params = dict(hf_model.named_parameters()) + dtype = getattr(torch, dtype) + num_attention_heads = hf_model.config.num_attention_heads + hidden_size = hf_model.config.hidden_size + intermediate_size = hf_model.config.intermediate_size + head_size = hf_model.config.head_size + num_key_value_heads = hf_model.config.num_key_value_heads + mha_mode = (num_key_value_heads == num_attention_heads) + + layers_per_pipeline_stage = hf_model.config.num_hidden_layers // mapping.pp_size + layers_range = list( + range(mapping.pp_rank * layers_per_pipeline_stage, + (mapping.pp_rank + 1) * layers_per_pipeline_stage, 1)) + for l in range(hf_model.config.num_hidden_layers): + if l not in layers_range: + continue + print("Processing layer", l) + prefix = f'model.layers.{l}.' + idx = int(l) - mapping.pp_rank * layers_per_pipeline_stage + tllm_prex = f'transformer.layers.{idx}.' + + if use_smooth_quant: + qkv_weight = qkv_para[prefix + 'self_attn.qkv_proj'] + qkv_out_dim = qkv_weight.shape[1] + + if not mha_mode: + hidden_size = qkv_weight.shape[0] + local_dim = hidden_size + head_size = (qkv_weight.shape[-1] - local_dim) // 2 + qkv_weight = qkv_weight.reshape(hidden_size, + local_dim + 2 * head_size) + else: + qkv_weight = qkv_weight.reshape(hidden_size, 3, + head_size * num_attention_heads) + + int8_weights = generate_int8(qkv_weight.numpy(), + act_range.get(prefix + + 'self_attn.qkv_proj'), + is_qkv=True, + multi_query_mode=bool(not mha_mode)) + weights.update( + get_tllm_linear_sq_weight(int8_weights, + tllm_prex + 'attention.qkv.', + [1, qkv_out_dim // tensor_parallel], + tensor_parallel, + is_qkv=True, + per_token=per_token, + per_channel=per_channel, + last_prefix=tllm_prex + + 'input_layernorm.scale_to_int', + smoother_value=None, + smoother_shape=None, + rank=mapping.tp_rank, + cat_dim=-1, + multi_query_mode=bool(not mha_mode))) + else: + q_weight = get_weight(model_params, prefix + 'self_attn.q_proj', + dtype) + k_weight = get_weight(model_params, prefix + 'self_attn.k_proj', + dtype) + v_weight = get_weight(model_params, prefix + 'self_attn.v_proj', + dtype) + if not mha_mode: + if num_key_value_heads < tensor_parallel: + # duplicate the KV heads up to tensor_parallel + k_weight = dup_kv_weight(k_weight, num_key_value_heads, + tensor_parallel) + v_weight = dup_kv_weight(v_weight, num_key_value_heads, + tensor_parallel) + assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0 + assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0 + + wq = split(q_weight, mapping.tp_size, mapping.tp_rank) + wk = split(k_weight, mapping.tp_size, mapping.tp_rank) + wv = split(v_weight, mapping.tp_size, mapping.tp_rank) + + split_v = torch.concat((wq, wk, wv)) + + else: + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) + + split_v = split_qkv_tp(qkv_weight, num_attention_heads, + num_key_value_heads, head_size, + tensor_parallel, mapping.tp_rank) + weights.update( + get_tllm_linear_weight(split_v, tllm_prex + 'attention.qkv.', + None, use_weight_only, + plugin_weight_only_quant_type)) + + if int8_kv_cache: + qkv_y = torch.cat([ + act_range.get(prefix + 'self_attn.q_proj')["y"], + act_range.get(prefix + 'self_attn.k_proj')["y"], + act_range.get(prefix + 'self_attn.v_proj')["y"] + ], + dim=0) + int8_kv_scales = qkv_y.max() / 127. + kv_cache_weights = {} + kv_cache_weights[ + tllm_prex + + 'attention.kv_cache_scaling_factor'] = int8_kv_scales.reshape( + [1]) + + weights.update(kv_cache_weights) + + # Attention dense. + attn_dense_weight = get_weight(model_params, + prefix + 'self_attn.o_proj', dtype) + if use_smooth_quant: + attn_dense_weight = attn_dense_weight.t().numpy() + int8_weights = generate_int8( + attn_dense_weight, act_range.get(prefix + 'self_attn.o_proj')) + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + tllm_prex + 'attention.dense.', [1, hidden_size], + tensor_parallel, + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix=tllm_prex + + 'attention.quantization_scaling_factor', + smoother_value=smoother[(prefix + 'self_attn.o_proj')], + smoother_shape=[ + 1, head_size * num_attention_heads // tensor_parallel + ], + rank=mapping.tp_rank, + cat_dim=0)) + else: + attn_dense_weight = split_matrix_tp(attn_dense_weight, + tensor_parallel, + mapping.tp_rank, + dim=1) + weights.update( + get_tllm_linear_weight(attn_dense_weight, + tllm_prex + 'attention.dense.', None, + use_weight_only, + plugin_weight_only_quant_type)) + # MLP hf up to trt gate + mlp_up_weight = get_weight(model_params, prefix + 'mlp.up_proj', dtype) + if use_smooth_quant: + mlp_up_weight = mlp_up_weight.t().numpy() + int8_weights = generate_int8(mlp_up_weight, + act_range.get(prefix + 'mlp.up_proj')) + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + tllm_prex + 'mlp.gate.', + [1, intermediate_size // tensor_parallel], + tensor_parallel, + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix=tllm_prex + 'post_layernorm.scale_to_int', + smoother_value=None, + smoother_shape=None, + rank=mapping.tp_rank, + cat_dim=-1)) + else: + mlp_up_weight = split_matrix_tp(mlp_up_weight, + tensor_parallel, + mapping.tp_rank, + dim=0) + weights.update( + get_tllm_linear_weight(mlp_up_weight, tllm_prex + 'mlp.gate.', + None, use_weight_only, + plugin_weight_only_quant_type)) + + # MLP trt Gate to mlp fc + mlp_gate_weight = get_weight(model_params, prefix + 'mlp.gate_proj', + dtype) + if use_smooth_quant: + mlp_gate_weight = mlp_gate_weight.t().numpy() + int8_weights = generate_int8( + mlp_gate_weight, act_range.get(prefix + 'mlp.gate_proj')) + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + tllm_prex + 'mlp.fc.', + [1, intermediate_size // tensor_parallel], + tensor_parallel, + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix=tllm_prex + 'post_layernorm.scale_to_int', + smoother_value=None, + smoother_shape=None, + rank=mapping.tp_rank, + cat_dim=-1)) + else: + mlp_gate_weight = split_matrix_tp(mlp_gate_weight, + tensor_parallel, + mapping.tp_rank, + dim=0) + weights.update( + get_tllm_linear_weight(mlp_gate_weight, tllm_prex + 'mlp.fc.', + None, use_weight_only, + plugin_weight_only_quant_type)) + + # MLP down + mlp_proj_weight = get_weight(model_params, prefix + 'mlp.down_proj', + dtype) + if use_smooth_quant: + mlp_proj_weight = mlp_proj_weight.t().numpy() + int8_weights = generate_int8( + mlp_proj_weight, act_range.get(prefix + 'mlp.down_proj')) + weights.update( + get_tllm_linear_sq_weight( + int8_weights, + tllm_prex + 'mlp.proj.', [1, hidden_size], + tensor_parallel, + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix=tllm_prex + 'mlp.quantization_scaling_factor', + smoother_value=smoother[prefix + 'mlp.down_proj'], + smoother_shape=[1, intermediate_size // tensor_parallel], + rank=mapping.tp_rank, + cat_dim=0)) + else: + mlp_proj_weight = split_matrix_tp(mlp_proj_weight, + tensor_parallel, + mapping.tp_rank, + dim=1) + weights.update( + get_tllm_linear_weight(mlp_proj_weight, tllm_prex + 'mlp.proj.', + None, use_weight_only, + plugin_weight_only_quant_type)) + + # Layer norms do not use tensor parallelism + input_ln_weight = get_weight(model_params, prefix + 'input_layernorm', + dtype) + weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight + + post_ln_weight = get_weight(model_params, + prefix + 'post_attention_layernorm', dtype) + weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight + + v = get_weight(model_params, 'model.embed_tokens', dtype) + + if use_parallel_embedding: + v = split_matrix_tp(v, + mapping.tp_size, + mapping.tp_rank, + dim=sharding_dim) + + if mapping.is_first_pp_rank(): + weights['transformer.vocab_embedding.weight'] = v + + lm_head_weights = get_weight(model_params, 'lm_head', dtype) + + if mapping.is_last_pp_rank(): + + if vocab_size % mapping.tp_size != 0: + # padding + vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) + pad_width = vocab_size_padded - vocab_size + + lm_head_weights = torch.from_numpy( + np.pad(lm_head_weights.detach().cpu().numpy(), + ((0, pad_width), (0, 0)), + 'constant', + constant_values=0)) + weights['lm_head.weight'] = split_matrix_tp(lm_head_weights, + tensor_parallel, + mapping.tp_rank, + dim=0) + ln_f_w = get_weight(model_params, 'model.norm', dtype) + weights['transformer.ln_f.weight'] = ln_f_w + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Weights loaded. Total time: {t}') + return weights diff --git a/tensorrt_llm/models/gemma/weight.py b/tensorrt_llm/models/gemma/weight.py new file mode 100644 index 000000000..5a7852c2b --- /dev/null +++ b/tensorrt_llm/models/gemma/weight.py @@ -0,0 +1,681 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import configparser +import os +import time +from pathlib import Path +from typing import Dict, List, Optional, Union + +import numpy as np +import torch + +import tensorrt_llm +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.quantized.quant import get_dummy_quant_scales +from tensorrt_llm.quantization import QuantMode + + +def get_scaling_factors( + model_path: Union[str, Path], + num_layers: int, + quant_mode: Optional[QuantMode] = None, +) -> Optional[Dict[str, List[int]]]: + """ Get the scaling factors for LLaMA model + + Returns a dictionary of scaling factors for the selected layers of the + LLaMA model. + + Args: + model_path (str): Path to the quantized LLaMA model + layers (list): List of layers to get the scaling factors for. If None, + all layers are selected. + + Returns: + dict: Dictionary of scaling factors for the selected layers of the + LLaMA model. + + example: + + { + 'qkv_act': qkv_act_scale, + 'qkv_weights': qkv_weights_scale, + 'qkv_output' : qkv_outputs_scale, + 'dense_act': dense_act_scale, + 'dense_weights': dense_weights_scale, + 'fc_act': fc_act_scale, + 'fc_weights': fc_weights_scale, + 'gate_act': gate_act_scale, + 'gate_weights': gate_weights_scale, + 'proj_act': proj_act_scale, + 'proj_weights': proj_weights_scale, + } + """ + + if model_path is None: + tensorrt_llm.logger.warning( + f"--quantized_fp8_model_path not specified. " + f"Initialize quantization scales automatically.") + return get_dummy_quant_scales(num_layers) + weight_dict = np.load(model_path) + # yapf: disable + scaling_factor = { + 'qkv_act': [], + 'qkv_weights': [], + 'dense_act': [], + 'dense_weights': [], + 'fc_act': [], + 'fc_weights': [], + 'gate_act': [], + 'gate_weights': [], + 'proj_act': [], + 'proj_weights': [], + } + + if quant_mode is not None and quant_mode.has_fp8_kv_cache(): + scaling_factor['qkv_output'] = [] + + for layer in range(num_layers): + scaling_factor['qkv_act'].append(max( + weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(), + weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(), + weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item() + )) + scaling_factor['qkv_weights'].append(max( + weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(), + weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(), + weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item() + )) + if quant_mode is not None and quant_mode.has_fp8_kv_cache(): + # Not calibrarting KV cache. + scaling_factor['qkv_output'].append(1.0) + scaling_factor['dense_act'].append( + weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item()) + scaling_factor['dense_weights'].append( + weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item()) + scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item()) + scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item()) + scaling_factor['gate_act'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:activation_scaling_factor'].item()) + scaling_factor['gate_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:weights_scaling_factor'].item()) + scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item()) + scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item()) + # yapf: enable + for k, v in scaling_factor.items(): + assert len(v) == num_layers, \ + f'Expect scaling factor {k} of length {num_layers}, got {len(v)}' + + return scaling_factor + + +def gen_suffix(rank, use_smooth_quant, quant_per_channel): + suffix = f"{rank}.bin" + if use_smooth_quant: + sq_prefix = "int8." + if quant_per_channel: + sq_prefix += "col." + suffix = sq_prefix + suffix + return suffix + + +def extract_layer_idx(name): + ss = name.split('.') + for s in ss: + if s.isdigit(): + return s + return None + + +def split(v: Union[np.ndarray, torch.Tensor], + tp_size: int, + tp_rank: int, + dim=0): + if tp_size == 1: + return v + assert len(v.shape) > 1 or dim == 0 + if isinstance(v, np.ndarray): + return np.ascontiguousarray( + np.split(v, tp_size, axis=dim)[tp_rank].copy()) + else: + assert v.shape[dim] % tp_size == 0, \ + 'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.' + split_size = v.shape[dim] // tp_size + return v.split(split_size, dim=dim)[tp_rank].clone().detach() + + +def dup_kv_weight(v, num_head, tp_size): + assert tp_size % num_head == 0 + reps = tp_size // num_head + head_size = v.shape[0] // num_head + v = v.reshape(num_head, head_size, + -1)[:, None, :, :].expand(num_head, reps, head_size, + v.shape[1]) + return v.reshape(num_head * reps * head_size, -1).clone().detach() + + +def parse_bin_config(ini_file): + model_config = configparser.ConfigParser() + model_config.read(ini_file) + + n_embd = model_config.getint('gemma', 'hidden_size') + n_head = model_config.getint('gemma', 'num_attention_heads') + n_head_size = model_config.getint('gemma', + 'head_size', + fallback=n_embd // n_head) + n_layer = model_config.getint('gemma', 'num_hidden_layers') + n_positions = model_config.getint('gemma', 'max_position_embeddings') + vocab_size = model_config.getint('gemma', 'vocab_size') + hidden_act = model_config.get('gemma', 'hidden_act') + inter_size = model_config.getint('gemma', + 'intermediate_size', + fallback=None) + n_kv_head = model_config.getint('gemma', + 'num_key_value_heads', + fallback=None) + + if inter_size is None: + inter_size = 4 * n_embd + + return n_embd, n_head, n_layer, n_positions, vocab_size, hidden_act, inter_size, n_kv_head, n_head_size + + +def load_from_binary(tensorrt_llm_gemma, + dir_path, + mapping=Mapping(), + fp16=False, + multi_query_mode=False): + tensorrt_llm.logger.info('Loading weights from binary...') + tik = time.time() + + quant_mode = getattr(tensorrt_llm_gemma, 'quant_mode', QuantMode(0)) + + n_embd, n_head, n_layer, n_positions, vocab_size, hidden_act, inter_size, n_kv_head, n_head_size = parse_bin_config( + Path(dir_path) / 'config.ini') + np_dtype = np.float16 if fp16 else np.float32 + + def fromfile(dir_path, name, shape=None, dtype=None): + dtype = np_dtype if dtype is None else dtype + p = dir_path + '/' + name + if Path(p).exists(): + t = np.fromfile(p, dtype=dtype) + if shape is not None: + t = t.reshape(shape) + return t + return None + + def set_smoothquant_scale_factors(module, + pre_scale_weight, + dir_path, + basename, + shape, + per_tok_dyn, + per_channel, + is_qkv=False, + rank=None): + suffix = "bin" + if per_channel: + if rank is not None: + suffix = f"{rank}." + suffix + suffix = "col." + suffix + + col_shape = shape if (per_channel or is_qkv) else [1, 1] + + if per_tok_dyn: + if pre_scale_weight is not None: + pre_scale_weight.value = np.array([1.0], dtype=np.float32) + if is_qkv and not per_channel: + t = fromfile(dir_path, + f"{basename}scale_w_quant_orig.{rank}.{suffix}", + col_shape, np.float32) + else: + t = fromfile(dir_path, f"{basename}scale_w_quant_orig.{suffix}", + col_shape, np.float32) + module.per_channel_scale.value = t + else: + t = fromfile(dir_path, f"{basename}scale_x_orig_quant.bin", [1], + np.float32) + pre_scale_weight.value = t + if is_qkv: + t = fromfile(dir_path, + f"{basename}scale_y_accum_quant.{rank}.{suffix}", + col_shape, np.float32) + else: + t = fromfile(dir_path, + f"{basename}scale_y_accum_quant.{suffix}", + col_shape, np.float32) + module.per_channel_scale.value = t + t = fromfile(dir_path, f"{basename}scale_y_quant_orig.bin", [1, 1], + np.float32) + module.act_scale.value = t + + def set_smoother(module, dir_path, base_name, shape, rank): + suffix = f"{rank}.bin" + t = fromfile(dir_path, f"{base_name}.smoother.{suffix}", shape, + np.float32) + module.smoother.value = t + + # Determine the quantization mode. + quant_mode = getattr(tensorrt_llm_gemma, "quant_mode", QuantMode(0)) + if quant_mode.is_int8_weight_only(): + plugin_weight_only_quant_type = torch.int8 + elif quant_mode.is_int4_weight_only(): + plugin_weight_only_quant_type = torch.quint4x2 + # Do we use SmoothQuant? + use_smooth_quant = quant_mode.has_act_and_weight_quant() + # Do we use quantization per token? + quant_per_token_dyn = quant_mode.has_per_token_dynamic_scaling() + # Do we use quantization per channel? + quant_per_channel = quant_mode.has_per_channel_scaling() + + # Do we use INT4/INT8 weight-only? + use_weight_only = quant_mode.is_weight_only() + + # Int8 KV cache + use_int8_kv_cache = quant_mode.has_int8_kv_cache() + + # Debug + suffix = gen_suffix(mapping.tp_rank, use_smooth_quant, quant_per_channel) + # The type of weights. + w_type = np_dtype if not use_smooth_quant else np.int8 + + if mapping.is_first_pp_rank(): + tensorrt_llm_gemma.vocab_embedding.weight.value = (fromfile( + dir_path, 'vocab_embedding.weight.bin', [vocab_size, n_embd])) + + if mapping.is_last_pp_rank(): + tensorrt_llm_gemma.ln_f.weight.value = (fromfile( + dir_path, 'ln_f.weight.bin')) + # share input embedding + lm_head_weight = fromfile(dir_path, 'lm_head.weight.bin', + [vocab_size, n_embd]) + + if vocab_size % mapping.tp_size != 0: + # padding + vocab_size_padded = tensorrt_llm_gemma.lm_head.out_features * mapping.tp_size + pad_width = vocab_size_padded - vocab_size + lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), + 'constant', + constant_values=0) + if mapping.is_last_pp_rank(): + tensorrt_llm_gemma.lm_head.weight.value = np.ascontiguousarray( + split(lm_head_weight, mapping.tp_size, mapping.tp_rank)) + + layers_per_pipeline_stage = tensorrt_llm_gemma.num_layers // mapping.pp_size + layers_range = list( + range(mapping.pp_rank * layers_per_pipeline_stage, + (mapping.pp_rank + 1) * layers_per_pipeline_stage, 1)) + + # This code does not support the case where the number of ranks is greater than the number of K/V heads for GQA. + assert (n_kv_head % mapping.tp_size == 0) or (n_kv_head == 1) + + # Compute the number of K/V heads per rank. It's 1 for MQA. + kv_heads_per_rank = min(1, n_kv_head // mapping.tp_size) + # The N-dimension for each rank of the QKV matrix is number of columns for Q + 2 * number of columns for K/V. + if multi_query_mode: + c_attn_out_dim = n_head * n_head_size // mapping.tp_size + 2 * kv_heads_per_rank * n_head_size + else: + c_attn_out_dim = 3 * (n_head * n_head_size) // mapping.tp_size + + for i in layers_range: + idx = i - mapping.pp_rank * layers_per_pipeline_stage + tensorrt_llm_gemma.layers[idx].input_layernorm.weight.value = (fromfile( + dir_path, 'model.layers.' + str(i) + '.input_layernorm.weight.bin')) + t = fromfile( + dir_path, 'model.layers.' + str(i) + + '.attention.query_key_value.weight.' + suffix, + [n_embd, c_attn_out_dim], w_type) + if t is not None: + dst = tensorrt_llm_gemma.layers[idx].attention.qkv.weight + if use_smooth_quant: + dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) + set_smoothquant_scale_factors( + tensorrt_llm_gemma.layers[idx].attention.qkv, + tensorrt_llm_gemma.layers[idx].input_layernorm.scale_to_int, + dir_path, + 'model.layers.' + str(i) + '.attention.query_key_value.', + [1, c_attn_out_dim], + quant_per_token_dyn, + quant_per_channel, + rank=mapping.tp_rank, + is_qkv=True) + elif use_weight_only: + processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(t), plugin_weight_only_quant_type) + dst.value = processed_torch_weights.numpy() + scales = tensorrt_llm_gemma.layers[ + idx].attention.qkv.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) + + dst = tensorrt_llm_gemma.layers[idx].attention.dense.weight + t = fromfile( + dir_path, + 'model.layers.' + str(i) + '.attention.dense.weight.' + suffix, + [(n_head * n_head_size) // mapping.tp_size, n_embd], w_type) + if use_smooth_quant: + dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) + dense_scale = getattr(tensorrt_llm_gemma.layers[idx].attention, + "quantization_scaling_factor", None) + set_smoothquant_scale_factors( + tensorrt_llm_gemma.layers[idx].attention.dense, dense_scale, + dir_path, 'model.layers.' + str(i) + '.attention.dense.', + [1, n_embd], quant_per_token_dyn, quant_per_channel) + set_smoother(tensorrt_llm_gemma.layers[idx].attention.dense, + dir_path, + 'model.layers.' + str(i) + '.attention.dense', + [1, n_embd // mapping.tp_size], mapping.tp_rank) + elif use_weight_only: + processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(t), plugin_weight_only_quant_type) + dst.value = processed_torch_weights.numpy() + scales = tensorrt_llm_gemma.layers[ + idx].attention.dense.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) + + dst = tensorrt_llm_gemma.layers[idx].post_layernorm.weight + dst.value = fromfile( + dir_path, 'model.layers.' + str(i) + '.post_layernorm.weight.bin') + + t = fromfile(dir_path, + 'model.layers.' + str(i) + '.mlp.fc.weight.' + suffix, + [n_embd, inter_size // mapping.tp_size], w_type) + + if use_smooth_quant: + tensorrt_llm_gemma.layers[ + idx].mlp.fc.weight.value = np.ascontiguousarray( + np.transpose(t, [1, 0])) + set_smoothquant_scale_factors( + tensorrt_llm_gemma.layers[idx].mlp.fc, + tensorrt_llm_gemma.layers[idx].post_layernorm.scale_to_int, + dir_path, + 'model.layers.' + str(i) + '.mlp.fc.', + [1, inter_size // mapping.tp_size], + quant_per_token_dyn, + quant_per_channel, + rank=mapping.tp_rank) + elif use_weight_only: + dst = tensorrt_llm_gemma.layers[idx].mlp.fc.weight + processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(t), plugin_weight_only_quant_type) + + dst.value = processed_torch_weights.numpy() + scales = tensorrt_llm_gemma.layers[idx].mlp.fc.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + tensorrt_llm_gemma.layers[ + idx].mlp.fc.weight.value = np.ascontiguousarray( + np.transpose(t, [1, 0])) + + t = fromfile(dir_path, + 'model.layers.' + str(i) + '.mlp.gate.weight.' + suffix, + [n_embd, inter_size // mapping.tp_size], w_type) + if use_smooth_quant: + tensorrt_llm_gemma.layers[ + idx].mlp.gate.weight.value = np.ascontiguousarray( + np.transpose(t, [1, 0])) + set_smoothquant_scale_factors( + tensorrt_llm_gemma.layers[idx].mlp.gate, + tensorrt_llm_gemma.layers[idx].post_layernorm.scale_to_int, + dir_path, + 'model.layers.' + str(i) + '.mlp.gate.', + [1, inter_size // mapping.tp_size], + quant_per_token_dyn, + quant_per_channel, + rank=mapping.tp_rank) + elif use_weight_only: + dst = tensorrt_llm_gemma.layers[idx].mlp.gate.weight + processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(t), plugin_weight_only_quant_type) + dst.value = processed_torch_weights.numpy() + scales = tensorrt_llm_gemma.layers[idx].mlp.gate.per_channel_scale + + scales.value = torch_weight_scales.numpy() + else: + tensorrt_llm_gemma.layers[ + idx].mlp.gate.weight.value = np.ascontiguousarray( + np.transpose(t, [1, 0])) + + t = fromfile(dir_path, + 'model.layers.' + str(i) + '.mlp.proj.weight.' + suffix, + [inter_size // mapping.tp_size, n_embd], w_type) + if use_smooth_quant: + tensorrt_llm_gemma.layers[ + idx].mlp.proj.weight.value = np.ascontiguousarray( + np.transpose(t, [1, 0])) + proj_scale = getattr(tensorrt_llm_gemma.layers[idx].mlp, + "quantization_scaling_factor", None) + set_smoothquant_scale_factors( + tensorrt_llm_gemma.layers[idx].mlp.proj, proj_scale, dir_path, + 'model.layers.' + str(i) + '.mlp.proj.', [1, n_embd], + quant_per_token_dyn, quant_per_channel) + set_smoother(tensorrt_llm_gemma.layers[idx].mlp.proj, dir_path, + 'model.layers.' + str(i) + '.mlp.proj', + [1, inter_size // mapping.tp_size], mapping.tp_rank) + elif use_weight_only: + dst = tensorrt_llm_gemma.layers[idx].mlp.proj.weight + processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(t), plugin_weight_only_quant_type) + + dst.value = processed_torch_weights.numpy() + scales = tensorrt_llm_gemma.layers[idx].mlp.proj.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + tensorrt_llm_gemma.layers[idx].mlp.proj.weight.value = ( + np.ascontiguousarray(np.transpose(t, [1, 0]))) + + if use_int8_kv_cache: + t = fromfile( + dir_path, 'model.layers.' + str(i) + + '.attention.query_key_value.scale_y_quant_orig.bin', [1], + np.float32) + tensorrt_llm_gemma.layers[ + idx].attention.kv_cache_scaling_factor.value = t + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}') + + +def load_from_hf_llama(): + # leave for preventing import issue + pass + + +def quantize_fp8_weigths(weights, num_layers, mapping): + + def get_scaling_factor(weight): + amax = weight.max() + scale = 448.0 / amax + return scale + + layers_range = mapping.pp_layers(num_layers) + scaling_factors = {} + scaled_weights = {} + trt_llm_prefix = "transformer.layers" + for l in layers_range: + # attention.qkv.weight + for name in [ + "attention.qkv", "attention.dense", "mlp.fc", "mlp.gate", + "mlp.proj" + ]: + trt_llm_name = ".".join((trt_llm_prefix, str(l), name, "weight")) + scale_name = ".".join( + (trt_llm_prefix, str(l), name, "weights_scaling_factor")) + weight = weights[trt_llm_name] + dtype = weights[trt_llm_name].dtype + scale = get_scaling_factor(weight) + scaled_weights[trt_llm_name] = np.ascontiguousarray( + (weight * scale).astype(dtype)) + scaling_factors[scale_name] = np.asarray([1 / scale + ]).astype(np.float32) + return scaling_factors + + +def load_from_fp8_llama(quant_ckpt_path: str, num_layers: int, mapping: Mapping, + fp8_kv_cache: bool, weight_scales: dict): + """ + Get the fp8 scaling factors. + """ + fake_fp8_sf_dt = torch.float32 + + if quant_ckpt_path is not None and os.path.isfile(quant_ckpt_path): + fp8_llama = np.load(quant_ckpt_path) + else: + fp8_llama = None + tensorrt_llm.logger.info( + f"There is not quantized checkpoint, use dummy fp8 scaling factors instead." + ) + weights = {} + + def get_fp8_llama(name): + if fp8_llama is not None: + return fp8_llama[name] + else: + return torch.tensor([1.0], dtype=fake_fp8_sf_dt).numpy() + + layers_range = mapping.pp_layers(num_layers) + for l in layers_range: + prefix = f'_np:layers:{l}' + tllm_prex = f'transformer.layers.{l-layers_range[0]}' + + weights[f'{tllm_prex}.attention.qkv.activation_scaling_factor'] = max( + get_fp8_llama( + f'{prefix}:attention:qkv:q:activation_scaling_factor'), + get_fp8_llama( + f'{prefix}:attention:qkv:k:activation_scaling_factor'), + get_fp8_llama( + f'{prefix}:attention:qkv:v:activation_scaling_factor')) + weights[f'{tllm_prex}.attention.qkv.weights_scaling_factor'] = max( + get_fp8_llama(f'{prefix}:attention:qkv:q:weights_scaling_factor'), + get_fp8_llama(f'{prefix}:attention:qkv:k:weights_scaling_factor'), + get_fp8_llama(f'{prefix}:attention:qkv:v:weights_scaling_factor')) + weights[ + f'{tllm_prex}.attention.dense.activation_scaling_factor'] = get_fp8_llama( + f'{prefix}:attention:dense:activation_scaling_factor') + weights[ + f'{tllm_prex}.attention.dense.weights_scaling_factor'] = get_fp8_llama( + f'{prefix}:attention:dense:weights_scaling_factor') + + weights[ + f'{tllm_prex}.mlp.fc.activation_scaling_factor'] = get_fp8_llama( + f'{prefix}:mlp:fc:activation_scaling_factor') + weights[f'{tllm_prex}.mlp.fc.weights_scaling_factor'] = get_fp8_llama( + f'{prefix}:mlp:fc:weights_scaling_factor') + + weights[ + f'{tllm_prex}.mlp.gate.activation_scaling_factor'] = get_fp8_llama( + f'{prefix}:mlp:gate:activation_scaling_factor') + weights[f'{tllm_prex}.mlp.gate.weights_scaling_factor'] = get_fp8_llama( + f'{prefix}:mlp:gate:weights_scaling_factor') + + weights[ + f'{tllm_prex}.mlp.proj.activation_scaling_factor'] = get_fp8_llama( + f'{prefix}:mlp:proj:activation_scaling_factor') + weights[f'{tllm_prex}.mlp.proj.weights_scaling_factor'] = get_fp8_llama( + f'{prefix}:mlp:proj:weights_scaling_factor') + + if fp8_kv_cache: + # Not calibrarting KV cache. + scaling_factor = 1.0 + weights[ + f'{tllm_prex}.attention.kv_cache_scaling_factor'] = torch.tensor( + [scaling_factor], dtype=fake_fp8_sf_dt).numpy() + if fp8_llama is None: + weights.update(weight_scales) + + return weights + + +def dummy_scaling_facotr_sq(weights): + for name in list(weights): + if any([ + _name in name for _name in [ + 'mlp.proj.weight', 'mlp.gate.weight', 'mlp.fc.weight', + 'attention.qkv.weight', 'attention.dense.weight' + ] + ]): + print("Processing:", name) + weight = weights[name] + out_dim, in_dim = weight.shape + weights_scaling_factor = (np.abs(weight).max(1, keepdims=True) / + 127.) + prequant_scaling_factor = np.ones([in_dim], dtype=weight.dtype) + activation_scaling_factor = np.array([0.1], dtype=np.float32) + int_weight = (weight / weights_scaling_factor).round().astype( + np.int8) + weights[name.replace( + 'weight', 'prequant_scaling_factor')] = prequant_scaling_factor + weights[name.replace( + 'weight', + 'weights_scaling_factor')] = weights_scaling_factor.astype( + np.float32).squeeze(1) + weights[name.replace( + 'weight', + 'activation_scaling_factor')] = activation_scaling_factor + weights[name] = int_weight + return weights + + +def dummy_scaling_facotr_kv_cache(weights): + for name in list(weights): + if 'attention.qkv.weight' in name: + kv_cache_scaling_factor = np.array([0.1], dtype=np.float32) + weights[name.replace( + 'qkv.weight', + 'kv_cache_scaling_factor')] = kv_cache_scaling_factor + + +def dummy_weights_awq(weights, precision, trt_llm_config, group_size): + packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 + use_fp8_kv_cache = trt_llm_config.quant_mode.has_fp8_kv_cache() + use_int8_kv_cache = trt_llm_config.quant_mode.has_int8_kv_cache() + num_layers = trt_llm_config.num_hidden_layers + for name in list(weights): + if any([ + _name in name for _name in [ + 'mlp.proj.weight', 'mlp.gate.weight', 'mlp.fc.weight', + 'attention.qkv.weight', 'attention.dense.weight' + ] + ]): + print("Processing:", name) + weight = np.ascontiguousarray(weights[name].T) + in_dim, out_dim = weight.shape + scale = np.amax(weight) / 7 + weights_scaling_factor = np.ones([out_dim, in_dim // group_size + ]) * scale.astype(np.float32) + weight_smoothed = (weight.astype(np.float32) / scale).astype( + np.int8) + weight_smoothed[weight_smoothed < -8] = -8 + weight_smoothed[weight_smoothed > 7] = 7 + prequant_scaling_factor = np.ones([in_dim], dtype=weight.dtype) + weights[name] = packer( + torch.from_numpy(weight_smoothed)).T.contiguous().numpy() + weights[name.replace( + 'weight', 'prequant_scaling_factor')] = prequant_scaling_factor + weights[name.replace( + 'weight', + 'weights_scaling_factor')] = weights_scaling_factor.astype( + weight.dtype) + if precision == "w4a8_awq": + alpha = np.array([1], dtype=np.float32) + weights[name.replace('weight', 'alpha')] = alpha + if use_fp8_kv_cache or use_int8_kv_cache: + for l in range(num_layers): + t = np.array([1], dtype=np.float32) + weights[ + f"transformer.layers.{l}.attention.kv_cache_scaling_factor"] = t + + return weights diff --git a/tensorrt_llm/quantization/layers.py b/tensorrt_llm/quantization/layers.py index a679ea958..7a5814656 100644 --- a/tensorrt_llm/quantization/layers.py +++ b/tensorrt_llm/quantization/layers.py @@ -992,7 +992,7 @@ def __init__( self.rotary_embedding_base = rotary_embedding_base self.rotary_embedding_dim = 0 if self.position_embedding_type.is_rope(): - self.rotary_embedding_dim = hidden_size // num_attention_heads + self.rotary_embedding_dim = self.attention_head_size self.quant_mode = quant_mode self.dtype = dtype diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index a680f1750..813b4cb9e 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -1013,11 +1013,16 @@ def __setup_decoder(self, input_ids: torch.Tensor, if scfg.output_log_probs: self.log_probs = torch.zeros( - (self.max_new_tokens, batch_size, scfg.num_beams), + (batch_size, scfg.num_beams, self.max_seq_length), + dtype=torch.float32, + device=self.device) + self.log_probs_tiled = torch.zeros( + (self.max_seq_length, batch_size, scfg.num_beams), dtype=torch.float32, device=self.device) else: self.log_probs = None + self.log_probs_tiled = None self.finished = torch.zeros((batch_size, scfg.num_beams), dtype=torch.uint8, @@ -2422,7 +2427,7 @@ def handle_per_step( this_src_cache_indirection, self.output_ids, self.new_tokens, self.finished, self.finished, self.sequence_length_buffer, self.cum_log_probs, - self.log_probs, self.parent_ids, + self.log_probs, self.log_probs_tiled, self.parent_ids, this_tgt_cache_indirection, self.beam_hyps_output_ids_tgt, self.beam_hyps_sequence_lengths_tgt, @@ -2527,6 +2532,10 @@ def decode_regular(self, def get_outputs_dict(output_ids): outputs = {} outputs['output_ids'] = output_ids + if scfg.output_log_probs: + outputs['log_probs'] = self.log_probs + if scfg.output_cum_log_probs: + outputs['cum_log_probs'] = self.cum_log_probs if output_sequence_lengths: outputs[ 'sequence_lengths'] = self.sequence_length_buffer.reshape( diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py index 03877c3cb..d77fff5f0 100644 --- a/tensorrt_llm/runtime/model_runner.py +++ b/tensorrt_llm/runtime/model_runner.py @@ -451,7 +451,10 @@ def from_dir(cls, max_medusa_tokens=pretrained_config.max_draft_len if hasattr( pretrained_config, 'max_draft_len') else 0, num_medusa_heads=pretrained_config.num_medusa_heads if hasattr( - pretrained_config, 'num_medusa_heads') else 0) + pretrained_config, 'num_medusa_heads') else 0, + use_custom_all_reduce=build_config.plugin_config. + use_custom_all_reduce, + ) max_batch_size = build_config.max_batch_size max_input_len = build_config.max_input_len max_output_len = build_config.max_output_len From d5a651d5667afbe54825ff8c9b5df4e921ad11d1 Mon Sep 17 00:00:00 2001 From: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Date: Wed, 21 Feb 2024 12:52:27 +0000 Subject: [PATCH 2/2] update --- .../aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a | 3 +++ .../libtensorrt_llm_batch_manager_static.pre_cxx11.a | 3 +++ cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt | 3 +++ .../aarch64-linux-gnu/libtensorrt_llm_executor_static.a | 3 +++ .../libtensorrt_llm_executor_static.pre_cxx11.a | 3 +++ cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt | 3 +++ tensorrt_llm/version.py | 2 +- 7 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a create mode 100644 cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a create mode 100644 cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt create mode 100644 cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a create mode 100644 cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a create mode 100644 cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a new file mode 100644 index 000000000..87e4bfec2 --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ba61c04ed7623fc44b5364802c1893fa824467455f4a9fe8245d5d51fef97e6 +size 2172266 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a new file mode 100644 index 000000000..ad0b8b08a --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf4afdfd281029c8e4bf0af548529b94a4a6d0f9bb5148ae10423e5e0275db06 +size 2195822 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt new file mode 100644 index 000000000..2b35efd9f --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt @@ -0,0 +1,3 @@ +4c405d39a0cbb93d44a5758480a1a223 libtensorrt_llm_batch_manager_static.a +68aea75a2ed5b219eec5a0f77ce33482 libtensorrt_llm_batch_manager_static.pre_cxx11.a +9b63c754d2a1edc7a17106e83c3e131d312f0a80 commit diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a new file mode 100644 index 000000000..7cba3bc7e --- /dev/null +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:49c84a22cee9e6c3a975db08d8d0d8dbe88867e2eb4fc12a4b3ff6c1c90e8c21 +size 586202 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a new file mode 100644 index 000000000..d979e1584 --- /dev/null +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fea93ae7d09e74b073a65d5d0ac34aec9ccc8f8299af1abd6826e97e9c8427f4 +size 589652 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt new file mode 100644 index 000000000..422368df6 --- /dev/null +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt @@ -0,0 +1,3 @@ +73999f4c2b3a4328db454b7ab6fe86d3 libtensorrt_llm_executor_static.a +df53aa83848b5ed75550a7b536ca02a4 libtensorrt_llm_executor_static.pre_cxx11.a +9b63c754d2a1edc7a17106e83c3e131d312f0a80 commit diff --git a/tensorrt_llm/version.py b/tensorrt_llm/version.py index d76cf5f0f..a8ad32389 100644 --- a/tensorrt_llm/version.py +++ b/tensorrt_llm/version.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.9.0.dev2024020600" +__version__ = "0.9.0.dev2024022000"