diff --git a/benchmarks/cpp/README.md b/benchmarks/cpp/README.md index e7835cd8d..07232e0be 100644 --- a/benchmarks/cpp/README.md +++ b/benchmarks/cpp/README.md @@ -93,16 +93,55 @@ cd cpp/build ./benchmarks/gptManagerBenchmark --help ``` -Take GPT-350M as an example for 2-GPU inflight batching -``` -mpirun -n 2 ./benchmarks/gptManagerBenchmark \ - --engine_dir ../../examples/gpt/trt_engine/gpt2-ib/fp16/2-gpu/ \ - --request_rate 10 \ - --dataset ../../benchmarks/cpp/preprocessed_dataset.json - --max_num_samples 500 -``` +`gptManagerBenchmark` now supports decoder-only models and encoder-decoder models. + +1. Decoder-only Models + + To benchmark decoder-only models, pass in the engine path with `--engine_dir` as executable input argument. + + Take GPT-350M as an example for 2-GPU inflight batching + ``` + mpirun -n 2 ./benchmarks/gptManagerBenchmark \ + --engine_dir ../../examples/gpt/trt_engine/gpt2-ib/fp16/2-gpu/ \ + --request_rate 10 \ + --dataset ../../benchmarks/cpp/preprocessed_dataset.json \ + --max_num_samples 500 + ``` + + `gptManagerBenchmark` by default uses the high-level C++ API defined by the `executor::Executor` class (see `cpp/include/tensorrt_llm/executor/executor.h`). + +2. Encoder-Decoder Models + To benchmark encoder-decoder models, pass in the encoder engine path with `--encoder_engine_dir` and the decoder engine path with `--decoder_engine_dir` as executable input arguments. `--decoder_engine_dir` is an alias of `--engine_dir`. + + Currently encoder-decoder engines only support `--api executor`, `--type IFB`, `--enable_kv_cache_reuse false`, which are all default values so no specific settings required. + + Prepare t5-small engine from [examples/enc_dec](/examples/enc_dec/README.md#convert-and-split-weights) for the encoder-decoder 4-GPU inflight batching example. + + Prepare the dataset suitable for engine input lengths. + ``` + python prepare_dataset.py \ + --tokenizer \ + --output cnn_dailymail.json \ + dataset \ + --dataset-name cnn_dailymail \ + --dataset-split validation \ + --dataset-config-name 3.0.0 \ + --dataset-input-key article \ + --dataset-prompt "Summarize the following article:" \ + --dataset-output-key "highlights" \ + --num-requests 100 \ + --max-input-len 512 \ + --output-len-dist 128,20 + ``` + + Run the benchmark + ``` + mpirun --allow-run-as-root -np 4 ./benchmarks/gptManagerBenchmark \ + --encoder_engine_dir ../../examples/enc_dec/tmp/trt_engines/t5-small-4gpu/bfloat16/encoder \ + --decoder_engine_dir ../../examples/enc_dec/tmp/trt_engines/t5-small-4gpu/bfloat16/decoder \ + --dataset cnn_dailymail.json + ``` -`gptManagerBenchmark` by default uses the high-level C++ API defined by the `executor::Executor` class (see `cpp/include/tensorrt_llm/executor/executor.h`). #### Emulated static batching @@ -125,7 +164,7 @@ Take GPT-350M as an example for single GPU with static batching ``` ./benchmarks/gptManagerBenchmark \ --engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \ - --request-rate -1 \ + --request_rate -1 \ --static_emulated_batch_size 32 \ --static_emulated_timeout 100 \ --dataset ../../benchmarks/cpp/tokens-fixed-lengths.json @@ -221,7 +260,7 @@ python benchmarks/cpp/utils/generate_rand_loras.py ${CPP_LORA} ${EG_DIR}/loras 1 # First run inference without LoRAs mkdir -p ${EG_DIR}/log-base-lora mpirun -n ${TP} --output-filename ${EG_DIR}/log-base-lora \ - cpp/build_Debug/benchmarks/gptManagerBenchmark \ + cpp/build/benchmarks/gptManagerBenchmark \ --engine_dir $LORA_ENGINE \ --type IFB \ --dataset "${EG_DIR}/data/token-norm-dist.json" \ @@ -239,7 +278,7 @@ mpirun -n ${TP} --output-filename ${EG_DIR}/log-base-lora \ for nloras in ${NUM_LORAS[@]}; do mkdir -p ${EG_DIR}/log-lora-${nloras} mpirun -n ${TP} --output-filename "${EG_DIR}/log-lora-${nloras}" \ - cpp/build_Debug/benchmarks/gptManagerBenchmark \ + cpp/build/benchmarks/gptManagerBenchmark \ --engine_dir $LORA_ENGINE \ --type IFB \ --dataset "${EG_DIR}/data/token-norm-dist-lora-${nloras}.json" \ diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index 71e83e127..645a3f238 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -155,6 +155,7 @@ struct BenchmarkParams std::optional maxNumTokens{std::nullopt}; int randomSeed = 430; std::optional maxAttentionWindow{std::nullopt}; + std::optional sinkTokenLength{std::nullopt}; bool multiBlockMode{false}; // lora / peft params @@ -784,10 +785,12 @@ class Recorder class ExecutorServer { public: - ExecutorServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth, - texec::CapacitySchedulerPolicy capacitySchedulerPolicy, BenchmarkParams const& benchmarkParams, - std::shared_ptr recorder, std::chrono::milliseconds waitSleep, - std::optional const staticEmulatedBatchSize, bool logIterationData) + ExecutorServer(std::optional const& decoderTrtEnginePath, + std::optional const& encoderTrtEnginePath, TrtGptModelType modelType, + int32_t maxBeamWidth, texec::CapacitySchedulerPolicy capacitySchedulerPolicy, + BenchmarkParams const& benchmarkParams, std::shared_ptr recorder, std::chrono::milliseconds waitSleep, + std::optional const staticEmulatedBatchSize, bool logIterationData, + texec::ModelType executorModelType) : mRecorder(std::move(recorder)) , mWaitSleep(waitSleep) , mStaticEmulatedBatchSize(staticEmulatedBatchSize) @@ -799,7 +802,7 @@ class ExecutorServer texec::SchedulerConfig schedulerConfig(capacitySchedulerPolicy); texec::KvCacheConfig kvCacheConfig(benchmarkParams.enableBlockReuse, benchmarkParams.maxTokensInPagedKvCache, - benchmarkParams.maxAttentionWindow, std::nullopt, benchmarkParams.freeGpuMemoryFraction, + benchmarkParams.maxAttentionWindow, benchmarkParams.sinkTokenLength, benchmarkParams.freeGpuMemoryFraction, benchmarkParams.kvHostCacheSize, benchmarkParams.kvOnboardBlocks); texec::PeftCacheConfig peftCacheConfig(0, benchmarkParams.loraDeviceNumModLayers, 8, 64, 4, 4, 4, 24, 8, std::nullopt, benchmarkParams.loraHostCacheSize); @@ -823,7 +826,25 @@ class ExecutorServer std::nullopt, benchmarkParams.medusaChoices)); executorConfig.setMultiBlockMode(benchmarkParams.multiBlockMode); - mExecutor = std::make_unique(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig); + if (executorModelType == texec::ModelType::kDECODER_ONLY) + { + mExecutor + = std::make_unique(decoderTrtEnginePath.value(), executorModelType, executorConfig); + } + else if (executorModelType == texec::ModelType::kENCODER_DECODER) + { + mExecutor = std::make_unique( + encoderTrtEnginePath.value(), decoderTrtEnginePath.value(), executorModelType, executorConfig); + } + else if (executorModelType == texec::ModelType::kENCODER_ONLY) + { + mExecutor + = std::make_unique(encoderTrtEnginePath.value(), executorModelType, executorConfig); + } + else + { + TLLM_LOG_ERROR("not a supported executor model type in executor server."); + } if (logIterationData) { @@ -1347,7 +1368,8 @@ std::shared_ptr makeRequest(std::uint64_t reqId, Sample const& texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamWidth, std::optional const& eosId, std::optional const& padId, bool streaming = false, bool const& returnContextLogits = false, bool const& returnGenerationLogits = false, - std::optional const& loraConfig = std::nullopt) + std::optional const& loraConfig = std::nullopt, + std::optional encoderInputTokenIds = std::nullopt) { auto samplingConfig = texec::SamplingConfig{beamWidth}; auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false}; @@ -1357,7 +1379,9 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW std::nullopt, // embeddingBias std::nullopt, // speculativeDecoding std::nullopt, // pTuning - loraConfig); + loraConfig, + std::nullopt, // logitsPostProcessorName + encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt); } void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType, @@ -1383,6 +1407,10 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType { optionalParams.kvCacheConfig.maxAttentionWindow = benchmarkParams.maxAttentionWindow; } + if (benchmarkParams.sinkTokenLength) + { + optionalParams.kvCacheConfig.sinkTokenLength = benchmarkParams.sinkTokenLength; + } optionalParams.kvCacheConfig.enableBlockReuse = benchmarkParams.enableBlockReuse; optionalParams.enableChunkedContext = benchmarkParams.enableChunkedContext; optionalParams.enableTrtOverlap = benchmarkParams.enableTrtOverlap; @@ -1526,12 +1554,13 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType gptServer->waitBatchManager(); } -void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType modelType, +void benchmarkExecutor(std::optional const& decoderEngineDir, + std::optional const& encoderEngineDir, TrtGptModelType modelType, std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp, std::optional const& eosId, std::optional const& padId, BenchmarkParams const& benchmarkParams, texec::CapacitySchedulerPolicy capacitySchedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits, std::optional const staticEmulatedBatchSize, - bool logIterationData, std::optional const maxPromptLen) + bool logIterationData, std::optional const maxPromptLen, texec::ModelType executorModelType) { auto const& world = tensorrt_llm::mpi::MpiComm::world(); auto worldRank = world.getRank(); @@ -1541,9 +1570,57 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m auto const numSamples = samples.size(); auto recorder = std::make_shared(opCsvFile, benchmarkParams.streaming, beamWidth); + int32_t decoderStartTokenId = 0; + std::shared_ptr executorServer; - auto executorServer = std::make_shared(engineDir, modelType, beamWidth, capacitySchedulerPolicy, - benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData); + if (executorModelType == texec::ModelType::kDECODER_ONLY) + { + TLLM_CHECK_WITH_INFO( + decoderEngineDir.has_value(), "decoder models require a path to decoder engine in executor benchmark."); + executorServer = std::make_shared(decoderEngineDir.value(), std::nullopt, modelType, beamWidth, + capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData, + executorModelType); + } + else if (executorModelType == texec::ModelType::kENCODER_DECODER) + { + TLLM_CHECK_WITH_INFO(encoderEngineDir.has_value(), + "encoder-decoder models require a path to encoder engine in executor benchmark."); + executorServer = std::make_shared(decoderEngineDir.value(), encoderEngineDir.value(), modelType, + beamWidth, capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, + logIterationData, executorModelType); + try + { + std::ifstream decoderJsonConfigPath(decoderEngineDir.value() / "config.json"); + auto const decoderPretrainedConfig + = nlohmann::json::parse(decoderJsonConfigPath, nullptr, true, true).at("pretrained_config"); + decoderStartTokenId = decoderPretrainedConfig.at("decoder_start_token_id").template get(); + } + catch (nlohmann::json::out_of_range& e) + { + TLLM_LOG_ERROR( + "Parameter %s cannot be read from decoder config.json in pretrained_config. Using default id %d.", + std::string("decoder_start_token_id").c_str(), decoderStartTokenId); + } + catch (nlohmann::json::type_error const& e) + { + TLLM_LOG_ERROR( + "Parameter %s has error type in decoder config.json in pretrained_config. Using default id %d.", + std::string("decoder_start_token_id").c_str(), decoderStartTokenId); + } + } + else if (executorModelType == texec::ModelType::kENCODER_ONLY) + { + TLLM_CHECK_WITH_INFO( + encoderEngineDir.has_value(), "encoder models require a path to encoder engine in executor benchmark."); + executorServer = std::make_shared(std::nullopt, encoderEngineDir.value(), modelType, beamWidth, + capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData, + executorModelType); + } + else + { + TLLM_LOG_ERROR("not a supported executor model type in executor benchmark."); + return; + } if (worldRank == 0) { @@ -1559,8 +1636,18 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m p.second->squeeze(0); texec::LoraConfig loraConfig( taskId, texec::detail::ofITensor(p.first), texec::detail::ofITensor(p.second)); - Sample s{std::vector{1, 2, 3, 4, 5}, 1, static_cast(taskId)}; - requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, loraConfig)); + if (executorModelType == texec::ModelType::kENCODER_DECODER) + { + Sample s{std::vector{decoderStartTokenId}, 1, static_cast(taskId)}; + requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, + loraConfig, std::vector{1, 2, 3, 4, 5})); + } + else + { + Sample s{std::vector{1, 2, 3, 4, 5}, 1, static_cast(taskId)}; + requests.emplace_back( + makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, loraConfig)); + } } executorServer->enqueue(std::move(requests), true); executorServer->waitForResponses(loras.getLoras().size(), true); @@ -1573,8 +1660,17 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m std::vector requests; for (auto i = 0; i < warmUp; ++i) { - requests.emplace_back(makeExecutorRequest(samples[0], beamWidth, eosId, padId, - benchmarkParams.streaming, returnContextLogits, returnGenerationLogits)); + if (executorModelType == texec::ModelType::kENCODER_DECODER) + { + Sample s{std::vector{decoderStartTokenId}, samples[0].outputLen, samples[0].taskId}; + requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, benchmarkParams.streaming, + returnContextLogits, returnGenerationLogits, std::nullopt, samples[0].inputIds)); + } + else + { + requests.emplace_back(makeExecutorRequest(samples[0], beamWidth, eosId, padId, + benchmarkParams.streaming, returnContextLogits, returnGenerationLogits)); + } } executorServer->enqueue(std::move(requests), true); executorServer->waitForResponses(warmUp, true); @@ -1595,8 +1691,17 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m { loraConfig = texec::LoraConfig(samples[i].taskId); } - requests.emplace_back(makeExecutorRequest(samples[i], beamWidth, eosId, padId, - benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig)); + if (executorModelType == texec::ModelType::kENCODER_DECODER) + { + Sample s{std::vector{decoderStartTokenId}, samples[i].outputLen, samples[i].taskId}; + requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, benchmarkParams.streaming, + returnContextLogits, returnGenerationLogits, loraConfig, samples[i].inputIds)); + } + else + { + requests.emplace_back(makeExecutorRequest(samples[i], beamWidth, eosId, padId, + benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig)); + } } bool const hasDelay @@ -1687,7 +1792,8 @@ int main(int argc, char* argv[]) cxxopts::Options options( "TensorRT-LLM BatchManager Benchmark", "TensorRT-LLM BatchManager Benchmark for GPT and GPT-like models."); options.add_options()("h,help", "Print usage"); - options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value()); + options.add_options()("engine_dir, decoder_engine_dir", "Directory that store the engines of decoder models.", + cxxopts::value()); options.add_options()( "api", "API type: gptManager or executor.", cxxopts::value()->default_value("executor")); options.add_options()("type", "Batching type: IFB, UIFB (unfused IFB) or V1 (non-IFB) batching.", @@ -1707,6 +1813,7 @@ int main(int argc, char* argv[]) options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value()); options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value()); options.add_options()("max_attention_window", "Max KV cache length per sequence", cxxopts::value()); + options.add_options()("sink_token_len", "Sink token length in kv cache per sequence.", cxxopts::value()); options.add_options()( "random_seed", "integer random seed for exponential time delays.", cxxopts::value()->default_value("420")); options.add_options()( @@ -1781,6 +1888,8 @@ int main(int argc, char* argv[]) options.add_options()("multi_block_mode", "Distribute the work across multiple CUDA thread-blocks on the GPU for masked MHA kernel", cxxopts::value()->default_value("false")); + options.add_options()( + "encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value()); auto result = options.parse(argc, argv); @@ -1791,7 +1900,7 @@ int main(int argc, char* argv[]) } // Argument: Engine directory - if (!result.count("engine_dir")) + if (!result.count("engine_dir") && !result.count("encoder_engine_dir")) { std::cout << options.help() << std::endl; TLLM_LOG_ERROR("Please specify engine directory."); @@ -1848,6 +1957,12 @@ int main(int argc, char* argv[]) benchmarkParams.maxAttentionWindow = result["max_attention_window"].as(); } + // Argument: Sink token length + if (result.count("sink_token_len")) + { + benchmarkParams.sinkTokenLength = result["sink_token_len"].as(); + } + if (result.count("random_seed")) { benchmarkParams.randomSeed = result["random_seed"].as(); @@ -2049,12 +2164,33 @@ int main(int argc, char* argv[]) } else if (api == "executor") { + texec::ModelType executorModelType; + std::optional decoderEngineDir = std::nullopt, encoderEngineDir = std::nullopt; + if (result.count("encoder_engine_dir") && result.count("engine_dir")) + { + TLLM_CHECK_WITH_INFO(api == "executor", "encoder-decoder only support executor api."); + TLLM_CHECK_WITH_INFO( + modelType == TrtGptModelType::InflightFusedBatching, "encoder-decoder only support inflight batching."); + executorModelType = texec::ModelType::kENCODER_DECODER; + decoderEngineDir = result["engine_dir"].as(); + encoderEngineDir = result["encoder_engine_dir"].as(); + } + else if (result.count("engine_dir")) + { + executorModelType = texec::ModelType::kDECODER_ONLY; + decoderEngineDir = result["engine_dir"].as(); + } + else + { + executorModelType = texec::ModelType::kENCODER_ONLY; + encoderEngineDir = result["encoder_engine_dir"].as(); + } try { - benchmarkExecutor(result["engine_dir"].as(), modelType, datasetPath, opCsvFile, maxNumSamples, + benchmarkExecutor(decoderEngineDir, encoderEngineDir, modelType, datasetPath, opCsvFile, maxNumSamples, beamWidth, result["warm_up"].as(), eosId, padId, benchmarkParams, capacitySchedulerPolicy, waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData, - maxPromptLen); + maxPromptLen, executorModelType); } catch (std::exception const& e) { diff --git a/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py b/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py index d7302098f..4dd6797f3 100644 --- a/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py +++ b/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py @@ -16,7 +16,7 @@ "meta-llama/Llama-2-70b-hf": "rope_gpt_neox", "meta-llama/Meta-Llama-3-8B": "rope_gpt_neox", "meta-llama/Meta-Llama-3-70B": "rope_gpt_neox", - "EleutherAI/gpt-j-6b": "rope_gptj", + "gpt-j-6b": "rope_gptj", "bigscience/bloom-560m": "alibi", "mistralai/Mistral-7B-v0.1": "rope_gpt_neox", "mistralai/Mixtral-8x7B-v0.1": "rope_gpt_neox", @@ -126,7 +126,7 @@ class TRTLLMConfig(BaseModel): moe_top_k: Optional[int] = Field( default=0, validation_alias=AliasChoices("num_experts_per_tok")) rotary_base: Optional[float] = Field( - default=None, validation_alias=AliasChoices("rope_theta")) + default=10000.0, validation_alias=AliasChoices("rope_theta")) mapping: TRTLLM_Mapping quantization: TRTLLM_Quantization @@ -176,8 +176,9 @@ def populate_build_config(cls, "quant_algo": quant_dtype, "kv_cache_quant_algo": kv_cache_quant_dtype, } - if model_name in PET_dict: - build_config["position_embedding_type"] = PET_dict.get(model_name) + for name, pet in PET_dict.items(): + if name in str(model_name): + build_config["position_embedding_type"] = pet return build_config @classmethod diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index b83e2b741..0249dc982 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -718,16 +718,9 @@ class GenericLlmRequest mReturnGenerationLogits = returnGenerationLogits; } - // Return all generation logits for model w/o draft token [[nodiscard]] bool getReturnGenerationLogits() const { - return mReturnGenerationLogits && (getNumDraftTokens() == 0); - } - - // Return accepted tokens logits for target model - [[nodiscard]] bool getReturnTargetModelAcceptedLogits() const - { - return mReturnGenerationLogits && (getNumDraftTokens() > 0); + return mReturnGenerationLogits; } [[nodiscard]] TensorPtr const& getContextLogitsHost() const @@ -735,6 +728,7 @@ class GenericLlmRequest return mContextLogitsHost; } + /// @param contextLogitsHost Expected shape [promtLen, vocabSizePadded] void setContextLogitsHost(TensorPtr contextLogitsHost) { mContextLogitsHost = std::move(contextLogitsHost); @@ -751,6 +745,9 @@ class GenericLlmRequest return mGenerationLogitsHost; } + /// @param generationLogitsHost Expected shape + /// * [beamWidth, maxNewTokens, vocabSizePadded] for non-speculative decoding + /// * [1, numDraftTokens + 1, vocabSizePadded] for speculative decoding void setGenerationLogitsHost(TensorPtr generationLogitsHost) { mGenerationLogitsHost = std::move(generationLogitsHost); @@ -765,7 +762,7 @@ class GenericLlmRequest void allocTargetModelAcceptedTokenLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType) { mGenerationLogitsHost = runtime::BufferManager::pinnedPool( - runtime::ITensor::makeShape({getNumDraftTokens() + 1, vocabSizePadded}), logitsDataType); + runtime::ITensor::makeShape({1, getNumDraftTokens() + 1, vocabSizePadded}), logitsDataType); } [[nodiscard]] std::vector const& getGenerationLogitsFragments() const @@ -966,18 +963,6 @@ class GenericLlmRequest result.generationLogits = executor::detail::ofITensor(getGenerationLogitsHost()); } - if (getReturnTargetModelAcceptedLogits()) - { - auto targetModelAcceptedTokenLogitsShape = getGenerationLogitsHost()->getShape(); - TLLM_CHECK(targetModelAcceptedTokenLogitsShape.nbDims == 2); - auto numAcceptedToken = targetModelAcceptedTokenLogitsShape.d[0]; - auto vocabSizePadded = targetModelAcceptedTokenLogitsShape.d[1]; - // Align the shape of accepted token logits and generation logits - TensorPtr targetModelAcceptedTokenLogitsHostView = runtime::ITensor::view( - getGenerationLogitsHost(), runtime::ITensor::makeShape({1, numAcceptedToken, vocabSizePadded})); - result.generationLogits = executor::detail::ofITensor(targetModelAcceptedTokenLogitsHostView); - } - if (getReturnEncoderOutput()) { result.encoderOutput = executor::detail::ofITensor(getEncoderOutputHost()); diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h index 49bac7b02..923e150a8 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h @@ -349,7 +349,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture #ifndef ENABLE_FP8 static_assert(!FP8, "FP8 Tests enabled on unsupported CUDA version"); #endif - bool should_skip_unsupported_fp8 = getSMVersion() < 90 && FP8; + bool should_skip_unsupported_fp8 = getSMVersion() < 89 && FP8; return should_skip_unsupported_fp8; } diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a index 3841fb940..6b1d112ae 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7eec52cb658f033cf3146017cbaa3ea1554942ee7ece49329ddf7b01361fa080 -size 4293100 +oid sha256:96130d5dc94a0373331ce58d169e5769d90fea8c103c4ee800cb7ac84b5a901d +size 4285916 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index 18938e885..c2b3bd530 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:cf65778d6469a5a85bf2191fb104094aa4e606b370a25475a16017329e27fd95 -size 4395148 +oid sha256:1e1256cca1ddde6dbf716906494d5371aa63dc144ebe8f6e868169ed771405ec +size 4387206 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt index d8baf2cde..82cd69565 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -08d59f31da00044ae21995c6573a55da libtensorrt_llm_batch_manager_static.a -abdb9b58e0a4587d2d2ce6bc83655f8a libtensorrt_llm_batch_manager_static.pre_cxx11.a -315e9f5ccd286e906d4c0d402fefbf2f69a1febe commit \ No newline at end of file +ea8bb6e3a155175a0dcfc0e87d1e7f25 libtensorrt_llm_batch_manager_static.a +42ebee6d5349709e33ab36a82f6fef4d libtensorrt_llm_batch_manager_static.pre_cxx11.a +8baf57c648b66a48dbe29f766c6fdff505045f24 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a index 10cbcad26..3141a4bd5 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e339bca2212b46c6227b328fc376db4628a0a96636b5f2b5b3ae387e884b7f01 -size 4155892 +oid sha256:32f765a854cd199462ea24be09a08c8f3596c7f96f8c7f137b8137712909369c +size 4146754 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index 8f5ebab37..04764bee4 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7503446c4ef7b959970fc02b33ca81dd0dece0663d9a0f8b881c60ff66006000 -size 4136818 +oid sha256:f5fc2638623716f1067b1ba7cd8fabc651bc4fec40a716625739f819ebfb8270 +size 4128508 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib index 11c9d75bc..e1f890302 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib +++ b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:51174b20ed939662c92d21cdd5a0fd652a6592947270182ff026eb3a4153e4cf -size 24015602 +oid sha256:b7217580ab669c0f41d94acf7b9a7e27f99b18db5f660643572e9cb98511a588 +size 23991606 diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h index 3db9bf532..3f8347027 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -21,6 +21,7 @@ #include "cutlass/bfloat16.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/gemm.h" +#include "cutlass/half.h" #include "cutlass/layout/matrix.h" #include "cutlass_extensions/arch/mma.h" @@ -138,11 +139,7 @@ struct MixedGemmArchTraits struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value -#ifdef ENABLE_FP8 - || cutlass::platform::is_same::value -#endif - >::type> + || cutlass::platform::is_same::value>::type> { private: using LayoutDetails = LayoutDetailsB; @@ -162,6 +159,32 @@ struct MixedGemmArchTraits +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + // be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t + using TypeC = __nv_bfloat16; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + } // namespace kernel } // namespace gemm } // namespace cutlass diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index a1712431e..8ac984faf 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -77,6 +77,23 @@ struct LayoutDetailsB +struct LayoutDetailsB +{ + static constexpr int ThreadblockK = 64; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; + // for fast accumulation + // using Operator = cutlass::arch::OpMultiplyAddFastAccum; +}; + // Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, // which signals that we want to dequantize after loading from smem. template diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h index 08c373870..557700705 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -446,8 +446,6 @@ struct MoeFCGemm // Epilogue // - EpilogueOutputOp output_op(params.output_op); - ElementC* ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; @@ -468,7 +466,20 @@ struct MoeFCGemm Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); + if constexpr (platform::is_same>::value) + { + EpilogueOutputOp output_op(params.output_op, problem_idx); + epilogue(output_op, iterator_D, accumulators, iterator_C); + } + else + { + EpilogueOutputOp output_op(params.output_op); + epilogue(output_op, iterator_D, accumulators, iterator_C); + } // Next tile problem_visitor.advance(gridDim.x); @@ -501,8 +512,19 @@ struct MoeFCGemm run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900) + constexpr bool isFp8 = platform::is_same::value + || platform::is_same::value; + if constexpr (isFp8) + { + run_kernel(params, shared_storage); + } + else + { // reuse sm80 kernel for other types, align with dispatchToArch + run_kernel(params, shared_storage); + } #elif (__CUDA_ARCH__ >= 900) run_kernel(params, shared_storage); #else diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h index 94b457d16..2fd01d2a9 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -60,7 +60,11 @@ enum class CutlassTileConfig CtaShape256x128x64_WarpShape64x64x64, // TensorCore config CTA_N = 256, CTA_K = 64 - CtaShape16x256x64_WarpShape16x64x64 + CtaShape16x256x64_WarpShape16x64x64, + + // TensorCore config CTA_N = 256, CTA_K = 128 + CtaShape16x256x128_WarpShape16x64x128 + }; enum class SplitKStyle @@ -129,6 +133,7 @@ struct CutlassGemmConfig INT8_ONLY = 1u << 2, HOPPER = 1u << 3, GROUPED_GEMM = 1u << 4, + FP8_ONLY = 1u << 5, }; CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a index 85f8cbe8b..88170f4a2 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:19fdeb78169c29492026b62bf147481e2b0d893916d9a20333d83fb61c0abe36 -size 1428026 +oid sha256:bb26beeee542e271d25f23c6942d9769f75cd6fc0d011a80920f246218189282 +size 1422872 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a index 8f0790654..44d7e32bd 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1d7f36c49f24730e4038c2252b966870789d9c9cff698ccd50d0f61ae85fcc9d -size 1455538 +oid sha256:0df6e73d5a2d81b82b35721347791f7493288f83c05894cfae01f84924a65524 +size 1450382 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt index bc69891ba..a822ad769 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -5bdad7b823b79b1b91439693aa25cff5 libtensorrt_llm_executor_static.a -566734842bb731319971850583fdc9c7 libtensorrt_llm_executor_static.pre_cxx11.a -315e9f5ccd286e906d4c0d402fefbf2f69a1febe commit \ No newline at end of file +c5aa4f49d7673d8115ffc35803304698 libtensorrt_llm_executor_static.a +3109c235a62431315f72ab8fc652f3b0 libtensorrt_llm_executor_static.pre_cxx11.a +8baf57c648b66a48dbe29f766c6fdff505045f24 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a index 3e5a2132d..d3250ba94 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:58e3e6d7414ab730ba54c8aabdc5f193787b44699e1289279428087cbb2e46d4 -size 1478178 +oid sha256:6371e085289077f9b70b9a4fe964f4265dd0f78c7f1921e0467f7b1a3d610f7e +size 1475170 diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a index be1049abe..221c1b4f1 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5f6598d6c2dafd9b97edfeb8fc424607374e8791c4e334cfaaf5cae865da15c6 -size 1410466 +oid sha256:88f31eeb0486e8d6918d37b124a1684b97990c812fe3d00686dd4b388d4c9d02 +size 1405742 diff --git a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib index 1051e2be6..d232b0fc2 100644 --- a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib +++ b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:93e0c81a8d00db0e860cdfdafbae7391e0d2956c2301da1f22ef6419bcb4e02f -size 14321264 +oid sha256:dd532185af010b1972175ba5c1efa5ec5f336b5c0e03c312fa02f01dfb37f1c7 +size 14313722 diff --git a/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu b/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu index d6e054ab1..5b230443b 100644 --- a/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu +++ b/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu @@ -1117,9 +1117,14 @@ void AllReduceDispatchType(AllReduceParams& params, AllReduceStrategyType strat, } } -AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value) +AllReduceParams AllReduceParams::deserialize(int64_t* buffer, size_t tpSize, size_t tpRank) { void* const* buffer_ptrs = reinterpret_cast(buffer); + auto const flag_ptr = &buffer[4 * tpSize]; + // cannot use 0 since 0 represents released state for barrier + *flag_ptr += 1; + TLLM_LOG_TRACE("AllReduceParams's flag value is %d", *flag_ptr); + uint32_t flag_value = *flag_ptr; AllReduceParams params; // Even plugins use ping buffers, odd plugins use pong. // That way, we don't need to wait for other GPUs to be done diff --git a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h index 652397764..ebe6b8795 100644 --- a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h +++ b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h @@ -91,7 +91,7 @@ struct AllReduceParams AllReduceFusionParams fusion_params; - static AllReduceParams deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value); + static AllReduceParams deserialize(int64_t* buffer, size_t tpSize, size_t tpRank); }; bool configurationSupported(AllReduceStrategyType algo, size_t msg_size, size_t n_ranks, nvinfer1::DataType type); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 0b8fd224c..54260b2a1 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -66,6 +66,7 @@ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: return TileShape{128, 128}; case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: return TileShape{128, 256}; case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: return TileShape{256, 128}; + case CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: return TileShape{16, 256}; default: TLLM_THROW("[get_grid_shape_for_config] Invalid config"); } } @@ -118,7 +119,8 @@ std::vector get_candidate_tiles( Default, WeightOnly, Simt, - Int8 + Int8, + Fp8 }; CutlassGemmType gemm_type = CutlassGemmType::Default; @@ -134,6 +136,10 @@ std::vector get_candidate_tiles( { gemm_type = CutlassGemmType::Int8; } + else if (config_type_param & CutlassGemmConfig::FP8_ONLY) + { + gemm_type = CutlassGemmType::Fp8; + } std::vector base_configs{ CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; @@ -166,6 +172,25 @@ std::vector get_candidate_tiles( CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + case CutlassGemmType::Fp8: + if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) + { + if (sm == 89) + { + return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + } + else + { + // no valid ampere style fp8 configs for sm90 + return {}; + } + } default: return base_configs; } } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index 18b746e72..f1b7bb0f7 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -548,6 +548,7 @@ template CutlassFpAIntBGemmRunner::getConfigs() const { + static constexpr bool is_weight_only = !std::is_same::value; tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param = tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h index 48c5ffdf8..ab9aad716 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h @@ -157,14 +157,19 @@ class MoeGemmRunner public: MoeGemmRunner(); - void moeGemmBiasAct(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, - int64_t const* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, - int64_t gemm_k, int num_experts, ActivationType activation_type, bool use_fused_moe, cudaStream_t stream, - cutlass_extensions::CutlassGemmConfig chosen_conf); + using HopperGemmOutputType = typename HopperGroupedGemmInput::OutputTypeAdaptor_t; + static constexpr bool use_fp8 = std::is_same::value || std::is_same::value; - void moeGemm(T const* A, WeightType const* B, T const* weight_scales, T* C, int64_t const* total_rows_before_expert, + void moeGemmBiasAct(T const* A, WeightType const* B, HopperGemmOutputType const* weight_scales, + HopperGemmOutputType const* biases, HopperGemmOutputType* C, int64_t const* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - bool use_fused_moe, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf); + ActivationType activation_type, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + cutlass_extensions::CutlassGemmConfig chosen_conf); + + void moeGemm(T const* A, WeightType const* B, HopperGemmOutputType const* weight_scales, HopperGemmOutputType* C, + int64_t const* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, + int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + cutlass_extensions::CutlassGemmConfig chosen_conf); std::vector getConfigs() const; static std::vector getConfigs(int sm); @@ -183,15 +188,17 @@ class MoeGemmRunner private: template - void dispatchToArch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, - int64_t const* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, - int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, + void dispatchToArch(T const* A, WeightType const* B, HopperGemmOutputType const* weight_scales, + HopperGemmOutputType const* biases, HopperGemmOutputType* C, int64_t const* total_rows_before_expert, + HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr); template - void runGemm(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, - int64_t const* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, - int64_t gemm_k, int num_experts, bool use_fused_moe, cudaStream_t stream, + void runGemm(T const* A, WeightType const* B, HopperGemmOutputType const* weight_scales, + HopperGemmOutputType const* biases, HopperGemmOutputType* C, int64_t const* total_rows_before_expert, + HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf); private: diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h index 8f1650261..1657ca911 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h @@ -70,14 +70,20 @@ namespace kernels::cutlass_kernels { // ============================= Variable batched Gemm things =========================== -template -void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, - int64_t const* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count, bool use_fused_moe, - cudaStream_t stream, int* kernel_occupancy = nullptr) +template +void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, GemmOutputType* C, int64_t const* total_rows_before_expert, int64_t num_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, + int const multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + int* kernel_occupancy = nullptr) { -#ifdef ENABLE_BF16 +#if defined(ENABLE_FP8) + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for fp8, bfloat16, half, float"); +#elif defined(ENABLE_BF16) static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, "Specialized for bfloat16, half, float"); @@ -96,6 +102,7 @@ void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, T const* weig // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. using ElementType = typename TllmToCutlassTypeAdapter::type; + using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter::type; using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; if (!use_fused_moe) { @@ -104,15 +111,28 @@ void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, T const* weig using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; using ElementAccumulator = typename MixedGemmArchTraits::AccType; - using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue::Op; + typename EpilogueOp::Params epilogue_op( + ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); + + if constexpr ((std::is_same_v + || std::is_same_v) &&std::is_same_v) + { + TLLM_CHECK_WITH_INFO(weight_scales == nullptr && biases == nullptr && alpha_scale_ptr_array, + "weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 " + "Ada"); + epilogue_op.alpha_ptr_array = alpha_scale_ptr_array; + } + // Finally, set up the kernel. using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped::GemmKernel; @@ -133,14 +153,12 @@ void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, T const* weig TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); int const threadblock_count = multi_processor_count * occupancy; - typename EpilogueOp::Params epilogue_op( - ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); - int const group_size = gemm_k; typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op, reinterpret_cast(A), reinterpret_cast(B), - reinterpret_cast(weight_scales), reinterpret_cast(biases), - reinterpret_cast(C), total_rows_before_expert, gemm_n, gemm_k); + reinterpret_cast(weight_scales), + reinterpret_cast(biases), reinterpret_cast(C), + total_rows_before_expert, gemm_n, gemm_k); GemmGrouped gemm; @@ -173,20 +191,24 @@ void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, T const* weig } // namespace kernels::cutlass_kernels -template -static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, - int64_t const* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - cudaStream_t stream, int* occupancy = nullptr) +template +static void dispatch(T const* A, WeightType const* B, GemmOutputType const* weight_scales, GemmOutputType const* biases, + GemmOutputType* C, int64_t const* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) { + static_assert(!std::is_same_v, "Use TMA specialised functions for arch SM90"); constexpr bool isFp8 = std::is_same_v || std::is_same_v; - if constexpr ((Stages == 2 || Arch::kMinComputeCapability >= 80) && !isFp8) + + if constexpr ((Stages == 2 || Arch::kMinComputeCapability >= 80) + && (!isFp8 || std::is_same_v) ) { - kernels::cutlass_kernels::genericMoeGemmKernelLauncher(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); + kernels::cutlass_kernels::genericMoeGemmKernelLauncher(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); } else { @@ -195,29 +217,30 @@ static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T } } -template -void dispatchGemmConfig(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, - int64_t const* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - cudaStream_t stream, int* occupancy = nullptr) +template +void dispatchGemmConfig(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, GemmOutputType* C, int64_t const* total_rows_before_expert, int64_t num_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, + int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + int* occupancy = nullptr) { switch (gemm_config.stages) { case 2: - dispatch(A, B, weight_scales, biases, C, - total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, stream, occupancy); + dispatch(A, B, weight_scales, + biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); break; case 3: - dispatch(A, B, weight_scales, biases, C, - total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, stream, occupancy); + dispatch(A, B, weight_scales, + biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); break; case 4: - dispatch(A, B, weight_scales, biases, C, - total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, stream, occupancy); + dispatch(A, B, weight_scales, + biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); break; default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break; } @@ -225,12 +248,14 @@ void dispatchGemmConfig(T const* A, WeightType const* B, T const* weight_scales, // This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. // This overload is only enabled when T == WeightType. -template ::value && std::is_same::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, - int64_t const* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - cudaStream_t stream, int* occupancy = nullptr) +template ::value && !std::is_same::value + && !std::is_same::value && std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, GemmOutputType* C, int64_t const* total_rows_before_expert, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, + int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + int* occupancy = nullptr) { switch (gemm_config.tile_config) { @@ -238,36 +263,39 @@ void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_s TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); if constexpr (arch::kMinComputeCapability >= 75) { - dispatchGemmConfig, + dispatchGemmConfig, cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, - total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, - occupancy); + total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, + alpha_scale_ptr_array, stream, occupancy); } break; case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); if constexpr (arch::kMinComputeCapability >= 75) { - dispatchGemmConfig, + dispatchGemmConfig, cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, - total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, - occupancy); + total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, + alpha_scale_ptr_array, stream, occupancy); } break; case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatchGemmConfig, + dispatchGemmConfig, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: - dispatchGemmConfig, + dispatchGemmConfig, cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: - dispatchGemmConfig, + dispatchGemmConfig, cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: @@ -280,12 +308,13 @@ void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_s // Tensorop GEMM overload // Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve // compile time -template ::value && !std::is_same::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, - int64_t const* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - cudaStream_t stream, int* occupancy = nullptr) +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, GemmOutputType* C, int64_t const* total_rows_before_expert, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, + int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + int* occupancy = nullptr) { switch (gemm_config.tile_config) { @@ -293,36 +322,39 @@ void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_s TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); if constexpr (arch::kMinComputeCapability >= 75) { - dispatchGemmConfig, + dispatchGemmConfig, cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, - total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, - occupancy); + total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, + alpha_scale_ptr_array, stream, occupancy); } break; case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); if constexpr (arch::kMinComputeCapability >= 75) { - dispatchGemmConfig, + dispatchGemmConfig, cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, - total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, - occupancy); + total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, + alpha_scale_ptr_array, stream, occupancy); } break; case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatchGemmConfig, + dispatchGemmConfig, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatchGemmConfig, + dispatchGemmConfig, cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatchGemmConfig, + dispatchGemmConfig, cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: @@ -332,20 +364,85 @@ void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_s } } +// This overload will handle tensorop gemms. +// This overload is only enabled when T == WeightType and T == __nv_fp8_e4m3 or __nv_fp8_e5m2 +template ::value || std::is_same::value) + && std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, GemmOutputType* C, int64_t const* total_rows_before_expert, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, + int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 64, 128>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; + } +} + // This overload will handle simt gemms. It is disabled via SFINAE for tensorop. -template ::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, - int64_t const* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - cudaStream_t stream, int* occupancy = nullptr) +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, GemmOutputType* C, int64_t const* total_rows_before_expert, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, + int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + int* occupancy = nullptr) { switch (gemm_config.tile_config) { case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: - dispatchGemmConfig, + dispatchGemmConfig, cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, + stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: @@ -379,12 +476,13 @@ std::vector MoeGemmRunner: = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; static constexpr auto simt_only_flag = std::is_same::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE; + static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; int const max_split_k = 1; int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM; int const enable_hopper = CutlassGemmConfig::NONE; auto config_type_param = static_cast( - weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper); + weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) { @@ -407,9 +505,9 @@ std::vector MoeGemmRunner: int const max_split_k = 1; int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM; int const enable_hopper = CutlassGemmConfig::HOPPER; - + static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; auto config_type_param = static_cast( - weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper); + weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); if (!kernels::cutlass_kernels::isValidHopperMOESpecialisation()) { @@ -467,34 +565,46 @@ MoeGemmRunner::MoeGemmRunner() template template -void MoeGemmRunner::dispatchToArch(T const* A, WeightType const* B, T const* weight_scales, - T const* biases, T* C, int64_t const* total_rows_before_expert, HopperGroupedGemmInput hopper_input, - int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, cudaStream_t stream, int* occupancy) +void MoeGemmRunner::dispatchToArch(T const* A, WeightType const* B, + HopperGemmOutputType const* weight_scales, HopperGemmOutputType const* biases, HopperGemmOutputType* C, + int64_t const* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, + int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy) { TLLM_CHECK_WITH_INFO( - sm_ == 90 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation"); + sm_ >= 89 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation"); TLLM_CHECK_WITH_INFO( sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture"); if (sm_ >= 70 && sm_ < 75) { - dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, - total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, - use_fused_moe, stream, occupancy); + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); } else if (sm_ >= 75 && sm_ < 80) { - dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, - total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, - use_fused_moe, stream, occupancy); + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); } else if (sm_ >= 80 && sm_ < 90) { - dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, - total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, - use_fused_moe, stream, occupancy); + + if constexpr (use_fp8) + { + TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + else + { + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } } else if (sm_ >= 90) { @@ -524,10 +634,9 @@ void MoeGemmRunner::dispatchToArch(T const* A, Weigh "information is not required"); TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90, "GEMM config is for SM90 configuration, but this configuration is not valid for Hppper"); - - dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, - total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, - use_fused_moe, stream, occupancy); + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); } else { @@ -590,46 +699,54 @@ size_t MoeGemmRunner::calcMaxWorkspaceSize(int num_experts) const template template -void MoeGemmRunner::runGemm(T const* A, WeightType const* B, T const* weight_scales, - T const* biases, T* C, int64_t const* total_rows_before_expert, HopperGroupedGemmInput hopper_input, - int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, cudaStream_t stream, +void MoeGemmRunner::runGemm(T const* A, WeightType const* B, HopperGemmOutputType const* weight_scales, + HopperGemmOutputType const* biases, HopperGemmOutputType* C, int64_t const* total_rows_before_expert, + HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) { dispatchToArch(A, B, weight_scales, biases, C, total_rows_before_expert, hopper_input, total_rows, - gemm_n, gemm_k, num_experts, chosen_conf, use_fused_moe, stream); + gemm_n, gemm_k, num_experts, chosen_conf, use_fused_moe, alpha_scale_ptr_array, stream, nullptr); } template -void MoeGemmRunner::moeGemmBiasAct(T const* A, WeightType const* B, T const* weight_scales, - T const* biases, T* C, int64_t const* total_rows_before_expert, HopperGroupedGemmInput hopper_input, - int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, - bool use_fused_moe, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) +void MoeGemmRunner::moeGemmBiasAct(T const* A, WeightType const* B, + HopperGemmOutputType const* weight_scales, HopperGemmOutputType const* biases, HopperGemmOutputType* C, + int64_t const* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, + int64_t gemm_k, int num_experts, ActivationType activation_type, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) { switch (activation_type) { case ActivationType::Relu: runGemm(A, B, weight_scales, biases, C, total_rows_before_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream, chosen_conf); + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream, + chosen_conf); break; case ActivationType::Gelu: runGemm(A, B, weight_scales, biases, C, total_rows_before_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream, chosen_conf); + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream, + chosen_conf); break; case ActivationType::Silu: runGemm(A, B, weight_scales, biases, C, total_rows_before_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream, chosen_conf); + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream, + chosen_conf); break; case ActivationType::Identity: runGemm(A, B, weight_scales, biases, C, total_rows_before_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream, chosen_conf); + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream, + chosen_conf); break; case ActivationType::Swiglu: runGemm(A, B, weight_scales, biases, C, total_rows_before_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream, chosen_conf); + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream, + chosen_conf); break; case ActivationType::Geglu: runGemm(A, B, weight_scales, biases, C, total_rows_before_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream, chosen_conf); + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream, + chosen_conf); break; case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; default: TLLM_THROW("Invalid activation type."); break; @@ -637,13 +754,14 @@ void MoeGemmRunner::moeGemmBiasAct(T const* A, WeightType const* } template -void MoeGemmRunner::moeGemm(T const* A, WeightType const* B, T const* weight_scales, T* C, - int64_t const* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, - int64_t gemm_k, int num_experts, bool use_fused_moe, cudaStream_t stream, - cutlass_extensions::CutlassGemmConfig chosen_conf) +void MoeGemmRunner::moeGemm(T const* A, WeightType const* B, HopperGemmOutputType const* weight_scales, + HopperGemmOutputType* C, int64_t const* total_rows_before_expert, HopperGroupedGemmInput hopper_input, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) { runGemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream, chosen_conf); + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream, + chosen_conf); } } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h index cc52e6363..959d0ea08 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h @@ -38,13 +38,7 @@ constexpr bool isValidHopperMOESpecialisation() template constexpr bool isValidAmpereMOESpecialisation() { -#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) and defined(ENABLE_FP8) - constexpr bool is_fp8 - = cutlass::platform::is_same::value || cutlass::platform::is_same::value; - return !is_fp8; -#else - return true; // Default to true -#endif + return true; // Default to true } } // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt index 83ba30e7e..1205e65e3 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt @@ -1,2 +1,2 @@ 957f7c6034dca28dff7afe65ed68aa4b libtensorrt_llm_nvrtc_wrapper.so -315e9f5ccd286e906d4c0d402fefbf2f69a1febe commit \ No newline at end of file +8baf57c648b66a48dbe29f766c6fdff505045f24 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll index 2982fc7fc..906a3c305 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:73ea01f6014e5c11a263f342f8c19f3a1b8bfa824441accd3cb4b7fa699a9d9a -size 1087488 +oid sha256:e5dc96df086943a4d3bb57cd825fc89f3774658a9c9a9935ee2de6f1e033edb2 +size 1089536 diff --git a/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu b/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu index 6b88adbc8..8e3cfc693 100644 --- a/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu @@ -67,6 +67,43 @@ namespace tensorrt_llm::kernels static constexpr int WARP_SIZE = 32; +// ====================== Compute FP8 dequant scale only =============================== +__global__ void computeFP8DequantScaleKernel( + float const** alpha_scale_ptr_array, int64_t const num_experts, float const* fp8_dequant) +{ + // First, compute the global tid. We only need 1 thread per expert. + int const expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) + { + return; + } + + if (fp8_dequant) + { + alpha_scale_ptr_array[expert] = fp8_dequant + expert; + } +} + +float const** computeFP8DequantScale( + float const** alpha_scale_ptr_array, int const num_experts, float const* fp8_dequant, cudaStream_t stream) +{ + + if (!fp8_dequant) + { + alpha_scale_ptr_array = nullptr; + return alpha_scale_ptr_array; + } + else + { + int const threads = std::min(1024, num_experts); + int const blocks = (num_experts + threads - 1) / threads; + + computeFP8DequantScaleKernel<<>>(alpha_scale_ptr_array, num_experts, fp8_dequant); + + return alpha_scale_ptr_array; + } +} + // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. @@ -533,6 +570,13 @@ size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs, int int* null_int = nullptr; cub::DeviceRadixSort::SortPairs( nullptr, required_storage, null_int, null_int, null_int, null_int, num_key_value_pairs, 0, num_bits); + + // TODO: fix DeviceRadixSort + // when num_key_value_pairs, num_experts, num_bits, required_storage = 64, 4, 3, 0 + if (required_storage == 0) + { + required_storage = 1; + } return required_storage; } @@ -787,7 +831,7 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; // This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection. template __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted_rows, - OutputType* reduced_unpermuted_output, T const* bias, float const* scales, + OutputType* reduced_unpermuted_output, GemmOutputType const* bias, float const* scales, int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, int64_t const* num_valid_ptr) { @@ -806,7 +850,7 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted int64_t const stride = FINALIZE_THREADS_PER_BLOCK; int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD; - using BiasElem = cutlass::Array; + using BiasElem = cutlass::Array; using InputElem = cutlass::Array; using OutputElem = cutlass::Array; using ComputeElem = cutlass::Array; @@ -877,7 +921,7 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted template void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows, - OutputType* reduced_unpermuted_output, T const* bias, float const* scales, + OutputType* reduced_unpermuted_output, GemmOutputType const* bias, float const* scales, int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int64_t const num_rows, int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) @@ -887,7 +931,7 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro // Only add bias on rank 0 for tensor parallelism bool const is_rank_0 = parallelism_config.tp_rank == 0; - T const* bias_ptr = is_rank_0 ? bias : nullptr; + GemmOutputType const* bias_ptr = is_rank_0 ? bias : nullptr; bool const check_finished = num_valid_ptr != nullptr; @@ -918,9 +962,9 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro // ============================== Gated Activation ================================= constexpr static int ACTIVATION_THREADS_PER_BLOCK = 256; -template class ActFn> +template class ActFn, class OutputType = T> __global__ void doGatedActivationKernel( - T* output, T const* gemm_result, int64_t const* num_valid_tokens_ptr, int64_t inter_size) + T* output, OutputType const* gemm_result, int64_t const* num_valid_tokens_ptr, int64_t inter_size) { int64_t const tid = threadIdx.x; int64_t const token = blockIdx.x; @@ -955,25 +999,28 @@ __global__ void doGatedActivationKernel( } } -template -void doGatedActivation(T* output, T const* gemm_result, int64_t const* num_valid_tokens_ptr, int64_t inter_size, - int64_t num_tokens, ActivationType activation_type, cudaStream_t stream) +template +void doGatedActivation(T* output, OutputType const* gemm_result, int64_t const* num_valid_tokens_ptr, + int64_t inter_size, int64_t num_tokens, ActivationType activation_type, cudaStream_t stream) { int64_t const blocks = num_tokens; int64_t const threads = ACTIVATION_THREADS_PER_BLOCK; // TODO For some reason Volta fails on GELU_taylor here with Warp Illegal Instruction. - auto* fn = activation_type == ActivationType::Swiglu ? &doGatedActivationKernel - : &doGatedActivationKernel; + + auto* fn = activation_type == ActivationType::Swiglu + ? &doGatedActivationKernel + : &doGatedActivationKernel; fn<<>>(output, gemm_result, num_valid_tokens_ptr, inter_size); } // ============================== Activation ================================= -template class ActFn> -__global__ void doActivationKernel(T* output, HopperGroupedGemmInput::OutputTypeAdaptor_t const* gemm_result, - float const* fp8_quant, T const* bias_ptr, int64_t const* total_rows_before_expert_, int num_experts, - int64_t inter_size, bool gated) +template class ActFn, class GemmOutputType> +__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, + GemmOutputType const* bias_ptr, int64_t const* total_rows_before_expert_, int num_experts, int64_t inter_size, + bool gated) + { int64_t const tid = threadIdx.x; int64_t const token = blockIdx.x; @@ -1007,7 +1054,7 @@ __global__ void doActivationKernel(T* output, HopperGroupedGemmInput::OutputType / std::min(cutlass::sizeof_bits::value, cutlass::sizeof_bits>::value); - using BiasElem = cutlass::Array; + using BiasElem = cutlass::Array; using GemmResultElem = cutlass::Array, ACTIVATION_ELEM_PER_THREAD>; using OutputElem = cutlass::Array; using ComputeElem = cutlass::Array; @@ -1046,21 +1093,21 @@ __global__ void doActivationKernel(T* output, HopperGroupedGemmInput::OutputType } } -template -void doActivation(T* output, HopperGroupedGemmInput::OutputTypeAdaptor_t const* gemm_result, float const* fp8_quant, - T const* bias, int64_t const* total_rows_before_expert_, int num_experts, int64_t inter_size, int64_t num_tokens, +template > +void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, GemmOutputType const* bias, + int64_t const* total_rows_before_expert_, int num_experts, int64_t inter_size, int64_t num_tokens, ActivationType activation_type, cudaStream_t stream) { int64_t const blocks = num_tokens; int64_t const threads = ACTIVATION_THREADS_PER_BLOCK; auto fn_list = std::array{ - &doActivationKernel, // Gelu - &doActivationKernel, // Relu - &doActivationKernel, // Silu - &doActivationKernel, // Swiglu - &doActivationKernel, // Geglu - &doActivationKernel // Identity + &doActivationKernel, // Gelu + &doActivationKernel, // Relu + &doActivationKernel, // Silu + &doActivationKernel, // Swiglu + &doActivationKernel, // Geglu + &doActivationKernel // Identity }; auto fn = fn_list[static_cast(activation_type)]; fn<<>>(output, gemm_result, fp8_quant, bias, total_rows_before_expert_, num_experts, @@ -1090,7 +1137,8 @@ std::vector CutlassMoeFCRunner::getWo size_t num_softmax_outs = 0; bool using_hopper = moe_gemm_runner_.supportsHopperSpecialisation(); - size_t const gemm_output_dtype = using_hopper ? sizeof(HopperGemmOutputType) : sizeof(T); + + size_t const gemm_output_dtype = sizeof(HopperGemmOutputType); bool const is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) @@ -1108,7 +1156,9 @@ std::vector CutlassMoeFCRunner::getWo size_t const fc1_result_size = interbuf_elems * sizeof(T); // Acitvation quantizes so back to sizeof(T) size_t const sorter_size = CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts); size_t const fc2_result_size = permuted_elems * gemm_output_dtype; // May be an intermediate type for quantization + size_t const hopper_size = using_hopper ? HopperGroupedGemmInput::workspaceSize(num_experts_per_node) : 0; + size_t const gemm_workspace_size = moe_gemm_runner_.getMaxWorkspaceSize(num_experts_per_node); // We do some overlapping of the large workspace buffers. Although we could overlap some of the other buffers, they @@ -1123,6 +1173,8 @@ std::vector CutlassMoeFCRunner::getWo overlapped_gemm1_gemm2_inputs = std::max(overlapped_gemm1_gemm2_inputs, fc1_result_size); } + size_t const alpha_scale_ptr_array_size = num_experts_per_node * sizeof(float**); + // if we have glu_inter we overlap it with fc2_result, otherwise we use fc1_result by itself size_t overlapped_gemm1_gemm2_outputs = fc1_result_size; if (glu_inter_elems > 0) @@ -1142,7 +1194,9 @@ std::vector CutlassMoeFCRunner::getWo overlapped_gemm1_gemm2_inputs, // overlapped_gemm1_gemm2_outputs, // hopper_size, // - gemm_workspace_size}; + gemm_workspace_size, // + alpha_scale_ptr_array_size}; + return workspace; } @@ -1200,7 +1254,7 @@ void CutlassMoeFCRunner::configureWsPtrs(char bool const hopper_has_glu = gemm1_using_hopper && (mayHaveDifferentGEMMOutputType() || is_gated_activation); // We always use fused path if we can bool const non_hopper_has_glu = !gemm1_using_fused_moe && is_gated_activation; - bool const has_glu_inter_result = hopper_has_glu || non_hopper_has_glu; + bool const has_glu_inter_result = hopper_has_glu || non_hopper_has_glu || use_fp8; // Always 7, ignored if not needed glu_inter_result_ = has_glu_inter_result ? (T*) ws_sliced[7] : nullptr; @@ -1215,16 +1269,19 @@ void CutlassMoeFCRunner::configureWsPtrs(char { hopper_grouped_gemm_input_.configureWorkspace(ws_sliced[8], num_experts_per_node, ws_sliced[9], ws_sizes[9]); } + + alpha_scale_ptr_array_ = reinterpret_cast(ws_sliced[10]); } template void CutlassMoeFCRunner::gemm1(MoeGemmRunner& gemm_runner, T const* const input, T* const output, void* const intermediate_result, int64_t const* const total_rows_before_expert, HopperGroupedGemmInput hopper_input_template, - WeightType const* const fc1_expert_weights, T const* const fc1_expert_biases, - int64_t const* const num_valid_tokens_ptr, T const* const fc1_int_scales, float const* const fc1_fp8_dequant, - float const* const fc2_fp8_quant, int64_t const expanded_num_rows, int64_t const hidden_size, - int64_t const inter_size, int const num_experts_per_node, ActivationType fc1_activation_type, cudaStream_t stream, + WeightType const* const fc1_expert_weights, HopperGemmOutputType const* const fc1_expert_biases, + int64_t const* const num_valid_tokens_ptr, HopperGemmOutputType const* const fc1_int_scales, + float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant, int64_t const expanded_num_rows, + int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, + ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config) { bool const using_hopper_gemm1 = gemm_runner.isHopperSpecialised(config); @@ -1247,7 +1304,7 @@ void CutlassMoeFCRunner::gemm1(MoeGemmRunner< sync_check_cuda_error(); gemm_runner.moeGemm(input, nullptr, nullptr, nullptr, total_rows_before_expert, hopper_input, expanded_num_rows, - fc1_out_size, hidden_size, num_experts_per_node, false, stream, config); + fc1_out_size, hidden_size, num_experts_per_node, false, alpha_scale_ptr_array, stream, config); sync_check_cuda_error(); @@ -1256,13 +1313,30 @@ void CutlassMoeFCRunner::gemm1(MoeGemmRunner< sync_check_cuda_error(); } + else if (use_fp8) + { + alpha_scale_ptr_array + = computeFP8DequantScale(alpha_scale_ptr_array, num_experts_per_node, fc1_fp8_dequant, stream); + + gemm_runner.moeGemmBiasAct(input, fc1_expert_weights, nullptr, nullptr, + reinterpret_cast(intermediate_result), total_rows_before_expert, + HopperGroupedGemmInput{}, expanded_num_rows, fc1_out_size, hidden_size, num_experts_per_node, + ActivationType::Identity, false, alpha_scale_ptr_array, stream, config); + + doActivation(output, static_cast(intermediate_result), fc2_fp8_quant, + fc1_expert_biases, total_rows_before_expert, num_experts_per_node, inter_size, expanded_num_rows, + fc1_activation_type, stream); + + sync_check_cuda_error(); + } else if (!is_gated_activation) { TLLM_CHECK(!use_fused_moe); TLLM_CHECK(!config.is_sm90); - gemm_runner.moeGemmBiasAct(input, fc1_expert_weights, fc1_int_scales, fc1_expert_biases, output, - total_rows_before_expert, HopperGroupedGemmInput{}, expanded_num_rows, fc1_out_size, hidden_size, - num_experts_per_node, fc1_activation_type, false, stream, config); + gemm_runner.moeGemmBiasAct(input, fc1_expert_weights, fc1_int_scales, fc1_expert_biases, + reinterpret_cast(output), total_rows_before_expert, HopperGroupedGemmInput{}, + expanded_num_rows, fc1_out_size, hidden_size, num_experts_per_node, fc1_activation_type, false, + alpha_scale_ptr_array, stream, config); sync_check_cuda_error(); } @@ -1275,16 +1349,17 @@ void CutlassMoeFCRunner::gemm1(MoeGemmRunner< // Run the GEMM with activation function overridden with `Identity`, we do the activation separately ActivationType activation_type = use_fused_moe ? fc1_activation_type : ActivationType::Identity; T* gemm_result = use_fused_moe ? output : static_cast(intermediate_result); - gemm_runner.moeGemmBiasAct(input, fc1_expert_weights, fc1_int_scales, fc1_expert_biases, gemm_result, - total_rows_before_expert, HopperGroupedGemmInput{}, expanded_num_rows, fc1_out_size, hidden_size, - num_experts_per_node, activation_type, use_fused_moe, stream, config); + gemm_runner.moeGemmBiasAct(input, fc1_expert_weights, fc1_int_scales, fc1_expert_biases, + reinterpret_cast(gemm_result), total_rows_before_expert, HopperGroupedGemmInput{}, + expanded_num_rows, fc1_out_size, hidden_size, num_experts_per_node, activation_type, use_fused_moe, + alpha_scale_ptr_array, stream, config); sync_check_cuda_error(); if (!use_fused_moe) { - doGatedActivation(output, static_cast(intermediate_result), num_valid_tokens_ptr, inter_size, - expanded_num_rows, fc1_activation_type, stream); + doGatedActivation(output, static_cast(intermediate_result), + num_valid_tokens_ptr, inter_size, expanded_num_rows, fc1_activation_type, stream); sync_check_cuda_error(); } @@ -1295,8 +1370,9 @@ template void CutlassMoeFCRunner::gemm2(MoeGemmRunner& gemm_runner, T const* const input, void* const output, int64_t const* const total_rows_before_expert, HopperGroupedGemmInput const hopper_input_template, WeightType const* const fc2_expert_weights, - T const* const fc2_int_scales, float const* const fc2_fp8_dequant, int64_t const expanded_num_rows, - int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, cudaStream_t stream, + HopperGemmOutputType const* const fc2_int_scales, float const* const fc2_fp8_dequant, + int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, + int const num_experts_per_node, float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config) { bool const using_hopper_gemm2 = gemm_runner.isHopperSpecialised(config); @@ -1308,9 +1384,15 @@ void CutlassMoeFCRunner::gemm2(MoeGemmRunner< static_cast(output), stream); sync_check_cuda_error(); } + else if (use_fp8) + { + alpha_scale_ptr_array + = computeFP8DequantScale(alpha_scale_ptr_array, num_experts_per_node, fc2_fp8_dequant, stream); + } - gemm_runner.moeGemm(input, fc2_expert_weights, fc2_int_scales, static_cast(output), total_rows_before_expert, - hopper_input, expanded_num_rows, hidden_size, inter_size, num_experts_per_node, false, stream, config); + gemm_runner.moeGemm(input, fc2_expert_weights, fc2_int_scales, static_cast(output), + total_rows_before_expert, hopper_input, expanded_num_rows, hidden_size, inter_size, num_experts_per_node, false, + alpha_scale_ptr_array, stream, config); } template @@ -1330,14 +1412,15 @@ void CutlassMoeFCRunner::runMoe(void const* i auto const* input_activations = static_cast(input_activations_void); auto const* fc1_expert_weights = static_cast(fc1_expert_weights_void); - auto const* fc1_expert_biases = static_cast(fc1_expert_biases_void); + auto const* fc1_expert_biases = reinterpret_cast(fc1_expert_biases_void); auto const* fc2_expert_weights = static_cast(fc2_expert_weights_void); - auto const* fc1_int_scales = static_cast(quant_params.fc1_weight_scales); - auto const* fc2_int_scales = static_cast(quant_params.fc2_weight_scales); + auto const* fc1_int_scales = reinterpret_cast(quant_params.fc1_weight_scales); + auto const* fc2_int_scales = reinterpret_cast(quant_params.fc2_weight_scales); + auto const* fc1_fp8_dequant = quant_params.dequant_fc1; auto const* fc2_fp8_quant = quant_params.quant_fc2; auto const* fc2_fp8_dequant = quant_params.dequant_fc2; - auto const* fc2_expert_biases = static_cast(fc2_expert_biases_void); + auto const* fc2_expert_biases = reinterpret_cast(fc2_expert_biases_void); auto* final_output = static_cast(final_output_void); auto* expert_scales = static_cast(expert_scales_void); @@ -1413,6 +1496,7 @@ void CutlassMoeFCRunner::runMoe(void const* i configureWsPtrs( workspace_ptr, num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, fc1_activation_type); + topkGatingSoftmaxKernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, source_rows_, num_rows, num_experts, k, start_expert, end_expert, stream); @@ -1442,30 +1526,22 @@ void CutlassMoeFCRunner::runMoe(void const* i Self::gemm1(moe_gemm_runner_, permuted_data_, fc1_result_, glu_inter_result_, total_rows_before_expert_, hopper_grouped_gemm_input_, fc1_expert_weights, fc1_expert_biases, num_valid_tokens_ptr, fc1_int_scales, fc1_fp8_dequant, fc2_fp8_quant, expanded_num_rows, hidden_size, inter_size, num_experts_per_node, - fc1_activation_type, stream, *gemm1_config_); + fc1_activation_type, alpha_scale_ptr_array_, stream, *gemm1_config_); sync_check_cuda_error(); Self::gemm2(moe_gemm_runner_, fc1_result_, fc2_result_, total_rows_before_expert_, hopper_grouped_gemm_input_, fc2_expert_weights, fc2_int_scales, fc2_fp8_dequant, expanded_num_rows, hidden_size, inter_size, - num_experts_per_node, stream, *gemm2_config_); + num_experts_per_node, alpha_scale_ptr_array_, stream, *gemm2_config_); sync_check_cuda_error(); bool const using_hopper_gemm2 = moe_gemm_runner_.isHopperSpecialised(*gemm2_config_); - if (using_hopper_gemm2) - { - finalizeMoeRoutingKernelLauncher( - static_cast(fc2_result_), final_output, fc2_expert_biases, expert_scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, hidden_size, k, - num_valid_tokens_ptr, parallelism_config, normalization_mode, stream); - } - else - { - finalizeMoeRoutingKernelLauncher(static_cast(fc2_result_), final_output, - fc2_expert_biases, expert_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, - hidden_size, k, num_valid_tokens_ptr, parallelism_config, normalization_mode, stream); - } + + finalizeMoeRoutingKernelLauncher( + static_cast(fc2_result_), final_output, fc2_expert_biases, expert_scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, hidden_size, k, num_valid_tokens_ptr, + parallelism_config, normalization_mode, stream); sync_check_cuda_error(); } @@ -1686,11 +1762,14 @@ std::vector GemmProfilerBackend::getProfilerWorkspaces(int maxM, bool is { hopper_workspace_size = HopperGroupedGemmInput::workspaceSize(num_experts_per_node); } + + size_t alpha_scale_ptr_array_size = num_experts_per_node * sizeof(float**); size_t gemm_workspace_size = mInterface->getGemmWorkspaceSize(num_experts_per_node); + // NOTICE: put gemm_workspace_size at last return {total_rows_before_expert_size, // Put this first because we initialise this but nothing else input_size, output_size, intermediate_size, weights, bias, quant_1, quant_2, quant_3, quant_4, - hopper_workspace_size, gemm_workspace_size}; + hopper_workspace_size, alpha_scale_ptr_array_size, gemm_workspace_size}; } size_t GemmProfilerBackend::getWorkspaceSize(int maxM) @@ -1702,6 +1781,7 @@ size_t GemmProfilerBackend::getWorkspaceSize(int maxM) void GemmProfilerBackend::runProfiler( int original_num_tokens, Config const& tactic, char* workspace_ptr_char, cudaStream_t const& stream) { + int8_t* workspace_ptr = reinterpret_cast(workspace_ptr_char); auto workspaces = getProfilerWorkspaces(original_num_tokens, tactic.is_sm90); auto ws_it = workspaces.begin(); @@ -1731,8 +1811,8 @@ void GemmProfilerBackend::runProfiler( void const* scale_3 = getNext(); void const* scale_4 = getNext(); void* hopper_workspace = getNext(); + float const** alpha_scale_ptr_array = reinterpret_cast(getNext()); void* gemm_workspace = getNext(); // NOTE we rely on this being last below (i.e. workspaces.back()) - int64_t expanded_num_tokens = original_num_tokens * mK; int64_t num_experts_per_node = mNumExpertsPerNode; @@ -1775,6 +1855,7 @@ void GemmProfilerBackend::runProfiler( mExpertInterSize, // num_experts_per_node, // mActivationType, // + alpha_scale_ptr_array, // stream, // tactic); } @@ -1792,6 +1873,7 @@ void GemmProfilerBackend::runProfiler( mExpertHiddenSize, // mExpertInterSize, // num_experts_per_node, // + alpha_scale_ptr_array, // stream, // tactic); } diff --git a/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h b/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h index 837e5a885..f7df62266 100644 --- a/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h @@ -173,14 +173,14 @@ class CutlassMoeFCRunnerInterface int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, ActivationType fc1_activation_type, - cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config) + float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config) = 0; virtual void gemm2(void const* const input, void* const output, int64_t const* const total_rows_before_expert, HopperGroupedGemmInput hopper_input_template, void const* const fc2_expert_weights, void const* const fc2_int_scales, float const* const fc2_fp8_dequant, int64_t const expanded_num_rows, - int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, cudaStream_t stream, - cutlass_extensions::CutlassGemmConfig config) + int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, + float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config) = 0; virtual size_t getGemmWorkspaceSize(int num_experts) const = 0; @@ -193,12 +193,12 @@ class CutlassMoeFCRunnerInterface // Avoid making several duplicates of this class. template class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { using Self = CutlassMoeFCRunner; - + using HopperGemmOutputType = typename HopperGroupedGemmInput::OutputTypeAdaptor_t; /* Internal gemm output type*/ public: CutlassMoeFCRunner() = default; @@ -242,16 +242,18 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface static void gemm1(MoeGemmRunner& gemm_runner, T const* const input, T* const output, void* const intermediate_result, int64_t const* const total_rows_before_expert, HopperGroupedGemmInput const hopper_input_template, WeightType const* const fc1_expert_weights, - T const* const fc1_expert_biases, int64_t const* const num_valid_tokens_ptr, T const* const fc1_int_scales, - float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant, int64_t const expanded_num_rows, - int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, - ActivationType fc1_activation_type, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config); + HopperGemmOutputType const* const fc1_expert_biases, int64_t const* const num_valid_tokens_ptr, + HopperGemmOutputType const* const fc1_int_scales, float const* const fc1_fp8_dequant, + float const* const fc2_fp8_quant, int64_t const expanded_num_rows, int64_t const hidden_size, + int64_t const inter_size, int const num_experts_per_node, ActivationType fc1_activation_type, + float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config); static void gemm2(MoeGemmRunner& gemm_runner, T const* const input, void* const output, int64_t const* const total_rows_before_expert, HopperGroupedGemmInput hopper_input_template, - WeightType const* const fc2_expert_weights, T const* const fc2_int_scales, float const* const fc2_fp8_dequant, - int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config); + WeightType const* const fc2_expert_weights, HopperGemmOutputType const* const fc2_int_scales, + float const* const fc2_fp8_dequant, int64_t const expanded_num_rows, int64_t const hidden_size, + int64_t const inter_size, int const num_experts_per_node, float const** alpha_scale_ptr_array, + cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config); // Overrides to allow us to forward on to the internal functions with the pointers using the correct type void gemm1(void const* const input, void* const output, void* const intermediate_result, @@ -260,25 +262,25 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, ActivationType fc1_activation_type, - cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config) override + float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config) override { return Self::gemm1(moe_gemm_runner_, static_cast(input), static_cast(output), intermediate_result, total_rows_before_expert, hopper_input_template, static_cast(fc1_expert_weights), - static_cast(fc1_expert_weights), num_valid_tokens_ptr, static_cast(fc1_int_scales), - fc1_fp8_dequant, fc2_fp8_quant, expanded_num_rows, hidden_size, inter_size, num_experts_per_node, - fc1_activation_type, stream, config); + static_cast(fc1_expert_weights), num_valid_tokens_ptr, + static_cast(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant, expanded_num_rows, + hidden_size, inter_size, num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, stream, config); } void gemm2(void const* const input, void* const output, int64_t const* const total_rows_before_expert, HopperGroupedGemmInput hopper_input_template, void const* const fc2_expert_weights, void const* const fc2_int_scales, float const* const fc2_fp8_dequant, int64_t const expanded_num_rows, - int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, cudaStream_t stream, - cutlass_extensions::CutlassGemmConfig config) override + int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, + float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config) override { - return Self::gemm2(moe_gemm_runner_, static_cast(input), output, total_rows_before_expert, - hopper_input_template, static_cast(fc2_expert_weights), - static_cast(fc2_int_scales), fc2_fp8_dequant, expanded_num_rows, hidden_size, inter_size, - num_experts_per_node, stream, config); + return Self::gemm2(moe_gemm_runner_, static_cast(input), static_cast(output), + total_rows_before_expert, hopper_input_template, static_cast(fc2_expert_weights), + static_cast(fc2_int_scales), fc2_fp8_dequant, expanded_num_rows, hidden_size, + inter_size, num_experts_per_node, alpha_scale_ptr_array, stream, config); } virtual size_t getGemmWorkspaceSize(int num_experts) const override @@ -287,7 +289,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface } private: - using HopperGemmOutputType = typename HopperGroupedGemmInput::OutputTypeAdaptor_t; + static constexpr bool use_fp8 = std::is_same::value || std::is_same::value; static void computeTotalRowsBeforeExpert(int const* sorted_indices, int const total_indices, int const num_experts, int64_t* total_rows_before_expert, cudaStream_t stream); @@ -305,7 +307,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface bool mayHaveDifferentGEMMOutputType() const { // We just check if its supported because we need to know when calculating workspace size - return moe_gemm_runner_.supportsHopperSpecialisation() && !std::is_same_v; + return ( + (moe_gemm_runner_.supportsHopperSpecialisation() && !std::is_same_v) || use_fp8); } CubKeyValueSorter sorter_; @@ -327,6 +330,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface void* glu_inter_result_{}; void* fc2_result_{}; T* fc1_result_{}; + float const** alpha_scale_ptr_array_ = nullptr; HopperGroupedGemmInput hopper_grouped_gemm_input_; }; diff --git a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h index 1e1e296b3..43272e240 100644 --- a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h +++ b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h @@ -247,7 +247,7 @@ class GemmPluginProfiler protected: virtual void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) = 0; - virtual void computeTmpSize(int maxM, int n, int k) = 0; + virtual void computeTmpSize(size_t maxM, size_t n, size_t k) = 0; virtual bool checkTactic(int m, int n, int k, Config const& tactic) const { diff --git a/cpp/tensorrt_llm/plugins/fp8RowwiseGemmPlugin/fp8RowwiseGemmPlugin.cpp b/cpp/tensorrt_llm/plugins/fp8RowwiseGemmPlugin/fp8RowwiseGemmPlugin.cpp index e5b0f985a..743ec93e5 100644 --- a/cpp/tensorrt_llm/plugins/fp8RowwiseGemmPlugin/fp8RowwiseGemmPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/fp8RowwiseGemmPlugin/fp8RowwiseGemmPlugin.cpp @@ -93,7 +93,7 @@ int Fp8RowwiseGemmPluginProfiler::getMaxProfileM() const return 16384; } -void Fp8RowwiseGemmPluginProfiler::computeTmpSize(int maxM, int n, int k) +void Fp8RowwiseGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k) { std::vector workspaces = { maxM * k * getBytePerElement(nvinfer1::DataType::kFP8), // A diff --git a/cpp/tensorrt_llm/plugins/fp8RowwiseGemmPlugin/fp8RowwiseGemmPlugin.h b/cpp/tensorrt_llm/plugins/fp8RowwiseGemmPlugin/fp8RowwiseGemmPlugin.h index 187c2053c..f69cb81ac 100644 --- a/cpp/tensorrt_llm/plugins/fp8RowwiseGemmPlugin/fp8RowwiseGemmPlugin.h +++ b/cpp/tensorrt_llm/plugins/fp8RowwiseGemmPlugin/fp8RowwiseGemmPlugin.h @@ -44,7 +44,7 @@ class Fp8RowwiseGemmPluginProfiler : public GemmPluginProfiler getTactics(int m, int n, int k) const override; diff --git a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp index ac197199b..5d97a8412 100644 --- a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp @@ -112,7 +112,7 @@ bool CublasLtGemmPluginProfiler::checkTactic(int m, int n, int k, Config const& return checkResult; } -void CublasLtGemmPluginProfiler::computeTmpSize(int maxM, int n, int k) +void CublasLtGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k) { size_t dataSize = typeSize(mType); size_t outputDataSize = typeSize(mOutputType); diff --git a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h index 6b552b635..bf2b5540f 100644 --- a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h +++ b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h @@ -56,7 +56,7 @@ class CublasLtGemmPluginProfiler protected: void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) override; - void computeTmpSize(int maxM, int n, int k) override; + void computeTmpSize(size_t maxM, size_t n, size_t k) override; bool checkTactic(int m, int n, int k, Config const& tactic) const override; diff --git a/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cpp b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cpp index 92ff693ca..3b2142c84 100644 --- a/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cpp @@ -87,7 +87,7 @@ int GemmSwigluPluginProfiler::getMaxProfileM() const return 32768; } -void GemmSwigluPluginProfiler::computeTmpSize(int maxM, int n, int k) +void GemmSwigluPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k) { std::vector workspaces = { maxM * k * getBytePerElement(mType), // A diff --git a/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.h b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.h index 744f1fc04..766e59aad 100644 --- a/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.h +++ b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.h @@ -44,7 +44,7 @@ class GemmSwigluPluginProfiler : public GemmPluginProfiler= 90, - "MoE FP8 is not supported for architectures less than SM90"); + TLLM_CHECK_WITH_INFO(mType != DataType::kFP8 || tensorrt_llm::common::getSMVersion() >= 89, + "MoE FP8 is not supported for architectures less than SM89"); if (mWeightType == nvinfer1::DataType::kINT8 && mQuantMode.hasInt4Weights()) { @@ -226,6 +226,7 @@ void MixtureOfExpertsPlugin::init() mType, mWeightType, mQuantMode}; mGemmId2 = GemmIDMoe{2, mNumExperts, mK, mParallelismConfig, mExpertHiddenSize, mExpertInterSize, mActivationType, mType, mWeightType, mQuantMode}; + mGemmProfiler->setMaxProfileM(16384 * mNumExperts / mK); } // IPluginV2DynamicExt Methods @@ -689,7 +690,7 @@ char const* MixtureOfExpertsPluginCreator::getPluginNamespace() const noexcept return mNamespace.c_str(); } -void MixtureOfExpertsGemmProfiler::computeTmpSize(int maxM, int n, int k) +void MixtureOfExpertsGemmProfiler::computeTmpSize(size_t maxM, size_t n, size_t k) { checkInit(); size_t bytes = backend.getWorkspaceSize(maxM); diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h index fbf6fc231..5f493e887 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h @@ -334,10 +334,20 @@ class MixtureOfExpertsGemmProfiler init_backend = false; } + void setMaxProfileM(int maxProfileM) + { + mMaxProfileM = maxProfileM; + } + + virtual int getMaxProfileM() const override + { + return mMaxProfileM; + } + protected: using Config = tensorrt_llm::cutlass_extensions::CutlassGemmConfig; void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) override; - void computeTmpSize(int maxM, int n, int k) override; + void computeTmpSize(size_t maxM, size_t n, size_t k) override; std::vector getTactics(int m, int n, int k) const override; void initTmpData(int maxM, int n, int k, char* workspace, size_t size, cudaStream_t stream) override; @@ -345,6 +355,9 @@ class MixtureOfExpertsGemmProfiler bool init_backend = false; tensorrt_llm::kernels::GemmProfilerBackend backend{}; + +private: + int mMaxProfileM = 0; }; class MixtureOfExpertsPluginCreator : public nvinfer1::IPluginCreator diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp index b709cf207..110a07000 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp @@ -42,7 +42,6 @@ AllreducePlugin::AllreducePlugin(std::set group, nvinfer1::DataType type, A , mStrategy(strategy) , mConfig(config) , mOp(op) - , mCounter(counter) , mEps(eps) , mAffine(affine) , mBias(bias) @@ -65,7 +64,6 @@ AllreducePlugin::AllreducePlugin(void const* data, size_t length) } read(d, mConfig); read(d, mOp); - read(d, mCounter); read(d, mEps); read(d, mAffine); read(d, mBias); @@ -256,17 +254,17 @@ int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfe { case AllReduceStrategyType::NCCL: { - TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d layer %d: NCCL", rank, mCounter); + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL", rank); break; } case AllReduceStrategyType::ONESHOT: { - TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d layer %d: ONESHOT", rank, mCounter); + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: ONESHOT", rank); break; } case AllReduceStrategyType::TWOSHOT: { - TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d layer %d: TWOSHOT", rank, mCounter); + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: TWOSHOT", rank); break; } default: break; @@ -305,10 +303,16 @@ int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfe else { auto const tpSize = mGroup.size(); - auto const tpRank = rank % tpSize; + int tpRank = 0; + for (auto const& currentRank : mGroup) + { + if (rank == currentRank) + break; + ++tpRank; + } auto params = tensorrt_llm::kernels::AllReduceParams::deserialize( - reinterpret_cast(inputs[1]), tpSize, tpRank, mCounter); + reinterpret_cast(const_cast(inputs[1])), tpSize, tpRank); params.local_output_buffer_ptr = outputs[0]; params.local_input_buffer_ptr = inputs[0]; @@ -593,7 +597,7 @@ void AllreducePlugin::terminate() noexcept {} size_t AllreducePlugin::getSerializationSize() const noexcept { return sizeof(int) * mGroup.size() + sizeof(mType) + sizeof(mStrategy) + sizeof(mConfig) + sizeof(mOp) - + sizeof(mCounter) + sizeof(mEps) + sizeof(mAffine) + sizeof(mBias); + + sizeof(mEps) + sizeof(mAffine) + sizeof(mBias); } void AllreducePlugin::serialize(void* buffer) const noexcept @@ -603,7 +607,6 @@ void AllreducePlugin::serialize(void* buffer) const noexcept write(d, mStrategy); write(d, mConfig); write(d, mOp); - write(d, mCounter); write(d, mEps); write(d, mAffine); write(d, mBias); diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h index f293c4e7c..b4fc36b96 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h @@ -83,7 +83,6 @@ class AllreducePlugin : public BasePlugin kernels::AllReduceStrategyConfig mConfig; kernels::AllReduceFusionOp mOp; float mEps; - int32_t mCounter; std::shared_ptr mNcclComm; int8_t mAffine; int8_t mBias; diff --git a/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.cpp b/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.cpp index ce58f9981..fe7d49056 100644 --- a/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.cpp @@ -51,7 +51,7 @@ void SmoothQuantGemmPluginProfiler::runTactic(int m, int n, int k, SmoothQuantGe aTmp, bTmp, mQuantMode, alphaColTmp, alphaRowTmp, cTmp, m, n, k, tactic, workspaceTmp, wsSize, stream); } -void SmoothQuantGemmPluginProfiler::computeTmpSize(int maxM, int n, int k) +void SmoothQuantGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k) { std::vector workspaces = { maxM * k * sizeof(int8_t), // A diff --git a/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.h b/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.h index 702f1effb..3cabf5580 100644 --- a/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.h +++ b/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.h @@ -46,7 +46,7 @@ class SmoothQuantGemmPluginProfiler : public GemmPluginProfiler getTactics(int m, int n, int k) const override; diff --git a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp index db4fbeb7f..ad4e3973f 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp @@ -74,7 +74,7 @@ void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic(int m, int n, int k, tactic, workspacePtr, wsSize, stream); } -void WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(int maxM, int n, int k) +void WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k) { // Quantized weights are packed in FP16 format (INT4*4 -> FP16) int const originalN = n * FP16_INT4_RATIO; diff --git a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h index 529c3e606..a38a589c4 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h +++ b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h @@ -63,7 +63,7 @@ class WeightOnlyGroupwiseQuantGemmPluginProfiler protected: void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) override; - void computeTmpSize(int maxM, int n, int k) override; + void computeTmpSize(size_t maxM, size_t n, size_t k) override; std::vector getTactics(int m, int n, int k) const override; diff --git a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp index eafea714f..74835a52b 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp @@ -59,7 +59,7 @@ void WeightOnlyQuantGemmPluginProfiler::runTactic(int m, int n, int k, } } -void WeightOnlyQuantGemmPluginProfiler::computeTmpSize(int maxM, int n, int k) +void WeightOnlyQuantGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k) { int const originalN = n * getWeightTypeMultiplier(mWeightTypeId); std::vector workspaces = { diff --git a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h index 2683846b3..7c65e6623 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h +++ b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h @@ -69,7 +69,7 @@ class WeightOnlyQuantGemmPluginProfiler : public GemmPluginProfiler getTactics(int m, int n, int k) const override; diff --git a/cpp/tensorrt_llm/runtime/ipcUtils.cpp b/cpp/tensorrt_llm/runtime/ipcUtils.cpp index b15bb92e7..914a7211c 100644 --- a/cpp/tensorrt_llm/runtime/ipcUtils.cpp +++ b/cpp/tensorrt_llm/runtime/ipcUtils.cpp @@ -154,9 +154,12 @@ AllReduceBuffers::AllReduceBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWi mIpcMemoryHandles.emplace_back(size, manager, worldConfig); } - mAllReduceCommPtrs = BufferManager::cpu( - ITensor::makeShape({static_cast(mIpcMemoryHandles.size()) * tpSize}), nvinfer1::DataType::kINT64); + mAllReduceCommPtrs + = BufferManager::cpu(ITensor::makeShape({static_cast(mIpcMemoryHandles.size()) * tpSize + 1}), + nvinfer1::DataType::kINT64); auto commPtrs = BufferRange(*mAllReduceCommPtrs); + auto const flagPtr = static_cast(mAllReduceCommPtrs->data(mAllReduceCommPtrs->getSize() - 1)); + *flagPtr = 0; for (std::size_t memIdx = 0; memIdx < mIpcMemoryHandles.size(); memIdx++) { diff --git a/cpp/tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/kernels/mixtureOfExpertsTest.cu index ecb426402..0a0e574bc 100644 --- a/cpp/tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/kernels/mixtureOfExpertsTest.cu @@ -135,7 +135,7 @@ protected: static_assert(!FP8, "FP8 Tests enabled on unsupported CUDA version"); #endif bool should_skip_no_device = mDeviceCount <= 0; - bool should_skip_unsupported_fp8 = getSMVersion() < 90 && FP8; + bool should_skip_unsupported_fp8 = getSMVersion() < 89 && FP8; return should_skip_no_device || should_skip_unsupported_fp8; } diff --git a/cpp/tests/resources/scripts/test_cpp.py b/cpp/tests/resources/scripts/test_cpp.py index 6eebf39af..33ed71610 100755 --- a/cpp/tests/resources/scripts/test_cpp.py +++ b/cpp/tests/resources/scripts/test_cpp.py @@ -192,10 +192,35 @@ def run_tests(build_dir: _pl.Path, timeout=test_timeout) if run_gpt: - run_benchmarks(python_exe=python_exe, + run_benchmarks(model_name="gpt", + python_exe=python_exe, root_dir=root_dir, build_dir=build_dir, - resources_dir=resources_dir) + resources_dir=resources_dir, + model_cache=model_cache, + test_gpt_session_benchmark=True, + batching_types=["IFB", "V1"], + api_types=["gptManager", "executor"]) + elif run_t5: + run_benchmarks(model_name="t5", + python_exe=python_exe, + root_dir=root_dir, + build_dir=build_dir, + resources_dir=resources_dir, + model_cache=model_cache, + test_gpt_session_benchmark=False, + batching_types=["IFB"], + api_types=["executor"]) + elif run_bart: + run_benchmarks(model_name="bart", + python_exe=python_exe, + root_dir=root_dir, + build_dir=build_dir, + resources_dir=resources_dir, + model_cache=model_cache, + test_gpt_session_benchmark=False, + batching_types=["IFB"], + api_types=["executor"]) else: _log.info("Skipping benchmarks") @@ -562,8 +587,10 @@ def run_multi_gpu_tests(build_dir: _pl.Path, timeout=1500): run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500) -def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, - resources_dir: _pl.Path): +def run_benchmarks(model_name: str, python_exe: str, root_dir: _pl.Path, + build_dir: _pl.Path, resources_dir: _pl.Path, + model_cache: str, test_gpt_session_benchmark: bool, + batching_types: list[str], api_types: list[str]): # At this moment, CI env might not installed tensorrt_llm before, so tensorrt_llm module might not be available. import pathlib @@ -583,19 +610,45 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, run_command(make_benchmarks, cwd=build_dir, timeout=300) benchmark_exe_dir = build_dir / "benchmarks" - gpt_engine_dir = resources_dir / "models" / "rt_engine" / "gpt2" - - input_file = 'input_tokens.npy' - model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF) - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS) - model_spec_obj.use_gpt_plugin() + if model_name == "gpt": + model_engine_dir = resources_dir / "models" / "rt_engine" / "gpt2" + tokenizer_dir = resources_dir / "models" / "gpt2" + elif model_name in ('bart', 't5'): + if model_name == "t5": + hf_repo_name = "t5-small" + elif model_name == "bart": + hf_repo_name = "bart-large-cnn" + model_engine_dir = resources_dir / "models" / "enc_dec" / "trt_engines" / hf_repo_name + tokenizer_dir = model_cache + "/" + hf_repo_name + model_engine_path = model_engine_dir / "1-gpu" / "float16" / "decoder" + encoder_model_engine_path = model_engine_dir / "1-gpu" / "float16" / "encoder" + model_name = "enc_dec" + else: + _log.info( + f"run_benchmark test does not support {model_name}. Skipping benchmarks" + ) + return NotImplementedError + + if test_gpt_session_benchmark: + if model_name == "gpt": + input_file = 'input_tokens.npy' + model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF) + model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS) + model_spec_obj.use_gpt_plugin() + model_engine_path = model_engine_dir / model_spec_obj.get_model_path( + ) / "tp1-pp1-gpu" + else: + _log.info( + f"gptSessionBenchmark test does not support {model_name}. Skipping benchmarks" + ) + return NotImplementedError - benchmark = [ - str(benchmark_exe_dir / "gptSessionBenchmark"), "--engine_dir", - str(gpt_engine_dir / model_spec_obj.get_model_path() / "tp1-pp1-gpu"), - "--batch_size", "8", "--input_output_len", "10,20", "--duration", "10" - ] - run_command(benchmark, cwd=root_dir, timeout=600) + benchmark = [ + str(benchmark_exe_dir / "gptSessionBenchmark"), "--engine_dir", + str(model_engine_path), "--batch_size", "8", "--input_output_len", + "10,20", "--duration", "10" + ] + run_command(benchmark, cwd=root_dir, timeout=600) prompt_datasets_args = [{ '--dataset-name': "cnn_dailymail", @@ -618,8 +671,14 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, max_input_lens = ["256", "20"] num_reqs = ["50", "10"] - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) - model_spec_obj.use_packed_input() + if model_name == "gpt": + input_file = 'input_tokens.npy' + model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF) + model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.use_gpt_plugin() + model_spec_obj.use_packed_input() + model_engine_path = model_engine_dir / model_spec_obj.get_model_path( + ) / "tp1-pp1-gpu" for prompt_ds_args, tokens_f, len, num_req in zip(prompt_datasets_args, token_files, @@ -630,7 +689,7 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, prepare_dataset = [ python_exe, str(benchmark_src_dir / "prepare_dataset.py"), "--tokenizer", - str(resources_dir / "models" / "gpt2"), "--output", + str(tokenizer_dir), "--output", str(data_dir / tokens_f), "dataset", "--max-input-len", len, "--num-requests", num_req ] @@ -643,50 +702,69 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, timeout=300, env={'HF_DATASETS_OFFLINE': '0'}) - batching_types = ["IFB", "V1"] - api_types = ["gptManager", "executor"] - for batching_type in batching_types: for api_type in api_types: benchmark = [ str(benchmark_exe_dir / "gptManagerBenchmark"), "--engine_dir", - str(gpt_engine_dir / model_spec_obj.get_model_path() / - "tp1-pp1-gpu"), "--type", + str(model_engine_path), "--type", str(batching_type), "--api", str(api_type), "--dataset", str(data_dir / tokens_f) ] + if model_name == "enc_dec": + benchmark += [ + "--encoder_engine_dir", + str(encoder_model_engine_path) + ] + run_command(benchmark, cwd=root_dir, timeout=600) req_rate_benchmark = benchmark + ["--request_rate", "100"] run_command(req_rate_benchmark, cwd=root_dir, timeout=600) concurrency_benchmark = benchmark + ["--concurrency", "30"] run_command(concurrency_benchmark, cwd=root_dir, timeout=600) - benchmark = [ - str(benchmark_exe_dir / "gptManagerBenchmark"), "--engine_dir", - str(gpt_engine_dir / model_spec_obj.get_model_path() / - "tp1-pp1-gpu"), "--type", "IFB", "--dataset", - str(data_dir / tokens_f), "--api", "executor", "--streaming" - ] - run_command(benchmark, cwd=root_dir, timeout=600) - - benchmark = [ - str(benchmark_exe_dir / "gptManagerBenchmark"), "--engine_dir", - str(gpt_engine_dir / model_spec_obj.get_model_path() / - "tp1-pp1-gpu"), "--type", "IFB", "--dataset", - str(data_dir / tokens_f), "--api", "gptManager", "--streaming" - ] - run_command(benchmark, cwd=root_dir, timeout=600) - - benchmark = [ - str(benchmark_exe_dir / "gptManagerBenchmark"), "--engine_dir", - str(gpt_engine_dir / model_spec_obj.get_model_path() / - "tp1-pp1-gpu"), "--type", "IFB", "--dataset", - str(data_dir / tokens_f), "--api", "gptManager", "--streaming", - "request_rate", "100", "--enable_exp_delays" - ] - run_command(benchmark, cwd=root_dir, timeout=600) + if "IFB" in batching_type and "executor" in api_types: + # executor streaming test + benchmark = [ + str(benchmark_exe_dir / "gptManagerBenchmark"), "--engine_dir", + str(model_engine_path), "--type", "IFB", "--dataset", + str(data_dir / tokens_f), "--api", "executor", "--streaming" + ] + if model_name == "enc_dec": + benchmark += [ + "--encoder_engine_dir", + str(encoder_model_engine_path) + ] + run_command(benchmark, cwd=root_dir, timeout=600) + + if "IFB" in batching_type and "gptManager" in api_type: + # gptManager streaming test + benchmark = [ + str(benchmark_exe_dir / "gptManagerBenchmark"), "--engine_dir", + str(model_engine_path), "--type", "IFB", "--dataset", + str(data_dir / tokens_f), "--api", "gptManager", "--streaming" + ] + if model_name == "enc_dec": + benchmark += [ + "--encoder_engine_dir", + str(encoder_model_engine_path) + ] + run_command(benchmark, cwd=root_dir, timeout=600) + + # gptManager streaming test with delay + benchmark = [ + str(benchmark_exe_dir / "gptManagerBenchmark"), "--engine_dir", + str(model_engine_path), "--type", "IFB", "--dataset", + str(data_dir / tokens_f), "--api", "gptManager", "--streaming", + "request_rate", "100", "--enable_exp_delays" + ] + if model_name == "enc_dec": + benchmark += [ + "--encoder_engine_dir", + str(encoder_model_engine_path) + ] + run_command(benchmark, cwd=root_dir, timeout=600) if __name__ == "__main__": diff --git a/examples/baichuan/requirements.txt b/examples/baichuan/requirements.txt index 1921dc595..4a89f5f91 100644 --- a/examples/baichuan/requirements.txt +++ b/examples/baichuan/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.15.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/bloom/requirements.txt b/examples/bloom/requirements.txt index 500f51b8b..84dd1f5cc 100644 --- a/examples/bloom/requirements.txt +++ b/examples/bloom/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/chatglm/convert_checkpoint.py b/examples/chatglm/convert_checkpoint.py index 9f0eae6da..0bfc7125c 100644 --- a/examples/chatglm/convert_checkpoint.py +++ b/examples/chatglm/convert_checkpoint.py @@ -1034,6 +1034,17 @@ def convert_hf_chatglm(hf_model: AutoModel, elif chatglm_version in ['chatglm2', 'chatglm3']: position_embedding_type = 'rope_gptj' + rotary_base = 10000.0 + rotary_embedding_scaling = None + if chatglm_version == 'chatglm2': + if hf_config.rope_ratio > 1: + rotary_embedding_scaling = { + 'type': 'linear', + 'factor': hf_config.rope_ratio + } + elif chatglm_version == 'chatglm3': + rotary_base *= hf_config.rope_ratio + config = { 'architecture': hf_config.architectures[0], 'dtype': args.dtype, @@ -1047,6 +1058,9 @@ def convert_hf_chatglm(hf_model: AutoModel, 'vocab_size': hf_config.vocab_size, 'position_embedding_type': position_embedding_type, 'max_position_embeddings': hf_config.max_position_embeddings, + 'rotary_pct': 0.5, + 'rotary_base': rotary_base, + 'rotary_scaling': rotary_embedding_scaling, 'hidden_act': hf_config.hidden_act, 'use_parallel_embedding': args.use_parallel_embedding, 'embedding_sharding_dim': args.embedding_sharding_dim, diff --git a/examples/chatglm/requirements.txt b/examples/chatglm/requirements.txt index d60d9e599..85b334c73 100644 --- a/examples/chatglm/requirements.txt +++ b/examples/chatglm/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.14.5 evaluate~=0.4.1 protobuf diff --git a/examples/dbrx/requirements.txt b/examples/dbrx/requirements.txt index b512ec930..79542d96c 100644 --- a/examples/dbrx/requirements.txt +++ b/examples/dbrx/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/dit/README.md b/examples/dit/README.md index 00bd7f306..e68bc3e03 100644 --- a/examples/dit/README.md +++ b/examples/dit/README.md @@ -13,7 +13,7 @@ The TensorRT-LLM DiT implementation can be found in [tensorrt_llm/models/dit/mod - [x] FP16 - [x] TP - [ ] FP8 -- [ ] CP +- [x] CP ## Usage @@ -69,3 +69,31 @@ python vae_decoder_trt.py --max_batch_size 8 # run mpirun -n 4 --allow-run-as-root python sample.py ``` + +### Context Parallel + +Context parallel can also be used to reduce latency and memory consumption on each GPU. + +``` +# build dit engine +python convert_checkpoint.py --cp_size 4 +trtllm-build --checkpoint_dir ./tllm_checkpoint/ --max_batch_size 8 --remove_input_padding disable --bert_attention_plugin disable +# build vae engine +python vae_decoder_trt.py --max_batch_size 8 +# run +mpirun -n 4 --allow-run-as-root python sample.py +``` + +### Combine Tensor Parallel and Context Parallel + +Tensor Parallel and Context Parallel can be used together to better balance latency and memory consumption. + +``` +# build dit engine +python convert_checkpoint.py --cp_size 2 --tp_size 2 +trtllm-build --checkpoint_dir ./tllm_checkpoint/ --max_batch_size 8 --remove_input_padding disable --bert_attention_plugin disable +# build vae engine +python vae_decoder_trt.py --max_batch_size 8 +# run +mpirun -n 4 --allow-run-as-root python sample.py +``` diff --git a/examples/dit/convert_checkpoint.py b/examples/dit/convert_checkpoint.py index b5aa26b25..a6dd3cd47 100644 --- a/examples/dit/convert_checkpoint.py +++ b/examples/dit/convert_checkpoint.py @@ -71,6 +71,10 @@ def parse_arguments(): type=int, default=1, help='N-way tensor parallelism size') + parser.add_argument('--cp_size', + type=int, + default=1, + help='Context parallelism size') parser.add_argument('--pp_size', type=int, default=1, @@ -221,7 +225,8 @@ def save_config(args): 'learn_sigma': args.learn_sigma, 'cfg_scale': args.cfg_scale, 'mapping': { - 'world_size': args.tp_size * args.pp_size, + 'world_size': args.cp_size * args.tp_size * args.pp_size, + 'cp_size': args.cp_size, 'tp_size': args.tp_size, 'pp_size': args.pp_size, } @@ -235,8 +240,9 @@ def covert_and_save(args, rank): if rank == 0: save_config(args) - mapping = Mapping(world_size=args.tp_size * args.pp_size, + mapping = Mapping(world_size=args.cp_size * args.tp_size * args.pp_size, rank=rank, + cp_size=args.cp_size, tp_size=args.tp_size, pp_size=args.pp_size) @@ -268,7 +274,7 @@ def execute(workers, func, args): def main(): print(tensorrt_llm.__version__) args = parse_arguments() - world_size = args.tp_size * args.pp_size + world_size = args.cp_size * args.tp_size * args.pp_size assert args.pp_size == 1, "PP is not supported yet." diff --git a/examples/dit/sample.py b/examples/dit/sample.py index 70cefd7ab..898a30966 100644 --- a/examples/dit/sample.py +++ b/examples/dit/sample.py @@ -37,11 +37,15 @@ def __init__(self, self.dtype = config['pretrained_config']['dtype'] rank = tensorrt_llm.mpi_rank() - world_size = tp = config['pretrained_config']['mapping'][ - 'world_size'] # Only support TP + world_size = config['pretrained_config']['mapping']['world_size'] + cp_size = config['pretrained_config']['mapping']['cp_size'] + tp_size = config['pretrained_config']['mapping']['tp_size'] + pp_size = config['pretrained_config']['mapping']['pp_size'] + assert pp_size == 1 self.mapping = tensorrt_llm.Mapping(world_size=world_size, rank=rank, - tp_size=tp, + cp_size=cp_size, + tp_size=tp_size, pp_size=1, gpus_per_node=args.gpus_per_node) diff --git a/examples/enc_dec/convert_checkpoint.py b/examples/enc_dec/convert_checkpoint.py index 7a0d64abe..d01058f30 100755 --- a/examples/enc_dec/convert_checkpoint.py +++ b/examples/enc_dec/convert_checkpoint.py @@ -136,6 +136,14 @@ def parse_t5_config_by_component(config, component, args): 'encoder', 'd_kv') component_config.decoder_start_token_id = config.getint( 'decoder', 'decoder_start_token_id') + component_config.eos_token_id = config.getint( + 'decoder', 'eos_token_id') + bos_token_id = config.get('decoder', 'bos_token_id') + # T5 does not have bos_token_id + component_config.bos_token_id = int( + bos_token_id) if bos_token_id != "None" else None + component_config.pad_token_id = config.getint( + 'decoder', 'pad_token_id') else: assert False, 'Unsupported component!' @@ -334,9 +342,9 @@ def parse_nmt_config(args, model): config["decoder"][key] = f"{val}" config["decoder"]["q_scaling"] = '1' config["decoder"]["rescale_before_lm_head"] = 'false' - config['decoder']['has_model_final_layernorm'] = config['decoder'][ - 'decoder_normalize_before'] and not config['decoder'].getboolean( - 'no_decoder_final_norm', False) + config['decoder']['has_model_final_layernorm'] = str( + config['decoder'].getboolean('decoder_normalize_before', False) + and not config['decoder'].getboolean('no_decoder_final_norm', False)) config['decoder']['vocab_size'] = str(len(model.tgt_dict)) # fairseq naming config["structure"] = dict() @@ -426,8 +434,10 @@ def parse_nmt_config_by_component(config, component, args): 'd_kv', fallback=component_config.encoder_hidden_size // component_config.encoder_num_heads) - component_config.decoder_start_token_id = config.getint( - 'decoder', 'decoder_start_token_id') + component_config.decoder_start_token_id = None + component_config.eos_token_id = None + component_config.bos_token_id = None + component_config.pad_token_id = None return component_config @@ -733,6 +743,12 @@ def parse_bart_config_by_component(config, component, args): component_config.decoder_start_token_id = int( decoder_start_token_id ) if decoder_start_token_id != "None" else None + component_config.eos_token_id = config.getint( + 'decoder', 'eos_token_id') + component_config.bos_token_id = config.getint( + 'decoder', 'bos_token_id') + component_config.pad_token_id = config.getint( + 'decoder', 'pad_token_id') return component_config @@ -1025,6 +1041,12 @@ def parse_pix2struct_config_by_component(config, component, args): 'structure', 'position_embedding_type') args.decoder_start_token_id = config.getint( 'decoder', 'decoder_start_token_id') + args.eos_token_id = config.getint('decoder', 'eos_token_id') + bos_token_id = config.get('decoder', 'bos_token_id') + # pix2struct does not have bos_token_id + args.bos_token_id = int( + bos_token_id) if bos_token_id != "None" else None + args.pad_token_id = config.getint('decoder', 'pad_token_id') else: assert False, 'Unsupported component!' @@ -1330,6 +1352,9 @@ def convert_checkpoint(args): 'skip_cross_qkv': args.skip_cross_qkv, 'use_implicit_relative_attention': args.use_implicit_relative_attention, 'decoder_start_token_id': decoder_config.decoder_start_token_id, + 'eos_token_id': decoder_config.eos_token_id, + 'bos_token_id': decoder_config.bos_token_id, + 'pad_token_id': decoder_config.pad_token_id, } for additional_setting in additional_settings: if hasattr(decoder_config, additional_setting): diff --git a/examples/falcon/requirements.txt b/examples/falcon/requirements.txt index 9ff7f2431..51b11206e 100644 --- a/examples/falcon/requirements.txt +++ b/examples/falcon/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 transformers>=4.31.0 datasets~=2.14.5 evaluate~=0.4.1 diff --git a/examples/gemma/requirements.txt b/examples/gemma/requirements.txt index 0feae1db6..1c18ec949 100644 --- a/examples/gemma/requirements.txt +++ b/examples/gemma/requirements.txt @@ -3,7 +3,7 @@ # WAR the new posting of "nvidia-cudnn-cu12~=9.0". # "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9". nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64" -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 flax~=0.8.0 # jax[cuda12_pip]~=0.4.19; platform_system != "Windows" jax~=0.4.19; platform_system == "Windows" diff --git a/examples/gpt/requirements.txt b/examples/gpt/requirements.txt index 1dedcca2d..cd7058518 100644 --- a/examples/gpt/requirements.txt +++ b/examples/gpt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/gptj/requirements.txt b/examples/gptj/requirements.txt index 747ac0491..65fa0cf26 100644 --- a/examples/gptj/requirements.txt +++ b/examples/gptj/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/gptneox/requirements.txt b/examples/gptneox/requirements.txt index 045707f77..b61a51e9d 100644 --- a/examples/gptneox/requirements.txt +++ b/examples/gptneox/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.14.5 rouge_score~=0.1.2 evaluate~=0.4.1 diff --git a/examples/grok/requirements.txt b/examples/grok/requirements.txt index f92d4318b..ff9d53e4b 100644 --- a/examples/grok/requirements.txt +++ b/examples/grok/requirements.txt @@ -1,6 +1,6 @@ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/high-level-api/requirements.txt b/examples/high-level-api/requirements.txt index bbae209ac..ea8238a8b 100644 --- a/examples/high-level-api/requirements.txt +++ b/examples/high-level-api/requirements.txt @@ -1,2 +1,2 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 diff --git a/examples/internlm/requirements.txt b/examples/internlm/requirements.txt index afea54243..6fcc646c9 100644 --- a/examples/internlm/requirements.txt +++ b/examples/internlm/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets==2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/jais/requirements.txt b/examples/jais/requirements.txt index 1dedcca2d..cd7058518 100644 --- a/examples/jais/requirements.txt +++ b/examples/jais/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/llama/README.md b/examples/llama/README.md index 847927af6..2073579a6 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -1217,7 +1217,8 @@ Users can run the LLaMA-3.1 model with higher precision (bf16/fp16) or fp8. Here To use the fp8 quantization, please add the `--use_fp8_rowwise` flag during the checkpoint conversion. In this demonstration, we convert the Meta checkpoint to bfloat16 with TP8-PP2 and the HF checkpoint to FP8 with TP8. -Note that you may need to update your transformers installation via `pip install --upgrade transformers`. +Note: You may need to update your transformers installation via `pip install --upgrade transformers`. +Note: For 405B HF model, there are duplicated kv head weights. Users could use `--remove_duplicated_kv_heads` to remove them. ```bash # Run BF16 model by BF16 @@ -1227,7 +1228,7 @@ python examples/llama/convert_checkpoint.py --meta_ckpt_dir llama_3.1_405B_meta_ --tp_size 8 \ --pp_size 2 \ --load_by_shard \ - --workers 8 + --workers 2 # Run BF16 model by FP8 python examples/llama/convert_checkpoint.py --model_dir llama_3.1_405B_HF_model/ \ @@ -1237,7 +1238,8 @@ python examples/llama/convert_checkpoint.py --model_dir llama_3.1_405B_HF_model/ --tp_size 8 \ --pp_size 1 \ --load_by_shard \ - --workers 8 + --workers 8 \ + --remove_duplicated_kv_heads # Run FP8 model by FP8 python examples/llama/convert_checkpoint.py --model_dir llama_3.1_405B_HF_FP8_model/ \ @@ -1246,7 +1248,8 @@ python examples/llama/convert_checkpoint.py --model_dir llama_3.1_405B_HF_FP8_mo --tp_size 8 \ --pp_size 1 \ --load_by_shard \ - --workers 8 + --workers 8 \ + --remove_duplicated_kv_heads ``` ### Build Engine diff --git a/examples/llama/convert_checkpoint.py b/examples/llama/convert_checkpoint.py index 33a3df180..7c7be1383 100644 --- a/examples/llama/convert_checkpoint.py +++ b/examples/llama/convert_checkpoint.py @@ -224,6 +224,13 @@ def parse_arguments(): help= 'Only save the model config w/o read and converting weights, be careful, this is for debug only' ) + parser.add_argument( + '--remove_duplicated_kv_heads', + action="store_true", + default=False, + help= + 'Only used to remove the duplicated kv heads of llama-3.1 405B HF model.' + ) parser.add_argument('--log_level', type=str, default='info') args = parser.parse_args() @@ -312,7 +319,8 @@ def args_to_build_options(args): 'embedding_sharding_dim': args.embedding_sharding_dim, 'share_embedding_table': args.use_embedding_sharing, 'disable_weight_only_quant_plugin': - args.disable_weight_only_quant_plugin + args.disable_weight_only_quant_plugin, + 'remove_duplicated_kv_heads': args.remove_duplicated_kv_heads, } diff --git a/examples/llama/requirements.txt b/examples/llama/requirements.txt index ad2910aa3..2a4cf5b2c 100644 --- a/examples/llama/requirements.txt +++ b/examples/llama/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/mamba/requirements.txt b/examples/mamba/requirements.txt index 0b1ad6b6f..9d04be51b 100644 --- a/examples/mamba/requirements.txt +++ b/examples/mamba/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 transformers>=4.39.0 datasets~=2.14.5 evaluate diff --git a/examples/medusa/requirements.txt b/examples/medusa/requirements.txt index aa8f33325..152819075 100644 --- a/examples/medusa/requirements.txt +++ b/examples/medusa/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/mixtral/requirements.txt b/examples/mixtral/requirements.txt index a681790b7..5254831a1 100644 --- a/examples/mixtral/requirements.txt +++ b/examples/mixtral/requirements.txt @@ -1,4 +1,4 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 transformers==4.38.2 accelerate==0.25.0 diff --git a/examples/mpt/requirements.txt b/examples/mpt/requirements.txt index 747ac0491..65fa0cf26 100644 --- a/examples/mpt/requirements.txt +++ b/examples/mpt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/multimodal/README.md b/examples/multimodal/README.md index f5cfdd100..52f9e5341 100644 --- a/examples/multimodal/README.md +++ b/examples/multimodal/README.md @@ -55,7 +55,7 @@ This BLIP section covers both BLIP2-OPT and BLIP2-T5, with minor changes needed ```bash trtllm-build \ --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ - --output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ + --output_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --gemm_plugin float16 \ --max_beam_width 1 \ --max_batch_size 8 \ @@ -106,7 +106,7 @@ This BLIP section covers both BLIP2-OPT and BLIP2-T5, with minor changes needed python build_visual_engine.py --model_type blip2 --model_path tmp/hf_models/${MODEL_NAME} --max_batch_size 8 ``` - The built engines are located in `./visual_engines/${MODEL_NAME}`. + The built engines are located in `tmp/trt_engines/${MODEL_NAME}/vision_encoder`. To run the BLIP2 pipeline with batch size > 1, change `--max_batch_size` argument to `build_visual_engine.py` accordingly. @@ -118,8 +118,8 @@ This BLIP section covers both BLIP2-OPT and BLIP2-T5, with minor changes needed --max_new_tokens 30 \ --input_text "Question: which city is this? Answer:" \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu ``` For BLIP2-T5 family, @@ -128,7 +128,7 @@ This BLIP section covers both BLIP2-OPT and BLIP2-T5, with minor changes needed --max_new_tokens 30 \ --input_text "Question: which city is this? Answer:" \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/bfloat16 ``` @@ -143,7 +143,7 @@ This BLIP section covers both BLIP2-OPT and BLIP2-T5, with minor changes needed trtllm-build \ --checkpoint_dir tmp/trt_models/${MODEL_NAME}/int4_weightonly/1-gpu \ - --output_dir trt_engines/${MODEL_NAME}/int4_weightonly/1-gpu \ + --output_dir tmp/trt_engines/${MODEL_NAME}/int4_weightonly/1-gpu \ --gemm_plugin float16 \ --max_beam_width 1 \ --max_batch_size 8 \ @@ -152,7 +152,7 @@ This BLIP section covers both BLIP2-OPT and BLIP2-T5, with minor changes needed --max_seq_len 1024 ``` - The built OPT engines lie in `trt_engines/${MODEL_NAME}/int4_weightonly/1-gpu`. + The built OPT engines lie in `tmp/trt_engines/${MODEL_NAME}/int4_weightonly/1-gpu`. You should use this directory as `--llm_engine_dir` argument to `run.py` **NOTE:** INT8/INT4 option is not supported for BLIP2-T5, because quantization support has not been @@ -182,10 +182,10 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in CogVLM uses a Vit encoder as LLM encoder and a modified Llama as decoder. ```bash - python ../cogvlm/convert_checkpoint.py --model_dir tmp/hf_models/${MODEL_NAME} --output_dir ./tllm_checkpoint_1gpu_bf16 --dtype bfloat16 --use_prompt_tuning + python ../cogvlm/convert_checkpoint.py --model_dir tmp/hf_models/${MODEL_NAME} --output_dir tmp/trt_models/${MODEL_NAME} --dtype bfloat16 --use_prompt_tuning - trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_bf16 \ - --output_dir ./tmp/cogvlm/trt_engines/bf16/1-gpu \ + trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME} \ + --output_dir tmp/trt_engines/${MODEL_NAME}/bf16/1-gpu \ --gemm_plugin bfloat16 \ --gpt_attention_plugin bfloat16 \ --context_fmha_fp32_acc enable \ @@ -209,13 +209,13 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --max_new_tokens 1000 \ --input_text " [INST] please describe this image in detail [/INST] " \ --hf_model_dir tmp/hf_models/${TOKENIZER_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir tmp/cogvlm/trt_engines/bf16/1-gpu \ - --batch_size 1 \ - --top_p 0.4 \ - --top_k 1 \ - --temperature 0.2 \ - --repetition_penalty 1.2 + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/bf16/1-gpu \ + --batch_size 1 \ + --top_p 0.4 \ + --top_k 1 \ + --temperature 0.2 \ + --repetition_penalty 1.2 ``` ## Deplot @@ -255,7 +255,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --max_input_len 1 ``` - The built deplot engines are located in `./tmp/trt_engines/${MODEL_NAME}/1-gpu/float16`. + The built deplot engines are located in `tmp/trt_engines/${MODEL_NAME}/1-gpu/float16`. 3. Build TensorRT engines for visual components @@ -263,7 +263,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in python build_visual_engine.py --model_type pix2struct --model_path tmp/hf_models/${MODEL_NAME} --max_batch_size 8 ``` - The built engines are located in `./visual_engines/${MODEL_NAME}`. + The built visual engines are located in `tmp/trt_engines/${MODEL_NAME}/vision_encoder`. To run the deplot pipeline with batch size > 1, change `--max_batch_size` argument to `build_visual_engine.py` accordingly. @@ -274,7 +274,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --max_new_tokens 100 \ --input_text "" \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/float16 ``` @@ -298,7 +298,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in trtllm-build \ --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ - --output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ + --output_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --gemm_plugin float16 \ --use_fused_mlp \ --max_batch_size 1 \ @@ -314,8 +314,8 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in python run.py \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir trt_engines/${MODEL_NAME}/1-gpu/bfloat16 + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu ``` ## Kosmos-2 @@ -337,7 +337,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in trtllm-build \ --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ - --output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ + --output_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --gpt_attention_plugin float16 \ --gemm_plugin float16 \ --max_batch_size 1 \ @@ -353,8 +353,8 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in python run.py \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu ``` ## LLaVA, LLaVa-NeXT and VILA @@ -400,28 +400,31 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --output_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ --dtype float16 + # for LLaVA trtllm-build \ --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ - --output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ + --output_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --gemm_plugin float16 \ --use_fused_mlp \ --max_batch_size 1 \ --max_input_len 2048 \ --max_seq_len 2560 \ - --max_multimodal_len 576 # 1 (max_batch_size) * 576 (num_visual_features) for LLaVA + --max_multimodal_len 576 # 1 (max_batch_size) * 576 (num_visual_features) + # for LLaVA-NeXT trtllm-build \ --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ - --output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ + --output_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --gpt_attention_plugin float16 \ --gemm_plugin float16 \ + --use_fused_mlp \ --max_batch_size 1 \ --max_input_len 4096 \ --max_seq_len 5120 \ --max_num_tokens 4096 \ # 1 (max_batch_size) * 4096 (max_input_len) - --max_multimodal_len 4096 \ # 1 (max_batch_size) * 4096 (max_input_len) - --use_fused_mlp + --max_multimodal_len 4096 # 1 (max_batch_size) * 4096 (max_input_len) + # for VILA trtllm-build \ --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ --output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ @@ -430,11 +433,9 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --max_batch_size 1 \ --max_input_len 2048 \ --max_seq_len 2560 \ - --max_multimodal_len 4096 # 1 (max_batch_size) * 4096 (num_visual_features) for VILA + --max_multimodal_len 4096 # 1 (max_batch_size) * 4096 (num_visual_features) ``` - Note: do not use `--use_fused_mlp` flag in quantization mode. - 3. Build TensorRT engines for visual components ```bash @@ -449,8 +450,8 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in python run.py \ --max_new_tokens 30 \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --input_text "Question: which city is this? Answer:" # for LLaVA and for LLaVA-NeXT ``` @@ -462,8 +463,8 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in python run.py \ --max_new_tokens 100 \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --image_path=av.png,https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png \ --input_text="\n\n Please elaborate what you see in the images?" \ --batch_size=1 # for VILA mode 1 @@ -471,8 +472,8 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in python run.py \ --max_new_tokens 100 \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --image_path=av.png,https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png \ --input_text="\n Please elaborate what you see in the images?" \ --batch_size=2 # for VILA mode 2 @@ -484,71 +485,30 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in Note: use `--run_profiling` for performance measurement, use `--check_accuracy` for accuracy check. -4. (Optional) INT8/INT4 weight-only quantization for LLaMA can be enabled as follows (take `INT4` as an example, while `INT8` is the default precision for weight-only quantization): - ```bash - python ../llama/convert_checkpoint.py \ +4. (Optional) Different quantization methods supported in LLaMA can be applied to LLaVA/VILA as well, such as INT4/INT8 weight-only, SmoothQuant, and INT4 Activation-Aware Quantization (AWQ). Detailed instructions can be found in LLaMA [README](../llama/README.md). + + For example, + + ```bash + # INT4 weight only + python ../llama/convert_checkpoint.py \ --model_dir tmp/hf_models/${MODEL_NAME} \ --dtype float16 \ --output_dir tmp/trt_models/${MODEL_NAME}/int4_weightonly/1-gpu \ --use_weight_only \ --weight_only_precision int4 - trtllm-build \ - --checkpoint_dir tmp/trt_models/${MODEL_NAME}/int4_weightonly/1-gpu \ - --output_dir trt_engines/${MODEL_NAME}/int4_weightonly/1-gpu \ - --gemm_plugin float16 \ - --max_batch_size 1 \ - --max_input_len 1024 \ - --max_seq_len 1124 \ - --max_multimodal_len 576 # for LLaVA - - trtllm-build \ - --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ - --output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ - --gemm_plugin float16 \ - --use_fused_mlp \ - --max_batch_size 1 \ - --max_input_len 1024 \ - --max_seq_len 1124 \ - --max_multimodal_len 4096 # for VILA - ``` - - The built engines lie in `trt_engines/${MODEL_NAME}/int4_weightonly/1-gpu`. - You should use this directory as `--llm_engine_dir` argument to `run.py` - -5. (Optional) One can also use LLaVA/VILA with other quantization options, like SmoothQuant and INT4 AWQ, that are supported by LLaMA. - Instructions in LLaMA [README](../llama/README.md) to enable SmoothQuant and INT4 AWQ can be re-used to generate - quantized TRT engines for LLM component of LLaVA/VILA. - - For example, - - ```bash + # INT4 AWQ python ../quantization/quantize.py \ --model_dir tmp/hf_models/${MODEL_NAME} \ --output_dir tmp/trt_models/${MODEL_NAME}/int4_awq/1-gpu \ --dtype float16 \ --qformat int4_awq \ --calib_size 32 - - trtllm-build \ - --checkpoint_dir tmp/trt_models/${MODEL_NAME}/int4_awq/1-gpu \ - --output_dir trt_engines/${MODEL_NAME}/int4_awq/1-gpu \ - --gemm_plugin float16 \ - --max_batch_size 1 \ - --max_input_len 1024 \ - --max_seq_len 1124 \ - --max_multimodal_len 576 # for LLaVA - - trtllm-build \ - --checkpoint_dir tmp/trt_models/${MODEL_NAME}/int4_awq/1-gpu \ - --output_dir trt_engines/${MODEL_NAME}/int4_awq/1-gpu \ - --gemm_plugin float16 \ - --max_batch_size 1 \ - --max_input_len 2048 \ - --max_seq_len 2560 \ - --max_multimodal_len 4096 # for VILA ``` + Then follow the same `trtllm-build` and `run.py` steps as before. NOTE: for `trtllm-build` command, do not use `--use_fused_mlp` in these quantization modes. + ## NeVA [NeVA](https://docs.nvidia.com/nemo-framework/user-guide/latest/multimodalmodels/neva/index.html) is a groundbreaking addition to the NeMo Multimodal ecosystem. This model seamlessly integrates large language-centric models with a vision encoder, that can be deployed in TensorRT-LLM. @@ -574,7 +534,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in trtllm-build \ --checkpoint_dir tmp/trt_models/${MODEL_NAME} \ - --output_dir trt_engines/${MODEL_NAME}/bf16/1-gpu \ + --output_dir tmp/trt_engines/${MODEL_NAME}/bf16/1-gpu \ --gpt_attention_plugin bfloat16 \ --gemm_plugin bfloat16 \ --max_batch_size 1 \ @@ -593,8 +553,8 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in python run.py \ --max_new_tokens 30 \ --hf_model_dir tmp/trt_models/${MODEL_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir trt_engines/${MODEL_NAME}/bf16/1-gpu \ + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/bf16/1-gpu \ --input_text "Question: which city is this? Answer:" ``` @@ -647,7 +607,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in python run.py \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 ``` @@ -672,7 +632,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in trtllm-build \ --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ - --output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ + --output_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --gpt_attention_plugin float16 \ --gemm_plugin float16 \ --max_batch_size 1 \ @@ -688,8 +648,8 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in python run.py \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu/ \ + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu/ \ --image_path=https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png ``` @@ -712,7 +672,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in trtllm-build \ --checkpoint_dir nemotron-3/trt_ckpt/bf16/1-gpu \ - --output_dir trt_engines/nemotron-3/bf16/1-gpu \ + --output_dir tmp/trt_engines/nemotron-3/bf16/1-gpu \ --gpt_attention_plugin bfloat16 \ --gemm_plugin bfloat16 \ --max_batch_size 1 \ @@ -724,15 +684,15 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in 2. Build TensorRT engines for visual components ```bash - python build_visual_engine.py --model_path /path/to/video/neva/projector.nemo --model_type video-neva + python build_visual_engine.py --model_path /path/to/video/neva/projector.nemo --model_type video-neva --output_dir tmp/trt_engines/nemotron-3/visual_encoder ``` ```bash python run.py \ --max_new_tokens 30 \ --hf_model_dir nemotron-3/trt_ckpt/bf16/1-gpu \ - --visual_engine_dir visual_engines/video_neva_engine \ - --llm_engine_dir trt_engines/nemotron-3/bf16/1-gpu \ + --visual_engine_dir tmp/trt_engines/nemotron-3/visual_encoder \ + --llm_engine_dir tmp/trt_engines/nemotron-3/bf16/1-gpu \ --input_text "Question: what is in the video? Answer:" \ --video_path /path/to/your/local/video/file ``` @@ -760,7 +720,7 @@ The full set of commands to enable 2-way tensor parallelism for LLaVA is: trtllm-build \ --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/2-gpu \ - --output_dir trt_engines/${MODEL_NAME}/fp16/2-gpu \ + --output_dir tmp/trt_engines/${MODEL_NAME}/fp16/2-gpu \ --gemm_plugin float16 \ --max_batch_size 1 \ --max_input_len 2048 \ @@ -773,6 +733,6 @@ The full set of commands to enable 2-way tensor parallelism for LLaVA is: python run.py \ --max_new_tokens 30 \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ - --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir trt_engines/${MODEL_NAME}/fp16/2-gpu \ + --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/2-gpu \ ``` diff --git a/examples/multimodal/build_visual_engine.py b/examples/multimodal/build_visual_engine.py index 7d2e4f65c..67b01c291 100644 --- a/examples/multimodal/build_visual_engine.py +++ b/examples/multimodal/build_visual_engine.py @@ -66,11 +66,11 @@ def __init__(self, args): args.device = torch.device( "cuda") if torch.cuda.is_available() else "cpu" if args.output_dir is None: - args.output_dir = 'visual_engines/%s' % ( - args.model_path.split('/')[-1] if args.vila_path is not None - else args.model_path.split('/')[-1]) - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) + # default path to save the engines + model_name = args.model_path.split('/')[-1] + args.output_dir = f'tmp/trt_engines/{model_name}/vision_encoder' + + os.makedirs(args.output_dir, exist_ok=True) self.args = args @@ -103,35 +103,42 @@ def build(self): raise RuntimeError(f"Invalid model type {args.model_type}") -def export_visual_wrapper_onnx(visual_wrapper, - input, - output_dir, - input_names=['input'], - dynamic_axes={'input': { - 0: 'batch' - }}): - logger.log(trt.Logger.INFO, "Exporting onnx") - os.makedirs(f'{output_dir}/onnx', exist_ok=True) - torch.onnx.export(visual_wrapper, +def export_onnx(model, + input, + onnx_dir, + onnx_name='model.onnx', + input_names=['input'], + output_names=['output'], + dynamic_axes={'input': { + 0: 'batch' + }}, + logger=trt.Logger(trt.Logger.INFO)): + logger.log(trt.Logger.INFO, f"Exporting onnx to {onnx_dir}/{onnx_name}") + os.makedirs(onnx_dir, exist_ok=True) + torch.onnx.export(model, input, - f'{output_dir}/onnx/visual_encoder.onnx', + f'{onnx_dir}/{onnx_name}', opset_version=17, input_names=input_names, - output_names=['output'], + output_names=output_names, dynamic_axes=dynamic_axes) def build_trt_engine(model_type, input_sizes, - output_dir, + onnx_dir, + engine_dir, max_batch_size, dtype=torch.float16, - num_frames=None): - part_name = 'visual_encoder' - onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name) - engine_file = '%s/%s.engine' % (output_dir, part_name) - config_file = '%s/%s' % (output_dir, "config.json") - logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name) + num_frames=None, + onnx_name='model.onnx', + engine_name='model.engine', + delete_onnx=True, + logger=trt.Logger(trt.Logger.INFO)): + onnx_file = f'{onnx_dir}/{onnx_name}' + engine_file = f'{engine_dir}/{engine_name}' + config_file = f'{engine_dir}/config.json' + logger.log(trt.Logger.INFO, f"Building TRT engine to {engine_file}") builder = trt.Builder(logger) network = builder.create_network( @@ -157,9 +164,6 @@ def build_trt_engine(model_type, logger.log(trt.Logger.ERROR, parser.get_error(error)) logger.log(trt.Logger.INFO, "Succeeded parsing %s" % onnx_file) - # Delete onnx files since we don't need them now - shutil.rmtree(f'{output_dir}/onnx') - nBS = -1 nMinBS = 1 nOptBS = max(nMinBS, int(max_batch_size / 2)) @@ -167,8 +171,9 @@ def build_trt_engine(model_type, inputT = network.get_input(0) - # input sizes can be a list of ints (e.g., [3, H, W]) when inputs are images, - # or a list of three int lists (e.g., [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]]). + # input sizes can be: + # - integer list, when inputs are constant size images. e.g. [3, H, W] + # - list of integer lists, when inputs are dynamic size images. e.g. [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]] assert isinstance(input_sizes, list), "input_sizes must be a list" if isinstance(input_sizes[0], int): logger.log(trt.Logger.INFO, f"Processed input sizes {input_sizes}") @@ -200,9 +205,14 @@ def build_trt_engine(model_type, else: logger.log(trt.Logger.INFO, "Succeeded building %s in %d s" % (engine_file, t1 - t0)) + os.makedirs(engine_dir, exist_ok=True) with open(engine_file, 'wb') as f: f.write(engine_string) + # Clear onnx files since we no longer need them after a successful engine build + if delete_onnx: + shutil.rmtree(onnx_dir) + Builder.save_config(config_wrapper, config_file) @@ -245,10 +255,11 @@ def forward(self, image): model.language_projection, model.query_tokens) wrapper.to(args.device) - export_visual_wrapper_onnx(wrapper, image, args.output_dir) + export_onnx(wrapper, image, f'{args.output_dir}/onnx') build_trt_engine( args.model_type + "-" + blip2_llm, # blip2-t5 or blip2-opt [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] + f'{args.output_dir}/onnx', args.output_dir, args.max_batch_size) @@ -284,23 +295,24 @@ def forward(self, image, attention_mask): # falls within a relatively narrow range. To improve performance, we can avoid using # dynamic axis for the input patches and instead use a fixed number of patches along # with an attention mask. - export_visual_wrapper_onnx(wrapper, (image, attention_mask), - args.output_dir, - input_names=['input', 'attention_mask'], - dynamic_axes={ - 'input': { - 0: 'batch' - }, - 'attention_mask': { - 0: 'batch' - } - }) + export_onnx(wrapper, (image, attention_mask), + f'{args.output_dir}/onnx', + input_names=['input', 'attention_mask'], + dynamic_axes={ + 'input': { + 0: 'batch' + }, + 'attention_mask': { + 0: 'batch' + } + }) build_trt_engine( args.model_type, [image.shape[1], image.shape[2]], # Number of Patches, Hidden Dimension + f'{args.output_dir}/onnx', args.output_dir, args.max_batch_size, - torch.bfloat16) + dtype=torch.bfloat16) def build_llava_engine(args): @@ -359,10 +371,11 @@ def forward(self, pixel_values): model.multi_modal_projector.to(args.device), ) - export_visual_wrapper_onnx(wrapper, image, args.output_dir) + export_onnx(wrapper, image, f'{args.output_dir}/onnx') build_trt_engine( args.model_type, [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] + f'{args.output_dir}/onnx', args.output_dir, args.max_batch_size) if args.model_type == "llava_next": @@ -408,10 +421,11 @@ def forward(self, image): ) wrapper = VilaVisionWrapper(model.get_vision_tower().to(args.device), model.mm_projector.to(args.device)) - export_visual_wrapper_onnx(wrapper, image, args.output_dir) + export_onnx(wrapper, image, f'{args.output_dir}/onnx') build_trt_engine( args.model_type, [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] + f'{args.output_dir}/onnx', args.output_dir, args.max_batch_size) @@ -436,10 +450,11 @@ def forward(self, image): swin_encoder = model.get_encoder().to(args.device) wrapper = SwinEncoderWrapper(swin_encoder) - export_visual_wrapper_onnx(wrapper, image, args.output_dir) + export_onnx(wrapper, image, f'{args.output_dir}/onnx') build_trt_engine( args.model_type, [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] + f'{args.output_dir}/onnx', args.output_dir, args.max_batch_size) @@ -471,13 +486,14 @@ def forward(self, image): vit_encoder = cogvlm.model.vision.to(args.device).eval() wrapper = CogVlmVisionWrapper(vit_encoder) - export_visual_wrapper_onnx(wrapper, image, args.output_dir) + export_onnx(wrapper, image, f'{args.output_dir}/onnx') build_trt_engine( args.model_type, [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] + f'{args.output_dir}/onnx', args.output_dir, args.max_batch_size, - dtype) + dtype=dtype) def build_fuyu_engine(args): @@ -502,13 +518,13 @@ def forward(self, patches): vision_encoder = model.vision_embed_tokens wrapper = FuyuEncoderWrapper(vision_encoder).to(args.device) - export_visual_wrapper_onnx(wrapper, - image, - args.output_dir, - dynamic_axes={'input': { - 0: 'batch', - 2: 'patch' - }}) + export_onnx(wrapper, + image, + f'{args.output_dir}/onnx', + dynamic_axes={'input': { + 0: 'batch', + 2: 'patch' + }}) build_trt_engine( args.model_type, # [nImgs, nImgPatches, nDims] @@ -516,6 +532,7 @@ def forward(self, patches): # nImgPatches depends on image size (patch size: 30x30) # nDims is 30x30x3=2700 (patch size x color channels) [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]], + f'{args.output_dir}/onnx', args.output_dir, args.max_batch_size) @@ -589,13 +606,14 @@ def forward(self, images): dummy_image = torch.empty( 1, 3, image_size, image_size, dtype=dtype, device=args.device) # dummy image shape [B, C, H, W] - export_visual_wrapper_onnx(wrapper, dummy_image, args.output_dir) + export_onnx(wrapper, dummy_image, f'{args.output_dir}/onnx') build_trt_engine( args.model_type, [3, image_size, image_size], # [3, H, W] + f'{args.output_dir}/onnx', args.output_dir, args.max_batch_size, - dtype) + dtype=dtype) def build_video_neva_engine(args): @@ -669,13 +687,14 @@ def forward(self, images): image_size, dtype=dtype, device=args.device) # dummy image - export_visual_wrapper_onnx(wrapper, dummy_video, args.output_dir) + export_onnx(wrapper, dummy_video, f'{args.output_dir}/onnx') build_trt_engine( args.model_type, [num_frames, 3, image_size, image_size], # [num_frames, 3, H, W] + f'{args.output_dir}/onnx', args.output_dir, args.max_batch_size, - dtype, + dtype=dtype, num_frames=num_frames) @@ -707,10 +726,11 @@ def forward(self, images): model.vision_model.to(args.device), model.image_to_text_projection.to(args.device)) - export_visual_wrapper_onnx(wrapper, image, args.output_dir) + export_onnx(wrapper, image, f'{args.output_dir}/onnx') build_trt_engine( args.model_type, [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] + f'{args.output_dir}/onnx', args.output_dir, args.max_batch_size) @@ -783,10 +803,10 @@ def apply_img_projection(self, input): model.model.vision_embed_tokens.sub_GN) tensors = {"glb_GN": glb_GN, "sub_GN": sub_GN} save_file(tensors, args.output_dir + "/image_newlines.safetensors") - export_visual_wrapper_onnx(wrapper, image, args.output_dir) + export_onnx(wrapper, image, f'{args.output_dir}/onnx') build_trt_engine( - args.model_type, - [image.shape[1], image.shape[2], image.shape[3]], args.output_dir, + args.model_type, [image.shape[1], image.shape[2], image.shape[3]], + f'{args.output_dir}/onnx', args.output_dir, args.max_batch_size * (num_crops + 1)) #TODO: Take input from config diff --git a/examples/multimodal/run.py b/examples/multimodal/run.py index 9caeb3637..fbd8e361c 100644 --- a/examples/multimodal/run.py +++ b/examples/multimodal/run.py @@ -40,6 +40,10 @@ def parse_arguments(): type=str, default=None, help='Directory containing visual TRT engines') + parser.add_argument('--visual_engine_name', + type=str, + default='model.engine', + help='Name of visual TRT engine') parser.add_argument('--llm_engine_dir', type=str, default=None, @@ -314,7 +318,7 @@ def batch_decode(self, x, **kwargs): def init_image_encoder(self): vision_encoder_path = os.path.join(self.args.visual_engine_dir, - 'visual_encoder.engine') + self.args.visual_engine_name) logger.info(f'Loading engine from {vision_encoder_path}') with open(vision_encoder_path, 'rb') as f: engine_buffer = f.read() diff --git a/examples/nemotron/requirements.txt b/examples/nemotron/requirements.txt index 5bd99f191..95e7aa69a 100644 --- a/examples/nemotron/requirements.txt +++ b/examples/nemotron/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 transformers==4.40.2 # https://github.com/NVIDIA/NeMo/issues/9793 huggingface_hub==0.23.5 diff --git a/examples/opt/requirements.txt b/examples/opt/requirements.txt index 747ac0491..65fa0cf26 100644 --- a/examples/opt/requirements.txt +++ b/examples/opt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/phi/convert_checkpoint.py b/examples/phi/convert_checkpoint.py index 6f562952f..cddb110b0 100644 --- a/examples/phi/convert_checkpoint.py +++ b/examples/phi/convert_checkpoint.py @@ -15,11 +15,16 @@ import argparse import os import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed from transformers import AutoConfig import tensorrt_llm +from tensorrt_llm.mapping import Mapping from tensorrt_llm.models import Phi3ForCausalLM, PhiForCausalLM +from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization import QuantAlgo def parse_arguments(): @@ -68,6 +73,38 @@ def parse_arguments(): return args +def execute(workers, func, args): + if workers == 1: + for rank, f in enumerate(func): + f(args, rank) + else: + with ThreadPoolExecutor(max_workers=workers) as p: + futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len( + exceptions + ) == 0, "Checkpoint conversion failed, please check error log." + + +def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: + '''return config dict with quantization info based on the command line args + ''' + quant_config = QuantConfig() + if args.use_weight_only: + if args.weight_only_precision == 'int8': + quant_config.quant_algo = QuantAlgo.W8A16 + elif args.weight_only_precision == 'int4': + quant_config.quant_algo = QuantAlgo.W4A16 + + return quant_config + + if __name__ == '__main__': print(tensorrt_llm.__version__) args = parse_arguments() @@ -84,15 +121,37 @@ def parse_arguments(): 'PhiForCausalLM', 'Phi3ForCausalLM', 'Phi3VForCausalLM', 'Phi3SmallForCausalLM' ] - modelForCausalLM = None + if model_type not in supported_models: assert False, "Invalid model type" - modelForCausalLM = PhiForCausalLM if model_type == 'PhiForCausalLM' else Phi3ForCausalLM - modelForCausalLM.convert_hf_checkpoint(args.model_dir, - dtype=args.dtype, - output_dir=args.output_dir, - args=args) + phi_model = Phi3ForCausalLM if model_type.find( + 'Phi3') != -1 else PhiForCausalLM + + hf_model = None + + override_fields = {} + # override_fields.update(args_to_build_options(args)) + quant_config = args_to_quant_config(args) + + def convert_and_save_rank(args, rank): + mapping = Mapping(world_size=args.tp_size * args.pp_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size) + + phi = phi_model.from_hugging_face( + args.model_dir if hf_model is None else hf_model, + args.dtype, + mapping=mapping, + quant_config=quant_config, + **override_fields, + ) + phi.save_checkpoint(args.output_dir, save_config=(rank == 0)) + del phi + + execute(args.workers, [convert_and_save_rank] * args.tp_size * args.pp_size, + args) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) diff --git a/examples/phi/requirements.txt b/examples/phi/requirements.txt index 9fbd2480f..e19af8a47 100644 --- a/examples/phi/requirements.txt +++ b/examples/phi/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/quantization/requirements.txt b/examples/quantization/requirements.txt index b94f7a09a..63b3ee83d 100644 --- a/examples/quantization/requirements.txt +++ b/examples/quantization/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets>=2.14.4 nemo-toolkit[all]<=1.20.0,>=1.18.0 rouge_score~=0.1.2 diff --git a/examples/qwen/requirements.txt b/examples/qwen/requirements.txt index f02dd342a..258fbb79b 100644 --- a/examples/qwen/requirements.txt +++ b/examples/qwen/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/qwenvl/requirements.txt b/examples/qwenvl/requirements.txt index c303a4226..1c49cfff4 100644 --- a/examples/qwenvl/requirements.txt +++ b/examples/qwenvl/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/recurrentgemma/requirements.txt b/examples/recurrentgemma/requirements.txt index e60f95abc..b60499270 100644 --- a/examples/recurrentgemma/requirements.txt +++ b/examples/recurrentgemma/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 git+https://github.com/google-deepmind/recurrentgemma.git flax>=0.8.2 jax~=0.4.23 diff --git a/examples/skywork/requirements.txt b/examples/skywork/requirements.txt index 376beeadd..7c562c26f 100644 --- a/examples/skywork/requirements.txt +++ b/examples/skywork/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets~=2.16.1 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/smaug/requirements.txt b/examples/smaug/requirements.txt index ad2910aa3..2a4cf5b2c 100644 --- a/examples/smaug/requirements.txt +++ b/examples/smaug/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/whisper/requirements.txt b/examples/whisper/requirements.txt index 8dcca25f0..c0ec8018e 100644 --- a/examples/whisper/requirements.txt +++ b/examples/whisper/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024072302 +tensorrt_llm==0.12.0.dev2024073000 tiktoken datasets kaldialign diff --git a/requirements-windows.txt b/requirements-windows.txt index 04e6d6917..c07e60595 100644 --- a/requirements-windows.txt +++ b/requirements-windows.txt @@ -19,7 +19,7 @@ tensorrt-cu12==10.1.0 tokenizers>=0.14 # Default torch is CPU-only on Windows, so need to specify a torch version with GPU support torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-win_amd64.whl -nvidia-modelopt~=0.13,<0.14 +nvidia-modelopt~=0.15.0 transformers>=4.38.2 wheel optimum diff --git a/requirements.txt b/requirements.txt index 550c90aff..6b392d135 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ tensorrt~=10.2.0 # https://github.com/pytorch/pytorch/blob/v2.3.1/version.txt uses 2.3.0a0. # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-05.html#rel-24-05 uses 2.4.0a0. torch>=2.3.0a0,<=2.4.0a0 -nvidia-modelopt~=0.13,<0.14 +nvidia-modelopt~=0.15.0 transformers>=4.38.2 pillow==10.3.0 wheel diff --git a/tensorrt_llm/_ipc_utils.py b/tensorrt_llm/_ipc_utils.py index 84c04c947..760e2a13a 100644 --- a/tensorrt_llm/_ipc_utils.py +++ b/tensorrt_llm/_ipc_utils.py @@ -102,7 +102,9 @@ def open_ipc_memory(mapping: Mapping, Returns a list of buffer pointers, buffers[i] is a handle to the corresponding buffer residing on GPU #i. Call close_ipc_handle with the *buffer*. """ - comm = mpi_comm().Split(mapping.pp_rank, mapping.tp_rank) + comm = mpi_comm().Split( + mapping.pp_rank * mapping.cp_size + mapping.cp_rank, + mapping.tp_rank) error, local_ptr = cudart.cudaMalloc(size) _raise_if_error(error) diff --git a/tensorrt_llm/auto_parallel/parallelization.py b/tensorrt_llm/auto_parallel/parallelization.py index 1ef18daf7..c55cad54e 100644 --- a/tensorrt_llm/auto_parallel/parallelization.py +++ b/tensorrt_llm/auto_parallel/parallelization.py @@ -21,8 +21,7 @@ from tensorrt_llm.network import (PluginInfo, delete_plugin_info, get_np_weight, get_plugin_info, set_plugin_info) from tensorrt_llm.plugin import TRT_LLM_PLUGIN_NAMESPACE, init_all_reduce_helper -from tensorrt_llm.plugin.plugin import (CustomAllReduceHelper, - current_all_reduce_helper) +from tensorrt_llm.plugin.plugin import CustomAllReduceHelper from tensorrt_llm.version import __version__ from .config import AutoParallelConfig @@ -1564,7 +1563,6 @@ def add_reduce_scatter(self, context: GraphContext, input_name, output_name, def add_all_reduce_layer(self, context: GraphContext, input_name, output_name, device_ids, to_reduce_tensors): - counter = current_all_reduce_helper().gen_id() for device_id, to_reduce_tensor in zip(np.nditer(device_ids), to_reduce_tensors): device_id = device_id.item() @@ -1583,7 +1581,6 @@ def add_all_reduce_layer(self, context: GraphContext, input_name, strategy=strategy, dtype=to_reduce_tensor.dtype, config=AllReduceConfig(0), - counter=counter, reduce_fusion_params=AllReduceFusionParams(), ) plugin_info = PluginInfo(allreduce_plg_creator, "allreduce", pfc) @@ -2198,7 +2195,7 @@ def parallelize( if not debug_mode: init_all_reduce_helper() tp_size = phy_mesh.size // config.graph_config.num_stages - shape = (CustomAllReduceHelper.POINTERS_PER_RANK * tp_size, ) + shape = (CustomAllReduceHelper.POINTERS_PER_RANK * tp_size + 1, ) workspace = graph.as_trt().add_input( name="all_reduce_workspace", dtype=trt.int64, diff --git a/tensorrt_llm/commands/build.py b/tensorrt_llm/commands/build.py index 0bc067b3b..18d14c70e 100644 --- a/tensorrt_llm/commands/build.py +++ b/tensorrt_llm/commands/build.py @@ -105,6 +105,7 @@ def parse_arguments(): help='It equals to max_batch_size*max_beam_width by default, set this ' 'value as close as possible to the actual number of tokens on your workload. ' 'Note that this argument might be removed in the future.') + parser.add_argument('--cp_size', type=int, default=1) parser.add_argument('--tp_size', type=int, default=1) parser.add_argument('--pp_size', type=int, default=1) parser.add_argument( @@ -269,6 +270,7 @@ def build_model( bool = False, # return the modified BuildConfig without actually building the engine **kwargs ) -> Union[Engine, BuildConfig]: + model_config = copy.deepcopy(model_config) logits_dtype = kwargs.get('logits_dtype') @@ -412,6 +414,7 @@ def main(): kwargs = { 'logits_dtype': args.logits_dtype, 'use_fused_mlp': args.use_fused_mlp, + 'cp_size': args.cp_size, 'tp_size': args.tp_size, 'pp_size': args.pp_size, 'lora_dir': args.lora_dir, diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 1a897a7cf..97eea4a49 100644 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -3706,7 +3706,6 @@ def create_allreduce_plugin( strategy: AllReduceStrategy, dtype: trt.DataType, config: AllReduceConfig, - counter: int, reduce_fusion_params: AllReduceFusionParams, ): allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator( @@ -3727,9 +3726,6 @@ def create_allreduce_plugin( "fusion_op", np.array([int(reduce_fusion_params.fusion_op)], np.int8), trt.PluginFieldType.INT8) pfc.append(p_fusion_op) - p_counter = trt.PluginField("counter", np.array([counter], np.int32), - trt.PluginFieldType.INT32) - pfc.append(p_counter) p_eps = trt.PluginField( "eps", np.array([float(reduce_fusion_params.eps)], np.float32), trt.PluginFieldType.FLOAT32) @@ -3805,10 +3801,8 @@ def allreduce( strategy = AllReduceStrategy.NCCL workspace = None - counter = 0 if strategy != AllReduceStrategy.NCCL: workspace = current_all_reduce_helper().workspace.trt_tensor - counter = current_all_reduce_helper().gen_id() if reduce_fusion_params is None: reduce_fusion_params = AllReduceFusionParams() @@ -3822,7 +3816,6 @@ def allreduce( strategy=strategy, dtype=str_dtype_to_trt(dtype), config=config, - counter=counter, reduce_fusion_params=reduce_fusion_params, ) _add_plugin_info(layer, allreduce_plg_creator, "allreduce", pfc) diff --git a/tensorrt_llm/hlapi/llm_utils.py b/tensorrt_llm/hlapi/llm_utils.py index c7645146f..a243f1a06 100644 --- a/tensorrt_llm/hlapi/llm_utils.py +++ b/tensorrt_llm/hlapi/llm_utils.py @@ -1017,7 +1017,8 @@ def _load_model_from_hf(self): assert self._model_dir is not None import transformers - hf_config = transformers.AutoConfig.from_pretrained(self._model_dir) + hf_config = transformers.AutoConfig.from_pretrained( + self._model_dir, trust_remote_code=True) architecture = hf_config.architectures[0] if architecture not in MODEL_MAP: @@ -1262,7 +1263,8 @@ def get_final_build_config(llm_args: LlmArgs, # dtype. That's why the model will be downloaded from HF if necessary to get the accurate dtype. import transformers - hf_config = transformers.AutoConfig.from_pretrained(model_dir) + hf_config = transformers.AutoConfig.from_pretrained( + model_dir, trust_remote_code=True) architecture = hf_config.architectures[0] if architecture not in MODEL_MAP: diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index 66f34aec0..d33a6e414 100644 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -23,9 +23,9 @@ trt_dtype_to_str) from ..functional import (ACT2FN, AllReduceFusionParams, AttentionMaskType, Conditional, LayerNormType, PositionEmbeddingType, - RopeEmbeddingUtils, RotaryScalingType, Tensor, arange, - bert_attention, cast, clip, concat, constant, - embedding, expand, expand_dims, expand_mask, + RopeEmbeddingUtils, RotaryScalingType, Tensor, + allgather, arange, bert_attention, cast, clip, concat, + constant, embedding, expand, expand_dims, expand_mask, generate_alibi_biases, generate_alibi_slopes, gpt_attention, matmul) from ..functional import max as fmax @@ -155,6 +155,44 @@ def __init__(self, self.host_runtime_perf_knobs = host_runtime_perf_knobs + # const parameters that will be reused by all layers. + self.embed_positions = None + self.rotary_inv_freq = None + self.embed_positions_for_gpt_attention = None + self.embed_positions_short_factors = None + self.embed_positions_long_factors = None + self.embed_positions_short_factors_for_attention_plugin = None + self.embed_positions_long_factors_for_attention_plugin = None + self.short_mscale = 1.0 + self.long_mscale = 1.0 + self.short_inv_freq = None + self.long_inv_freq = None + + def fill_attention_const_params_for_rope( + self, + embed_positions: Tensor = None, + rotary_inv_freq: Tensor = None, + embed_positions_for_gpt_attention: Tensor = None): + self.embed_positions = embed_positions + self.rotary_inv_freq = rotary_inv_freq + self.embed_positions_for_gpt_attention = embed_positions_for_gpt_attention + return self + + def fill_attention_const_params_for_long_rope( + self, embed_positions_short_factors, embed_positions_long_factors, + embed_positions_short_factors_for_attention_plugin, + embed_positions_long_factors_for_attention_plugin, short_mscale, + long_mscale, short_inv_freq, long_inv_freq): + self.embed_positions_short_factors = embed_positions_short_factors + self.embed_positions_long_factors = embed_positions_long_factors + self.embed_positions_short_factors_for_attention_plugin = embed_positions_short_factors_for_attention_plugin + self.embed_positions_long_factors_for_attention_plugin = embed_positions_long_factors_for_attention_plugin + self.short_mscale = short_mscale + self.long_mscale = long_mscale + self.short_inv_freq = short_inv_freq + self.long_inv_freq = long_inv_freq + return self + def is_valid_cross_attn(self, do_cross_attention): if do_cross_attention: if self.encoder_input_lengths is None: @@ -373,79 +411,6 @@ def __init__(self, if self.position_embedding_type.is_rope(): self.rotary_embedding_dim = int(self.attention_head_size * rotary_embedding_percentage) - - if self.position_embedding_type == PositionEmbeddingType.long_rope: - embed_positions_short_factors, embed_positions_long_factors, \ - (short_inv_freq, embed_positions_short_factors_for_attention_plugin), \ - (long_inv_freq, embed_positions_long_factors_for_attention_plugin), mscale \ - = RopeEmbeddingUtils.create_sinusoidal_positions_long_rope( - self.max_position_embeddings, - original_max_position_embeddings, self.rotary_embedding_dim, - self.rotary_embedding_base, rope_scaling_short_factors, - rope_scaling_long_factors, rope_scaling_short_mscale, rope_scaling_long_mscale) - - if rope_scaling_short_mscale is not None: - assert rope_scaling_long_mscale is not None - short_mscale = rope_scaling_short_mscale - long_mscale = rope_scaling_long_mscale - else: - short_mscale = long_mscale = mscale - - short_inv_freq = short_inv_freq.reshape(1, -1) - long_inv_freq = long_inv_freq.reshape(1, -1) - - self.register_parameter( - 'embed_positions_short_factors', - Parameter(embed_positions_short_factors, - dtype='float32', - is_buffer=True)) - self.register_parameter( - 'embed_positions_long_factors', - Parameter(embed_positions_long_factors, - dtype='float32', - is_buffer=True)) - self.register_parameter( - 'embed_positions_short_factors_for_attention_plugin', - Parameter( - embed_positions_short_factors_for_attention_plugin, - dtype='float32', - is_buffer=True)) - self.register_parameter( - 'embed_positions_long_factors_for_attention_plugin', - Parameter(embed_positions_long_factors_for_attention_plugin, - dtype='float32', - is_buffer=True)) - self.short_mscale = short_mscale - self.long_mscale = long_mscale - self.register_parameter( - 'short_inv_freq', - Parameter(short_inv_freq, dtype='float32', is_buffer=True)) - self.register_parameter( - 'long_inv_freq', - Parameter(long_inv_freq, dtype='float32', is_buffer=True)) - else: - # Rotary cos/sin cache. - embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions( - self.max_position_embeddings, - self.rotary_embedding_dim, - ) - self.register_parameter( - 'embed_positions', - Parameter(embed_positions, dtype='float32', is_buffer=True)) - rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( - self.max_position_embeddings, self.rotary_embedding_dim, - self.rotary_embedding_base, self.rotary_embedding_scale, - self.rotary_embedding_scale_type, - self.rotary_embedding_scaling) - self.register_parameter( - 'rotary_inv_freq', - Parameter(rotary_inv_freq, dtype='float32', is_buffer=True)) - self.register_parameter( - 'embed_positions_for_gpt_attention', - Parameter(embed_positions_for_gpt_attention, - dtype='float32', - is_buffer=True)) - elif self.position_embedding_type.is_alibi(): alibi_scale = 1. / self.norm_factor if self.scale_alibi_bias else 1. alibi_slopes = generate_alibi_slopes( @@ -518,6 +483,155 @@ def __init__(self, self.skip_cross_qkv = skip_cross_qkv + @staticmethod + def create_attention_const_params(model_cls, config): + # get rotary parameters. + hidden_size = config.hidden_size + num_attention_heads = config.num_attention_heads + attention_head_size = config.head_size + max_position_embeddings = config.max_position_embeddings + position_embedding_type = config.position_embedding_type + rotary_embedding_base = getattr(config, 'rotary_base', 10000.0) + rotary_embedding_scaling = getattr(config, 'rotary_scaling', None) + rotary_embedding_percentage = getattr(config, 'rotary_pct', 1.0) + # only rope need the const parameters. + if not position_embedding_type.is_rope(): + return + # attention head size + attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size + # rotary embedding dim. + rotary_embedding_dim = getattr( + config, 'rotary_dim', + int(attention_head_size * rotary_embedding_percentage)) + # rotary scaling. + rotary_embedding_scale_type = RotaryScalingType.none + rotary_embedding_scale = 1.0 + if rotary_embedding_scaling is not None: + rotary_scaling_type = rotary_embedding_scaling.get( + "type", rotary_embedding_scaling.get("rope_type")) + rotary_embedding_scale_type = RotaryScalingType.from_string( + rotary_scaling_type) + rotary_embedding_scale = rotary_embedding_scaling.get("factor", 1.0) + + if position_embedding_type == PositionEmbeddingType.long_rope: + rope_scaling_short_factors, rope_scaling_long_factors = None, None + rope_scaling_short_mscale, rope_scaling_long_mscale = None, None + original_max_position_embeddings = max_position_embeddings + + if hasattr(config, "longrope_scaling_short_factors"): + rope_scaling_short_factors = np.asarray( + config.longrope_scaling_short_factors).astype(np.float32) + rope_scaling_long_factors = np.asarray( + config.longrope_scaling_long_factors).astype(np.float32) + + original_max_position_embeddings = config.original_max_position_embeddings + + if config.architecture == "Phi3SmallForCausalLM": + rope_scaling_short_mscale = config.longrope_short_mscale + rope_scaling_long_mscale = config.longrope_long_mscale + embed_positions_short_factors, embed_positions_long_factors, \ + (short_inv_freq, embed_positions_short_factors_for_attention_plugin), \ + (long_inv_freq, embed_positions_long_factors_for_attention_plugin), mscale \ + = RopeEmbeddingUtils.create_sinusoidal_positions_long_rope( + max_position_embeddings, + original_max_position_embeddings, rotary_embedding_dim, + rotary_embedding_base, rope_scaling_short_factors, + rope_scaling_long_factors, rope_scaling_short_mscale, rope_scaling_long_mscale) + + if rope_scaling_short_mscale is not None: + assert rope_scaling_long_mscale is not None + short_mscale = rope_scaling_short_mscale + long_mscale = rope_scaling_long_mscale + else: + short_mscale = long_mscale = mscale + + short_inv_freq = short_inv_freq.reshape(1, -1) + long_inv_freq = long_inv_freq.reshape(1, -1) + + model_cls.register_parameter( + 'embed_positions_short_factors', + Parameter(embed_positions_short_factors, + dtype='float32', + is_buffer=True)) + model_cls.register_parameter( + 'embed_positions_long_factors', + Parameter(embed_positions_long_factors, + dtype='float32', + is_buffer=True)) + model_cls.register_parameter( + 'embed_positions_short_factors_for_attention_plugin', + Parameter( + embed_positions_short_factors_for_attention_plugin, + dtype='float32', + is_buffer=True)) + model_cls.register_parameter( + 'embed_positions_long_factors_for_attention_plugin', + Parameter(embed_positions_long_factors_for_attention_plugin, + dtype='float32', + is_buffer=True)) + model_cls.short_mscale = short_mscale + model_cls.long_mscale = long_mscale + model_cls.register_parameter( + 'short_inv_freq', + Parameter(short_inv_freq, dtype='float32', is_buffer=True)) + model_cls.register_parameter( + 'long_inv_freq', + Parameter(long_inv_freq, dtype='float32', is_buffer=True)) + else: + # Rotary const weights. + embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions( + max_position_embeddings, + rotary_embedding_dim, + ) + # cogvlm attention. + if hasattr(config, 'vision_start') and hasattr( + config, 'vision_length'): + rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_cogvlm_attention_plugin( + max_position_embeddings, rotary_embedding_dim, + rotary_embedding_base, rotary_embedding_scale, + rotary_embedding_scale_type, config.vision_start, + config.vision_length) + else: + rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( + max_position_embeddings, rotary_embedding_dim, + rotary_embedding_base, rotary_embedding_scale, + rotary_embedding_scale_type, rotary_embedding_scaling) + model_cls.register_parameter( + 'embed_positions', + Parameter(embed_positions, dtype='float32', is_buffer=True)) + model_cls.register_parameter( + 'rotary_inv_freq', + Parameter(rotary_inv_freq, dtype='float32', is_buffer=True)) + model_cls.register_parameter( + 'embed_positions_for_gpt_attention', + Parameter(embed_positions_for_gpt_attention, + dtype='float32', + is_buffer=True)) + + @staticmethod + def fill_attention_params(model_cls, attention_params): + if model_cls.position_embedding_type.is_rope(): + if attention_params is None: + attention_params = AttentionParams() + if model_cls.position_embedding_type == PositionEmbeddingType.long_rope: + if hasattr(model_cls, "embed_positions_short_factors"): + return attention_params.fill_attention_const_params_for_long_rope( + model_cls.embed_positions_short_factors.value, + model_cls.embed_positions_long_factors.value, model_cls. + embed_positions_short_factors_for_attention_plugin. + value, model_cls. + embed_positions_long_factors_for_attention_plugin.value, + model_cls.short_mscale, model_cls.long_mscale, + model_cls.short_inv_freq.value, + model_cls.long_inv_freq.value) + else: + return attention_params.fill_attention_const_params_for_rope( + model_cls.embed_positions.value, + model_cls.rotary_inv_freq.value, + model_cls.embed_positions_for_gpt_attention.value) + # Fill nothing. + return attention_params + def forward(self, hidden_states: Tensor, attention_mask=None, @@ -743,16 +857,18 @@ def compute_cross_qkv(encoder_output): if self.position_embedding_type == PositionEmbeddingType.long_rope: short = slice( - self.embed_positions_short_factors_for_attention_plugin. - value, concat([0, 0, 0]), + attention_params. + embed_positions_short_factors_for_attention_plugin, + concat([0, 0, 0]), concat([ max(attention_params.sequence_length, self.original_max_position_embeddings), self.rotary_embedding_dim // 2, 2 ])) long = slice( - self.embed_positions_long_factors_for_attention_plugin. - value, concat([0, 0, 0]), + attention_params. + embed_positions_long_factors_for_attention_plugin, + concat([0, 0, 0]), concat([ max(attention_params.sequence_length, self.original_max_position_embeddings), @@ -767,8 +883,8 @@ def compute_cross_qkv(encoder_output): rotary_cos_sin = slice(embed_positions, concat([select, 0]), sizes=concat([1, shape(long, 1)])) - short_inv_freq = self.short_inv_freq.value - long_inv_freq = self.long_inv_freq.value + short_inv_freq = attention_params.short_inv_freq + long_inv_freq = attention_params.long_inv_freq concat_inv_freq = concat([short_inv_freq, long_inv_freq], dim=0) rotary_inv_freq = slice(concat_inv_freq, concat([select, 0]), @@ -777,11 +893,17 @@ def compute_cross_qkv(encoder_output): rotary_inv_freq = rotary_inv_freq.view((-1, )) else: # The rotary inv freq can be pre-computed. - rotary_inv_freq = self.rotary_inv_freq.value if self.position_embedding_type.is_rope( - ) else None + rotary_inv_freq = getattr(attention_params, "rotary_inv_freq", + None) # Rotary cos/sin cache. - rotary_cos_sin = self.embed_positions_for_gpt_attention.value if self.position_embedding_type.is_rope( - ) else None + rotary_cos_sin = getattr(attention_params, + "embed_positions_for_gpt_attention", + None) + # check if the cache is provided. + if self.position_embedding_type.is_rope(): + assert (rotary_inv_freq is not None) and ( + rotary_cos_sin is not None + ), "rotary_inv_freq and embed_positions_for_gpt_attention must be provided." context, past_key_value = gpt_attention( qkv=qkv, @@ -803,8 +925,8 @@ def compute_cross_qkv(encoder_output): rotary_embedding_dim=self.rotary_embedding_dim, rotary_embedding_base=self.rotary_embedding_base, rotary_embedding_scale_type=self.rotary_embedding_scale_type, - rotary_embedding_short_m_scale=self.short_mscale, - rotary_embedding_long_m_scale=self.long_mscale, + rotary_embedding_short_m_scale=attention_params.short_mscale, + rotary_embedding_long_m_scale=attention_params.long_mscale, rotary_embedding_scale=self.rotary_embedding_scale, rotary_embedding_max_positions=self.max_position_embeddings, rotary_embedding_original_max_positions=self. @@ -908,7 +1030,7 @@ def transpose_for_scores(x, if self.position_embedding_type == PositionEmbeddingType.long_rope: sequence_length = shape(hidden_states, 1) short = slice( - self.embed_positions_short_factors.value, + attention_params.embed_positions_short_factors, concat([0, 0, 0]), concat([ 1, @@ -917,7 +1039,7 @@ def transpose_for_scores(x, self.rotary_embedding_dim ])) long = slice( - self.embed_positions_long_factors.value, + attention_params.embed_positions_long_factors, concat([0, 0, 0]), concat([ 1, @@ -934,10 +1056,10 @@ def transpose_for_scores(x, sizes=shape(short)) embed_positions = cast(embed_positions, self.dtype) elif is_same_dtype(self.dtype, trt.bfloat16): - embed_positions = cast(self.embed_positions.value, + embed_positions = cast(attention_params.embed_positions, trt.bfloat16) else: - embed_positions = cast(self.embed_positions.value, + embed_positions = cast(attention_params.embed_positions, query.dtype) if self.rotary_embedding_dim is not None: @@ -1236,6 +1358,8 @@ def __init__(self, tp_group=None, tp_size=1, tp_rank=0, + cp_group=None, + cp_size=1, relative_attention=False, max_distance=0, num_buckets=0): @@ -1253,6 +1377,8 @@ def __init__(self, self.tp_group = tp_group self.tp_size = tp_size self.tp_rank = tp_rank + self.cp_group = cp_group + self.cp_size = cp_size self.num_layers = num_layers self.apply_query_key_layer_scaling = apply_query_key_layer_scaling @@ -1348,6 +1474,7 @@ def forward(self, if default_net().plugin_config.bert_attention_plugin: # TRT plugin mode assert input_lengths is not None + assert self.cp_size == 1 context = bert_attention( qkv, input_lengths, @@ -1372,6 +1499,9 @@ def transpose_for_scores(x): kv_size = self.attention_head_size * self.num_attention_kv_heads query, key, value = split( qkv, [self.attention_hidden_size, kv_size, kv_size], dim=2) + if self.cp_size > 1 and self.cp_group is not None: + key = allgather(key, self.cp_group, gather_dim=1) + value = allgather(value, self.cp_group, gather_dim=1) query = transpose_for_scores(query) key = transpose_for_scores(key) value = transpose_for_scores(value) @@ -1433,6 +1563,7 @@ def __init__( attention_mask_type=AttentionMaskType.causal, bias=True, dtype=None, + position_embedding_type=PositionEmbeddingType.rope_gpt_neox, rotary_embedding_base=10000.0, rotary_embedding_scaling=None, tp_group=None, @@ -1443,22 +1574,21 @@ def __init__( quant_mode: QuantMode = QuantMode(0), dense_bias=None, ): - super().__init__( - local_layer_idx=local_layer_idx, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_heads=num_kv_heads, - max_position_embeddings=max_position_embeddings, - dtype=dtype, - attention_mask_type=attention_mask_type, - bias=bias, - position_embedding_type=PositionEmbeddingType.rope_gpt_neox, - rotary_embedding_base=rotary_embedding_base, - rotary_embedding_scaling=rotary_embedding_scaling, - tp_group=tp_group, - tp_size=tp_size, - tp_rank=tp_rank, - quant_mode=quant_mode) + super().__init__(local_layer_idx=local_layer_idx, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_kv_heads=num_kv_heads, + max_position_embeddings=max_position_embeddings, + dtype=dtype, + attention_mask_type=attention_mask_type, + bias=bias, + position_embedding_type=position_embedding_type, + rotary_embedding_base=rotary_embedding_base, + rotary_embedding_scaling=rotary_embedding_scaling, + tp_group=tp_group, + tp_size=tp_size, + tp_rank=tp_rank, + quant_mode=quant_mode) self.vision_length = vision_length self.vision_start = vision_start @@ -1480,16 +1610,6 @@ def __init__( dtype=dtype, tp_group=tp_group, tp_size=tp_size) - rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_cogvlm_attention_plugin( - self.max_position_embeddings, self.rotary_embedding_dim, - self.rotary_embedding_base, self.rotary_embedding_scale, - self.rotary_embedding_scale_type, self.vision_start, - self.vision_length) - self.register_parameter('rotary_inv_freq', - Parameter(rotary_inv_freq, dtype='float32')) - self.register_parameter( - 'embed_positions_for_gpt_attention', - Parameter(embed_positions_for_gpt_attention, dtype='float32')) def forward(self, hidden_states: Tensor, @@ -1550,8 +1670,8 @@ def forward(self, ) or self.quant_mode.has_fp8_qdq( ), "FP8 Context FMHA must be used together with the fp8 quantization workflow." - rotary_inv_freq = self.rotary_inv_freq.value - rotary_cos_sin = self.embed_positions_for_gpt_attention.value + rotary_inv_freq = attention_params.rotary_inv_freq + rotary_cos_sin = attention_params.embed_positions_for_gpt_attention attention_output_orig_quant_scale = self.attention_output_orig_quant_scale.value if self.attention_output_orig_quant_scale is not None else None context, past_key_value = gpt_attention( qkv=qkv, diff --git a/tensorrt_llm/mapping.py b/tensorrt_llm/mapping.py index 98e79e0fe..f457337cf 100644 --- a/tensorrt_llm/mapping.py +++ b/tensorrt_llm/mapping.py @@ -17,7 +17,7 @@ class Mapping(object): ''' - A node with 8 GPUs, tp_size = 4, pp_size = 2 + A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2 2 tp groups: @@ -31,6 +31,20 @@ class Mapping(object): - [2, 6] - [3, 7] + A node with 8 GPUs, tp_size = 4, cp_size = 2, pp_size = 1 + + 2 tp groups: + + - [0, 1, 2, 3] + - [4, 5, 6, 7] + + 4 cp groups: + + - [0, 4] + - [1, 5] + - [2, 6] + - [3, 7] + A node with 8 GPUs, moe_tp_size = 2, moe_ep_size = 4 4 moe_tp groups: @@ -82,6 +96,7 @@ def __init__( world_size=1, rank=0, gpus_per_node=8, + cp_size=1, tp_size=1, pp_size=1, moe_tp_size=-1, # -1 means no moe @@ -91,9 +106,9 @@ def __init__( moe_tp_size = tp_size moe_ep_size = 1 - if pp_size * tp_size != world_size: + if pp_size * cp_size * tp_size != world_size: raise ValueError( - f"world_size must equal to pp_size * tp_size, but got {world_size} != {pp_size} * {tp_size}" + f"world_size must equal to pp_size * cp_size * tp_size, but got {world_size} != {pp_size} * {cp_size} * {tp_size}" ) moe_tp_ep_size = moe_tp_size * moe_ep_size @@ -102,7 +117,11 @@ def __init__( f"tp_size must equal to moe_tp_size * moe_ep_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size}" ) + if moe_ep_size != 1 and cp_size > 1: + raise NotImplementedError("CP don't support MoE tp/ep yet") + self.tp_size = tp_size + self.cp_size = cp_size self.pp_size = pp_size self.moe_tp_size = moe_tp_size self.moe_ep_size = moe_ep_size @@ -111,19 +130,29 @@ def __init__( self.gpus_per_node = gpus_per_node self.pp_groups = [] + self.cp_groups = [] self.tp_groups = [] self.moe_tp_groups = [] self.moe_ep_groups = [] # init pp group - for i in range(tp_size): - ranks = range(i, world_size, tp_size) + for i in range(tp_size * cp_size): + ranks = range(i, world_size, tp_size * cp_size) self.pp_groups.append(list(ranks)) + # init cp group + for i in range(pp_size): + for j in range(tp_size): + ranks = range(i * tp_size * cp_size + j, + (i + 1) * tp_size * cp_size + j, tp_size) + self.cp_groups.append(list(ranks)) + # init tp group for i in range(pp_size): - ranks = range(i * tp_size, (i + 1) * tp_size) - self.tp_groups.append(list(ranks)) + for j in range(cp_size): + ranks = range(i * tp_size * cp_size + j * tp_size, + i * tp_size * cp_size + (j + 1) * tp_size) + self.tp_groups.append(list(ranks)) # init moe tp group for i in range(pp_size): @@ -139,13 +168,18 @@ def __init__( i * moe_tp_ep_size + (j + 1) * moe_ep_size) self.moe_ep_groups.append(list(ranks)) - self.pp_rank = self.rank // self.tp_size - self.tp_rank = self.rank % self.tp_size + self.pp_rank = self.rank // (self.tp_size * self.cp_size) + self.cp_rank = self.rank % (self.tp_size * self.cp_size) // self.tp_size + self.tp_rank = self.rank % (self.tp_size * self.cp_size) % self.tp_size self.moe_tp_rank = self.tp_rank // self.moe_ep_size self.moe_ep_rank = self.tp_rank % self.moe_ep_size - self.tp_group = self.tp_groups[self.pp_rank] - self.pp_group = self.pp_groups[self.tp_rank] + self.tp_group = self.tp_groups[self.pp_rank * self.cp_size + + self.cp_rank] + self.cp_group = self.cp_groups[self.pp_rank * self.tp_size + + self.tp_rank] + self.pp_group = self.pp_groups[self.cp_rank * self.tp_size + + self.tp_rank] self.moe_tp_group = self.moe_tp_groups[self.pp_rank * moe_ep_size + self.moe_ep_rank] self.moe_ep_group = self.moe_ep_groups[self.pp_rank * moe_tp_size + @@ -154,6 +188,9 @@ def __init__( self.node_rank = self.rank // self.gpus_per_node self.local_rank = self.rank % self.gpus_per_node + def has_cp(self): + return self.cp_size > 1 + def get_node_rank(self, rank: int): return rank // self.gpus_per_node @@ -173,13 +210,13 @@ def has_pp(self): return self.pp_size > 1 def prev_pp_rank(self): - p = self.rank - self.tp_size + p = self.rank - self.tp_size * self.cp_size if p < 0: p = p + self.world_size return p def next_pp_rank(self): - p = self.rank + self.tp_size + p = self.rank + self.tp_size * self.cp_size if p >= self.world_size: p = p - self.world_size return p @@ -197,6 +234,7 @@ def pp_layers(self, num_layers: int) -> List[int]: return list(layers_range) def ep_experts(self, num_experts: int) -> List[int]: + assert self.cp_size == 1 experts_per_rank = num_experts // self.moe_ep_size experts_range = range(self.moe_ep_rank * experts_per_rank, (self.moe_ep_rank + 1) * experts_per_rank) @@ -211,6 +249,7 @@ def to_dict(self): 'world_size': self.world_size, 'rank': self.rank, 'gpus_per_node': self.gpus_per_node, + 'cp_size': self.cp_size, 'tp_size': self.tp_size, 'pp_size': self.pp_size, 'moe_tp_size': self.moe_tp_size, diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index 13820cafa..81c856525 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -72,7 +72,9 @@ 'GPTNeoXModel', 'GPTNeoXForCausalLM', 'PhiModel', + 'PhiConfig', 'Phi3Model', + 'Phi3Config', 'PhiForCausalLM', 'Phi3ForCausalLM', 'ChatGLMForCausalLM', @@ -114,6 +116,7 @@ 'FalconForCausalLM': FalconForCausalLM, 'PhiForCausalLM': PhiForCausalLM, 'Phi3ForCausalLM': Phi3ForCausalLM, + 'Phi3VForCausalLM': Phi3ForCausalLM, 'Phi3SmallForCausalLM': Phi3ForCausalLM, 'MambaForCausalLM': MambaForCausalLM, 'GPTNeoXForCausalLM': GPTNeoXForCausalLM, diff --git a/tensorrt_llm/models/chatglm/model.py b/tensorrt_llm/models/chatglm/model.py index 531b2f9f6..b819ddd92 100644 --- a/tensorrt_llm/models/chatglm/model.py +++ b/tensorrt_llm/models/chatglm/model.py @@ -39,8 +39,6 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): tp_rank = config.mapping.tp_rank layernorm_epsilon = config.norm_epsilon - rope_base = 10000.0 - rotary_embedding_scaling = None self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm self.alpha = (2 * config.num_hidden_layers)**0.5 norm_cls = RmsNorm if config.rmsnorm else LayerNorm @@ -51,14 +49,8 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): attention_mask_type = AttentionMaskType.bidirectional elif config.chatglm_version == 'chatglm2': attention_mask_type = AttentionMaskType.causal - if config.rope_ratio > 1: - rotary_embedding_scaling = { - 'type': 'linear', - 'factor': config.rope_ratio - } elif config.chatglm_version == 'chatglm3': attention_mask_type = AttentionMaskType.causal - rope_base *= config.rope_ratio self.input_layernorm = norm_cls( normalized_shape=hidden_size, @@ -82,9 +74,9 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): dense_bias=config.add_bias_linear, dtype=config.dtype, position_embedding_type=config.position_embedding_type, - rotary_embedding_base=rope_base, - rotary_embedding_scaling=rotary_embedding_scaling, - rotary_embedding_percentage=0.5, + rotary_embedding_base=config.rotary_base, + rotary_embedding_scaling=config.rotary_scaling, + rotary_embedding_percentage=config.rotary_pct, tp_group=tp_group, tp_size=tp_size, tp_rank=tp_rank, diff --git a/tensorrt_llm/models/cogvlm/config.py b/tensorrt_llm/models/cogvlm/config.py index dbeac68e3..e54d60fa7 100644 --- a/tensorrt_llm/models/cogvlm/config.py +++ b/tensorrt_llm/models/cogvlm/config.py @@ -35,6 +35,7 @@ def __init__(self, self.rotary_scaling = rotary_scaling self.vision_start = vision_start self.vision_length = vision_length + self.position_embedding_type = 'rope_gpt_neox' super().__init__(**kwargs) def to_dict(self): diff --git a/tensorrt_llm/models/cogvlm/model.py b/tensorrt_llm/models/cogvlm/model.py index db2856fe8..b05853ffb 100644 --- a/tensorrt_llm/models/cogvlm/model.py +++ b/tensorrt_llm/models/cogvlm/model.py @@ -51,6 +51,7 @@ def __init__(self, config: CogVLMConfig, layer_idx: int): dtype=config.dtype, attention_mask_type=AttentionMaskType.causal, bias=config.attn_bias, + position_embedding_type=config.position_embedding_type, rotary_embedding_base=config.rotary_base, rotary_embedding_scaling=config.rotary_scaling, tp_group=config.mapping.tp_group, diff --git a/tensorrt_llm/models/dit/model.py b/tensorrt_llm/models/dit/model.py index c82524bb5..68b009955 100644 --- a/tensorrt_llm/models/dit/model.py +++ b/tensorrt_llm/models/dit/model.py @@ -19,9 +19,10 @@ import numpy as np import tensorrt as trt -from ..._utils import str_dtype_to_trt, trt_dtype_to_np, trt_dtype_to_str -from ...functional import (Tensor, arange, chunk, concat, constant, cos, exp, - expand, shape, silu, sin, slice, split, unsqueeze) +from ..._utils import str_dtype_to_trt, trt_dtype_to_str +from ...functional import (Tensor, allgather, arange, chunk, concat, constant, + cos, exp, expand, shape, silu, sin, slice, split, + unsqueeze) from ...layers import MLP, BertAttention, Conv2d, Embedding, LayerNorm, Linear from ...mapping import Mapping from ...module import Module, ModuleList @@ -33,7 +34,7 @@ def modulate(x, shift, scale, dtype): ones = 1.0 if dtype is not None: - ones = constant(np.ones(1, dtype=trt_dtype_to_np(dtype))) + ones = constant(np.ones(1, dtype=np.float32)).cast(dtype) return x * (ones + unsqueeze(scale, 1)) + unsqueeze(shift, 1) @@ -129,6 +130,8 @@ def __init__(self, tp_group=mapping.tp_group, tp_size=mapping.tp_size, tp_rank=mapping.tp_rank, + cp_group=mapping.cp_group, + cp_size=mapping.cp_size, dtype=dtype) self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.mlp = MLP(hidden_size=hidden_size, @@ -203,6 +206,7 @@ def __init__(self, config: PretrainedConfig): self.num_heads = config.num_attention_heads self.dtype = str_dtype_to_trt(config.dtype) self.cfg_scale = config.cfg_scale + self.mapping = config.mapping self.x_embedder = PatchEmbed(config.input_size, config.patch_size, @@ -284,11 +288,19 @@ def forward_without_cfg(self, x, t, y): c = t + y input_length = constant(np.array([x.shape[1]], dtype=np.int32)) input_lengths = expand(input_length, unsqueeze(shape(x, 0), 0)) + # Split squeence for CP here + if self.mapping.cp_size > 1: + assert x.shape[1] % self.mapping.cp_size == 0 + x = chunk(x, self.mapping.cp_size, dim=1)[self.mapping.cp_rank] for block in self.blocks: x = block(x, c, input_lengths) # (N, T, D) self.register_network_output('before_final_layer', x) x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) self.register_network_output('final_layer', x) + + # All gather after CP + if self.mapping.cp_size > 1: + x = allgather(x, self.mapping.cp_group, gather_dim=1) x = self.unpatchify(x) # (N, out_channels, H, W) self.register_network_output('unpatchify', x) return x diff --git a/tensorrt_llm/models/llama/config.py b/tensorrt_llm/models/llama/config.py index e7454110a..4d2ad70cf 100644 --- a/tensorrt_llm/models/llama/config.py +++ b/tensorrt_llm/models/llama/config.py @@ -36,6 +36,7 @@ def __init__(self, residual_mlp: bool = False, disable_weight_only_quant_plugin: bool = False, moe: Optional[Union[MoeConfig, dict]] = None, + remove_duplicated_kv_heads: bool = False, **kwargs): self.mlp_bias = mlp_bias self.attn_bias = attn_bias @@ -55,6 +56,7 @@ def __init__(self, moe = MoeConfig.from_dict(moe) assert isinstance(moe, MoeConfig) self.moe = moe.validate() + self.remove_duplicated_kv_heads = remove_duplicated_kv_heads super().__init__(**kwargs) @@ -122,6 +124,8 @@ def from_hugging_face( residual_mlp = getattr(hf_config, "parallel_attn_mlp_res", False) disable_weight_only_quant_plugin = kwargs.pop( 'disable_weight_only_quant_plugin', False) + remove_duplicated_kv_heads = kwargs.pop('remove_duplicated_kv_heads', + False) if hf_config.model_type == "mixtral" or hf_config.model_type == "arctic": # HF LLaMA-type models are implicitly using gated activation. @@ -168,6 +172,7 @@ def from_hugging_face( moe=moe_config, mapping=mapping, quantization=quant_config, + remove_duplicated_kv_heads=remove_duplicated_kv_heads, **kwargs) @classmethod diff --git a/tensorrt_llm/models/llama/convert.py b/tensorrt_llm/models/llama/convert.py index 72f5f9ed2..360467214 100644 --- a/tensorrt_llm/models/llama/convert.py +++ b/tensorrt_llm/models/llama/convert.py @@ -1378,6 +1378,9 @@ def __init__(self, config: PretrainedConfig): self.head_size = None if not hasattr(config, "head_size") else config.head_size self._qkv_weights = {} + self.remove_duplicated_kv_heads = getattr(config, + 'remove_duplicated_kv_heads', + False) @staticmethod def is_qkv_weight(name): @@ -1411,6 +1414,17 @@ def split_qkv_weights(self, layer_idx): weights = self._qkv_weights.pop(layer_idx) # to prevent memory leak. q, k, v = (torch.tensor(weights[t]) for t in ['q', 'k', 'v']) + if self.remove_duplicated_kv_heads: + head_size = self.hidden_size // self.num_heads if self.head_size is None else self.head_size + k = k.reshape( + [k.shape[0] // head_size // 2, 2, head_size, self.hidden_size]) + v = v.reshape( + [v.shape[0] // head_size // 2, 2, head_size, self.hidden_size]) + assert (k[:, 0] == k[:, 1]).all() + assert (v[:, 0] == v[:, 1]).all() + k = k[:, 0].reshape([-1, self.hidden_size]) + v = v[:, 0].reshape([-1, self.hidden_size]) + if not self.is_mha: head_size = self.hidden_size // self.num_heads if self.head_size is None else self.head_size if self.num_kv_heads < self.tp_size: @@ -2088,6 +2102,29 @@ def process_and_assign_weight(v: List[torch.Tensor], return weights +def load_torch_meta_ckpt(meta_ckpt_path: Path): + ''' + meta_ckpt_path: The format of meta_ckpt_path is like /consolidated.xx There are two possible cases: + 1. A file like /consolidated.xx.pth, loading it by torch.load directly + 2. A folder like /consolidated.xx/, need to load all weights in the folder. + ''' + file_path = meta_ckpt_path.parent / (meta_ckpt_path.name + ".pth") + if file_path.exists() and file_path.is_file(): + return torch.load(file_path, map_location="cpu") + else: + folder_path = meta_ckpt_path + assert folder_path.exists() and folder_path.is_dir() + + ckpts = list(Path(folder_path).glob("consolidated-*.pth")) + + all_weights = {} + for ckpt in ckpts: + _weight = torch.load(ckpt, map_location="cpu") + all_weights = all_weights | _weight + del _weight + return all_weights + + def load_weights_from_meta_ckpt(meta_ckpt_dir: str, config: LLaMAConfig): torch_dtype = str_dtype_to_torch(config.dtype) mapping = config.mapping @@ -2145,9 +2182,8 @@ def get_current_weights(num_ckpts): file_ids = list(range(fs, fs + nf)) ckpts = [] for f in file_ids: - ckpt = torch.load(Path(meta_ckpt_dir, - f"consolidated.{f:02d}.pth"), - map_location="cpu") + ckpt = load_torch_meta_ckpt( + Path(meta_ckpt_dir, f"consolidated.{f:02d}")) ckpts.append(ckpt) return gather_ckpts(ckpts) elif num_ckpts < mapping.tp_size: @@ -2158,15 +2194,13 @@ def get_current_weights(num_ckpts): ckpt_rank = mapping.tp_rank % ranks_per_ckpt nH_per_ckpt = config.num_attention_heads // num_ckpts assert (nH_per_ckpt % ranks_per_ckpt) == 0 - ckpt = torch.load(Path(meta_ckpt_dir, - f"consolidated.{ckpt_fid:02d}.pth"), - map_location="cpu") + ckpt = load_torch_meta_ckpt( + Path(meta_ckpt_dir, f"consolidated.{ckpt_fid:02d}")) return split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank) # num_ckpts == tensor_parallel, 1:1 mapping from files to TP - return torch.load(Path(meta_ckpt_dir, - f"consolidated.{mapping.tp_rank:02d}.pth"), - map_location="cpu") + return load_torch_meta_ckpt( + Path(meta_ckpt_dir, f"consolidated.{mapping.tp_rank:02d}")) def permute(w, nH, d, dH): # due to MQA's wk, nH*dH != d could be true @@ -2205,9 +2239,8 @@ def gather_embedding(cur_embed, name: str, num_ckpts): if load_weights_from_meta_ckpt.saved_embed is None: embeds = [None] * num_ckpts for i in range(num_ckpts): - ckpt = torch.load(Path(meta_ckpt_dir, - f"consolidated.{i:02d}.pth"), - map_location="cpu") + ckpt = load_torch_meta_ckpt( + Path(meta_ckpt_dir, f"consolidated.{i:02d}")) embeds[i] = ckpt[name] embed = combine_embeddings(embeds, num_ckpts).to(torch_dtype) load_weights_from_meta_ckpt.saved_embed = embed @@ -2220,14 +2253,19 @@ def gather_embedding(cur_embed, name: str, num_ckpts): num_kv_heads = config.num_key_value_heads mha_mode = (num_kv_heads == config.num_attention_heads) - ckpts = list(Path(meta_ckpt_dir).glob("consolidated.*.pth")) + ckpts = list(Path(meta_ckpt_dir).glob("consolidated.*")) num_ckpts = len(ckpts) # llama/llama2 doesn't have MQA. So, simplifying loader logic by not worrying about it. assert num_kv_heads > 1 or num_kv_heads >= num_ckpts, \ f"We don't know how the {num_kv_heads} KV heads are distributed among {num_ckpts} checkpoints." - head_size = config.hidden_size // config.num_attention_heads + tik = time.time() ckpt = get_current_weights(num_ckpts) + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + logger.info(f'[{mapping.rank}] get_current_weights. Total time: {t}') + + head_size = config.hidden_size // config.num_attention_heads layers_range = mapping.pp_layers(config.num_hidden_layers) for l in layers_range: @@ -2281,12 +2319,12 @@ def gather_embedding(cur_embed, name: str, num_ckpts): else: # layer specific weights layer_idx = extract_layer_idx(k) - # Meta's recipe of not using fp8 rowwise for the first and last layer. - use_fp8_rowwise_in_layer = use_fp8_rowwise and ( - layer_idx not in exclude_layers_id) - if layer_idx is None or int(layer_idx) not in layers_range: continue + + # Meta's recipe of not using fp8 rowwise for the first and last layer. + use_fp8_rowwise_in_layer = use_fp8_rowwise and ( + int(layer_idx) not in exclude_layers_id) idx = int(layer_idx) - layers_range[0] tllm_prex = f'transformer.layers.{idx}.' diff --git a/tensorrt_llm/models/llama/model.py b/tensorrt_llm/models/llama/model.py index c852362c0..f5cc445c4 100644 --- a/tensorrt_llm/models/llama/model.py +++ b/tensorrt_llm/models/llama/model.py @@ -311,7 +311,8 @@ def from_hugging_face( mapping=mapping, quant_config=quant_config, **kwargs) - + if config.remove_duplicated_kv_heads: + config.num_key_value_heads /= 2 if use_preloading: assert not load_by_shard weights = load_weights_from_hf_model(hf_model, config) diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 2aa5bf3ba..a3f8e8590 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -210,6 +210,10 @@ def __init__(self, raise NotImplementedError( "Embedding table cannot be shared for pipeline parallelism") + if share_embedding_table and mapping.cp_size > 1: + raise NotImplementedError( + "Embedding table cannot be shared for context parallelism") + if head_size is None: head_size = hidden_size // num_attention_heads self.head_size = head_size @@ -276,6 +280,7 @@ def quant_mode(self): def set_rank(self, rank): self.mapping = Mapping(self.mapping.world_size, rank=rank, + cp_size=self.mapping.cp_size, tp_size=self.mapping.tp_size, pp_size=self.mapping.pp_size, moe_tp_size=self.mapping.moe_tp_size, @@ -678,6 +683,9 @@ def __init__(self, config: PretrainedConfig, transformer, lm_head): self.lm_head = lm_head self.mup_width_multiplier = getattr(config, 'mup_width_multiplier', None) + # Create constant attention parameters to be reused by all layers. + Attention.create_attention_const_params(self, config) + self.position_embedding_type = config.position_embedding_type def forward(self, input_ids: Tensor, @@ -693,6 +701,11 @@ def forward(self, prompt_vocab_size: Optional[Tensor] = None, lora_params=None, spec_decoding_params=None): + + # fill attention params. + attention_params = Attention.fill_attention_params( + self, attention_params) + kwargs = { 'input_ids': input_ids, 'position_ids': position_ids, diff --git a/tensorrt_llm/models/phi/config.py b/tensorrt_llm/models/phi/config.py new file mode 100644 index 000000000..b8bf4dc95 --- /dev/null +++ b/tensorrt_llm/models/phi/config.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch + +from ..._utils import torch_dtype_to_str +from ...logger import logger +from ...mapping import Mapping +from ..modeling_utils import PretrainedConfig, QuantConfig + + +class PhiConfig(PretrainedConfig): + + def __init__(self, + *, + rotary_base: float = 10000.0, + rotary_scaling: Optional[dict] = None, + **kwargs): + + self.rotary_base = rotary_base + self.rotary_scaling = rotary_scaling + + super().__init__(**kwargs) + + def to_dict(self): + output = super().to_dict() + # Serialize the fields added in PhiConfig + + output['rotary_base'] = self.rotary_base + output['rotary_scaling'] = self.rotary_scaling + + return output + + @classmethod + def from_hugging_face( + cls, + hf_config_or_dir: Union[str, 'transformers.PretrainedConfig'], + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + **kwargs): + import transformers + + if isinstance(hf_config_or_dir, transformers.PretrainedConfig): + hf_config = hf_config_or_dir + else: + hf_config_dir = str(hf_config_or_dir) + + hf_config = transformers.AutoConfig.from_pretrained( + hf_config_dir, trust_remote_code=True) + + num_key_value_heads = getattr(hf_config, "num_key_value_heads", + hf_config.num_attention_heads) + rotary_scaling = getattr(hf_config, "rope_scaling", None) + rotary_base = getattr(hf_config, "rope_theta", 10000.0) + if dtype == 'auto': + dtype = getattr(hf_config, 'torch_dtype', None) + if dtype is None: + dtype = 'float16' + if isinstance(dtype, torch.dtype): + dtype = torch_dtype_to_str(dtype) + if dtype == 'float32': + dtype = 'float16' + if dtype == 'bfloat16' and torch.cuda.get_device_properties( + 0).major < 8: + logger.warning( + "Pre SM 80 GPUs do not support bfloat16, fallback to float16") + dtype = 'float16' + + return cls(architecture=hf_config.architectures[0], + dtype=dtype, + num_hidden_layers=hf_config.num_hidden_layers, + num_attention_heads=hf_config.num_attention_heads, + hidden_size=hf_config.hidden_size, + intermediate_size=hf_config.intermediate_size, + num_key_value_heads=num_key_value_heads, + vocab_size=hf_config.vocab_size, + position_embedding_type='rope_gpt_neox', + max_position_embeddings=hf_config.max_position_embeddings, + hidden_act=hf_config.hidden_act, + rotary_base=rotary_base, + rotary_scaling=rotary_scaling, + rotary_pct=hf_config.partial_rotary_factor, + mapping=mapping, + quantization=quant_config, + **kwargs) diff --git a/tensorrt_llm/models/phi/convert.py b/tensorrt_llm/models/phi/convert.py index e72837a8d..6eb8563bb 100644 --- a/tensorrt_llm/models/phi/convert.py +++ b/tensorrt_llm/models/phi/convert.py @@ -3,11 +3,15 @@ from ..._utils import str_dtype_to_torch -def convert_hf_weights(hf_model, dtype, args=None): - torch_dtype = str_dtype_to_torch(dtype) +def load_weights_from_hf_model(hf_model, config): + torch_dtype = str_dtype_to_torch(config.dtype) hf_state_dict = hf_model.state_dict() weights = {} - + is_weight_only = config.quant_mode.is_weight_only() + if config.quant_mode.is_int8_weight_only(): + plugin_weight_only_quant_type = torch.int8 + elif config.quant_mode.is_int4_weight_only(): + plugin_weight_only_quant_type = torch.quint4x2 # replace key name for key, value in hf_state_dict.items(): # Decoder Layers @@ -21,6 +25,7 @@ def convert_hf_weights(hf_model, dtype, args=None): "transformer.vocab_embedding.weight") # Final Layer norm key = key.replace("model.final_layernorm.", "transformer.ln_f.") + weights[key] = value.to(torch_dtype).cpu() # merge qkv weights @@ -41,6 +46,22 @@ def convert_hf_weights(hf_model, dtype, args=None): weights[f"{prefix}attention.qkv.weight"] = torch.cat(qkv_weights, dim=0) weights[f"{prefix}attention.qkv.bias"] = torch.cat(qkv_bias, dim=0) + if is_weight_only: + kw_list = [ + 'attention.dense.weight', 'attention.qkv.weight', 'mlp.fc.weight', + 'mlp.proj.weight' + ] + for key in [ + weight_name for kw in kw_list for weight_name in weights + if kw in weight_name + ]: + v = weights[key].t().contiguous().cpu() + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + v, plugin_weight_only_quant_type) + weights[key] = processed_torch_weights + weights[key.replace('.weight', + '.per_channel_scale')] = torch_weight_scales return weights @@ -51,11 +72,12 @@ def convert_hf_config(hf_config, dtype, args): 'dtype': dtype, 'num_hidden_layers': hf_config.num_hidden_layers, 'num_attention_heads': hf_config.num_key_value_heads, - 'partial_rotary_factor': hf_config.partial_rotary_factor, + 'rotary_pct': hf_config.partial_rotary_factor, 'rope_theta': hf_config.rope_theta, 'hidden_size': hf_config.hidden_size, 'intermediate_size': hf_config.intermediate_size, 'vocab_size': hf_config.vocab_size, + 'position_embedding_type': 'rope_gpt_neox', 'max_position_embeddings': hf_config.max_position_embeddings, 'hidden_act': hf_config.hidden_act, 'share_embedding_table': False, diff --git a/tensorrt_llm/models/phi/model.py b/tensorrt_llm/models/phi/model.py index 7f26488fd..830c1428d 100644 --- a/tensorrt_llm/models/phi/model.py +++ b/tensorrt_llm/models/phi/model.py @@ -12,18 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +import copy +import os +from typing import Optional, Union +import safetensors from transformers import AutoModelForCausalLM from ..._utils import pad_vocab_size -from ...functional import PositionEmbeddingType, Tensor +from ...functional import Tensor from ...layers import (MLP, Attention, AttentionMaskType, Embedding, LayerNorm, ParallelLMHead) +from ...mapping import Mapping from ...module import Module +from ...quantization import QuantAlgo from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, - PretrainedConfig, save_checkpoint) -from .convert import convert_hf_config, convert_hf_weights + PretrainedConfig, QuantConfig) +from .config import PhiConfig +from .convert import load_weights_from_hf_model class PhiDecoderLayer(Module): @@ -44,8 +50,8 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): local_layer_idx=local_layer_idx, hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, - rotary_embedding_percentage=config.partial_rotary_factor, - position_embedding_type=PositionEmbeddingType.rope_gpt_neox, + rotary_embedding_percentage=config.rotary_pct, + position_embedding_type=config.position_embedding_type, rotary_embedding_base=config.rotary_base, max_position_embeddings=config.max_position_embeddings, dtype=config.dtype, @@ -162,21 +168,95 @@ def check_config(self, config): config.set_if_not_exist('rotary_base', 10000.0) @classmethod - def convert_hf_checkpoint(cls, - hf_model_dir: str, - dtype: Optional[str] = "float16", - output_dir: Optional[str] = None, - args=None): - ''' - Convert Huggingface checkpoint to TRT-LLM checkpoint - ''' - hf_model = AutoModelForCausalLM.from_pretrained(hf_model_dir, - torch_dtype="auto", - trust_remote_code=True) - config = convert_hf_config(hf_model.config, dtype, args) - weights = convert_hf_weights(hf_model, dtype, args) - - if output_dir: - save_checkpoint(output_dir, config=config, weights=weights) - - return {"weights": weights, "config": config} + def from_hugging_face( + cls, + hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'], + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + **kwargs): + import transformers + + assert hf_model_or_dir is not None + use_preloading = isinstance(hf_model_or_dir, + transformers.PreTrainedModel) + if use_preloading: + hf_model = hf_model_or_dir + hf_config_or_dir = hf_model.config + else: + hf_model_dir = hf_model_or_dir + hf_config_or_dir = hf_model_or_dir + config = PhiConfig.from_hugging_face(hf_config_or_dir, + dtype=dtype, + mapping=mapping, + quant_config=quant_config, + **kwargs) + if not use_preloading: + hf_model = AutoModelForCausalLM.from_pretrained( + hf_model_dir, torch_dtype="auto", trust_remote_code=True) + + assert isinstance(hf_model, transformers.PreTrainedModel) + + weights = load_weights_from_hf_model(hf_model, config) + + model = cls(config) + model.load(weights) + return model + + @classmethod + def quantize( + cls, + hf_model_dir: str, + output_dir: str, + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + *, + device: str = 'cuda', + calib_dataset: str = 'cnn_dailymail', + calib_batches: int = 512, + calib_batch_size: int = 1, + calib_max_seq_length: int = 512, + random_seed: int = 1234, + tokenizer_max_seq_length: int = 2048, + **kwargs, + ): + DEFAULT_MODELOPT_FLOW = [ + QuantAlgo.W4A16_AWQ, + QuantAlgo.FP8, + QuantAlgo.W8A8_SQ_PER_CHANNEL, + ] + NATIVE_QUANT_FLOW = [QuantAlgo.W4A16, QuantAlgo.W8A16, None] + + config = PhiConfig.from_hugging_face(hf_model_dir, + dtype=dtype, + mapping=mapping, + quant_config=quant_config, + **kwargs) + + if quant_config.quant_algo in DEFAULT_MODELOPT_FLOW: + super().quantize(hf_model_dir, + output_dir, + dtype=config.dtype, + mapping=config.mapping, + quant_config=config.quantization, + device=device, + calib_dataset=calib_dataset, + calib_batches=calib_batches, + calib_batch_size=calib_batch_size, + calib_max_seq_length=calib_max_seq_length, + random_seed=random_seed, + tokenizer_max_seq_length=tokenizer_max_seq_length) + else: + assert quant_config.quant_algo in NATIVE_QUANT_FLOW, f"Internal error: shall call Modelopt for this quantization {quant_config}" + + hf_model = AutoModelForCausalLM.from_pretrained( + hf_model_dir, torch_dtype="auto", trust_remote_code=True) + + for rank in range(mapping.world_size): + weights = load_weights_from_hf_model(hf_model, config) + config = copy.deepcopy(config) + config.set_rank(rank) + safetensors.torch.save_file( + weights, os.path.join(output_dir, + f'rank{rank}.safetensors')) diff --git a/tensorrt_llm/models/phi3/config.py b/tensorrt_llm/models/phi3/config.py new file mode 100644 index 000000000..ce1cfda98 --- /dev/null +++ b/tensorrt_llm/models/phi3/config.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch + +from ..._utils import torch_dtype_to_str +from ...logger import logger +from ...mapping import Mapping +from ..modeling_utils import PretrainedConfig, QuantConfig + + +class Phi3Config(PretrainedConfig): + + def __init__(self, + *, + rotary_base: float = 10000.0, + rotary_scaling: Optional[dict] = None, + **kwargs): + + self.rotary_base = rotary_base + self.rotary_scaling = rotary_scaling + + super().__init__(**kwargs) + + def to_dict(self): + output = super().to_dict() + # Serialize the fields added in PhiConfig + + output['rotary_base'] = self.rotary_base + output['rotary_scaling'] = self.rotary_scaling + + return output + + @classmethod + def from_hugging_face( + cls, + hf_config_or_dir: Union[str, 'transformers.PretrainedConfig'], + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + **kwargs): + import transformers + + if isinstance(hf_config_or_dir, transformers.PretrainedConfig): + hf_config = hf_config_or_dir + else: + hf_config_dir = str(hf_config_or_dir) + + hf_config = transformers.AutoConfig.from_pretrained( + hf_config_dir, trust_remote_code=True) + + num_key_value_heads = getattr(hf_config, "num_key_value_heads", + hf_config.num_attention_heads) + if dtype == 'auto': + dtype = getattr(hf_config, 'torch_dtype', None) + if dtype is None: + dtype = 'float16' + if isinstance(dtype, torch.dtype): + dtype = torch_dtype_to_str(dtype) + if dtype == 'float32': + dtype = 'float16' + if dtype == 'bfloat16' and torch.cuda.get_device_properties( + 0).major < 8: + logger.warning( + "Pre SM 80 GPUs do not support bfloat16, fallback to float16") + dtype = 'float16' + + small_variant = hf_config.architectures[0] == "Phi3SmallForCausalLM" + if small_variant: + kwargs['gegelu_limit'] = getattr(hf_config, "gegelu_limit", None) + kwargs['rotary_base'] = hf_config.rope_embedding_base + kwargs['mup_attn_multiplier'] = getattr(hf_config, + "mup_attn_multiplier", None) + kwargs['mup_embedding_multiplier'] = getattr( + hf_config, "mup_embedding_multiplier", None) + kwargs['mup_use_scaling'] = getattr(hf_config, "mup_use_scaling", + None) + kwargs['mup_width_multiplier'] = getattr(hf_config, + "mup_width_multiplier", + None) + kwargs['blocksparse_block_size'] = getattr( + hf_config, "blocksparse_block_size", None) + kwargs['blocksparse_homo_head_pattern'] = getattr( + hf_config, "blocksparse_homo_head_pattern", None) + kwargs['blocksparse_num_local_blocks'] = getattr( + hf_config, "blocksparse_num_local_blocks", None) + kwargs['blocksparse_vertical_stride'] = getattr( + hf_config, "blocksparse_vert_stride", None) + kwargs['dense_attention_every_n_layers'] = getattr( + hf_config, "dense_attention_every_n_layers", None) + else: + kwargs['rotary_base'] = hf_config.rope_theta + kwargs['norm_epsilon'] = hf_config.rms_norm_eps + kwargs['position_embedding_type'] = 'rope_gpt_neox' + if hf_config.max_position_embeddings >= 128000: + kwargs[ + 'original_max_position_embeddings'] = hf_config.original_max_position_embeddings + kwargs['position_embedding_type'] = "long_rope" + kwargs['longrope_scaling_short_factors'] = hf_config.rope_scaling[ + "short_factor"] + kwargs['longrope_scaling_long_factors'] = hf_config.rope_scaling[ + "long_factor"] + if small_variant: + kwargs['longrope_long_mscale'] = hf_config.rope_scaling[ + "long_mscale"] + kwargs['longrope_short_mscale'] = hf_config.rope_scaling[ + "short_mscale"] + + return cls(architecture=hf_config.architectures[0], + dtype=dtype, + num_hidden_layers=hf_config.num_hidden_layers, + num_attention_heads=hf_config.num_attention_heads, + hidden_size=hf_config.hidden_size, + intermediate_size=hf_config.intermediate_size, + num_key_value_heads=num_key_value_heads, + vocab_size=hf_config.vocab_size, + max_position_embeddings=hf_config.max_position_embeddings, + hidden_act="swiglu" + if hf_config.hidden_act == 'silu' else hf_config.hidden_act, + mapping=mapping, + quantization=quant_config, + **kwargs) diff --git a/tensorrt_llm/models/phi3/convert.py b/tensorrt_llm/models/phi3/convert.py index da2940178..9ee6821db 100644 --- a/tensorrt_llm/models/phi3/convert.py +++ b/tensorrt_llm/models/phi3/convert.py @@ -1,15 +1,19 @@ import torch -from tensorrt_llm.quantization import QuantAlgo - from ..._utils import str_dtype_to_torch from .split_weights import shuffle_qkv_weights, split_weights_tp -def convert_hf_weights(hf_model, dtype, config, small_variant, args, rank): - torch_dtype = str_dtype_to_torch(dtype) +def load_weights_from_hf_model(hf_model, config): + torch_dtype = str_dtype_to_torch(config.dtype) hf_state_dict = hf_model.state_dict() weights = {} + + config.quant_mode.is_weight_only() + if config.quant_mode.is_int8_weight_only(): + torch.int8 + elif config.quant_mode.is_int4_weight_only(): + torch.quint4x2 # replace key name for key, value in hf_state_dict.items(): # Decoder Layers @@ -25,9 +29,8 @@ def convert_hf_weights(hf_model, dtype, config, small_variant, args, rank): key = key.replace("mlp.fc1.", "mlp.fc.") key = key.replace("mlp.fc2.", "mlp.proj.") key = key.replace("mlp.gate_up_proj.", "mlp.fc.") - key = key.replace( - "mlp.up_proj.", - "mlp.fc." if small_variant else "mlp.gate.") #128k + key = key.replace("mlp.up_proj.", "mlp.fc." if config.architecture + == 'Phi3SmallForCausalLM' else "mlp.gate.") #128k key = key.replace("mlp.down_proj.", "mlp.proj.") #128k key = key.replace("mlp.gate_proj.", "mlp.fc.") #128k key = key.replace("o_proj.", "dense.") #128k @@ -62,7 +65,7 @@ def convert_hf_weights(hf_model, dtype, config, small_variant, args, rank): weights[key] = value.to(torch_dtype).cpu() - if small_variant: + if config.architecture == 'Phi3SmallForCausalLM': weights['lm_head.weight'] = weights[ 'transformer.vocab_embedding.weight'].clone() @@ -71,89 +74,6 @@ def convert_hf_weights(hf_model, dtype, config, small_variant, args, rank): if "qkv." in key: weights[key] = shuffle_qkv_weights(weights[key], config) - weights = split_weights_tp(config, weights, args, rank, torch_dtype) + weights = split_weights_tp(config, weights, torch_dtype) return weights - - -def convert_small_hf_config(hf_config): - return { - 'architecture': "Phi3SmallForCausalLM", - 'rotary_base': hf_config.rope_embedding_base, - 'gegelu_limit': hf_config.gegelu_limit, - 'mup_attn_multiplier': hf_config.mup_attn_multiplier, - 'mup_embedding_multiplier': hf_config.mup_embedding_multiplier, - 'mup_use_scaling': hf_config.mup_use_scaling, - 'mup_width_multiplier': hf_config.mup_width_multiplier, - 'blocksparse_block_size': hf_config.blocksparse_block_size, - 'blocksparse_homo_head_pattern': - hf_config.blocksparse_homo_head_pattern, - 'blocksparse_num_local_blocks': hf_config.blocksparse_num_local_blocks, - 'blocksparse_vertical_stride': hf_config.blocksparse_vert_stride, - 'dense_attention_every_n_layers': - hf_config.dense_attention_every_n_layers, - } - - -def convert_hf_config(hf_config, dtype, args): - config = { - 'architecture': "Phi3ForCausalLM", - 'dtype': dtype, - 'num_hidden_layers': hf_config.num_hidden_layers, - 'num_attention_heads': hf_config.num_attention_heads, - 'num_key_value_heads': hf_config.num_key_value_heads, - 'hidden_size': hf_config.hidden_size, - 'intermediate_size': hf_config.intermediate_size, - 'vocab_size': hf_config.vocab_size, - 'max_position_embeddings': hf_config.max_position_embeddings, - 'hidden_act': hf_config.hidden_act, - 'share_embedding_table': False, - } - - small_variant = hf_config.architectures[0] == "Phi3SmallForCausalLM" - if small_variant: - config.update(convert_small_hf_config(hf_config)) - else: - config.update({ - 'rotary_base': hf_config.rope_theta, - 'norm_epsilon': hf_config.rms_norm_eps, - }) - - # Long-context variants - if hf_config.max_position_embeddings >= 128000: - config.update({ - 'original_max_position_embeddings': - hf_config.original_max_position_embeddings, - 'longrope_scaling_short_factors': - hf_config.rope_scaling["short_factor"], - 'longrope_scaling_long_factors': - hf_config.rope_scaling["long_factor"] - }) - - if small_variant: - config.update({ - 'longrope_long_mscale': - hf_config.rope_scaling["long_mscale"], - 'longrope_short_mscale': - hf_config.rope_scaling["short_mscale"] - }) - - if config["hidden_act"] == "silu": - config["hidden_act"] = "swiglu" - - # Tensor parallelism and weight-only quantization - if args is not None: - config.update({ - 'mapping': { - 'world_size': args.tp_size * args.pp_size, - 'tp_size': args.tp_size, - 'pp_size': args.pp_size, - } - }) - - if args.use_weight_only and args.weight_only_precision == 'int8': - config.update({'quantization': {'quant_algo': QuantAlgo.W8A16}}) - elif args.use_weight_only and args.weight_only_precision == 'int4': - config.update({'quantization': {'quant_algo': QuantAlgo.W4A16}}) - - return config diff --git a/tensorrt_llm/models/phi3/model.py b/tensorrt_llm/models/phi3/model.py index f416f6ef4..fe81557ff 100644 --- a/tensorrt_llm/models/phi3/model.py +++ b/tensorrt_llm/models/phi3/model.py @@ -1,8 +1,6 @@ -import json +import copy import os -import traceback -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Optional +from typing import Optional, Union import numpy as np import safetensors @@ -13,10 +11,13 @@ from ...layers import (MLP, Attention, AttentionMaskType, BlockSparseAttnParams, Embedding, LayerNorm, ParallelLMHead, RmsNorm) from ...lora_manager import LoraConfig, use_lora +from ...mapping import Mapping from ...module import Module +from ...quantization import QuantAlgo from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, - PretrainedConfig) -from .convert import convert_hf_config, convert_hf_weights + PretrainedConfig, QuantConfig) +from .config import Phi3Config +from .convert import load_weights_from_hf_model class Phi3DecoderLayer(Module): @@ -234,50 +235,99 @@ def __init__(self, config: PretrainedConfig): super().__init__(config, transformer, lm_head) @classmethod - def convert_hf_checkpoint(cls, - hf_model_dir: str, - dtype: Optional[str] = "float16", - output_dir: Optional[str] = None, - args=None): - ''' - Convert Huggingface checkpoint to TRT-LLM checkpoint - ''' - - hf_model = AutoModelForCausalLM.from_pretrained(hf_model_dir, - torch_dtype="auto", - trust_remote_code=True) - config = convert_hf_config(hf_model.config, dtype, args) - with open(os.path.join(output_dir, 'config.json'), 'w') as f: - json.dump(config, f, indent=4) - - small_variant = config['architecture'] == "Phi3SmallForCausalLM" - - def covert_and_save(rank): - weights = convert_hf_weights(hf_model, dtype, config, small_variant, - args, rank) - safetensors.torch.save_file( - weights, os.path.join(output_dir, f'rank{rank}.safetensors')) - - world_size = args.tp_size * args.pp_size - if args.workers == 1: - for rank in range(world_size): - covert_and_save(rank) + def from_hugging_face( + cls, + hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'], + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + **kwargs): + import transformers + + assert hf_model_or_dir is not None + use_preloading = isinstance(hf_model_or_dir, + transformers.PreTrainedModel) + if use_preloading: + hf_model = hf_model_or_dir + hf_config_or_dir = hf_model.config else: - with ThreadPoolExecutor(max_workers=args.workers) as p: - futures = [ - p.submit(covert_and_save, rank) - for rank in range(world_size) - ] - exceptions = [] - for future in as_completed(futures): - try: - future.result() - except Exception as e: - traceback.print_exc() - exceptions.append(e) - assert len( - exceptions - ) == 0, "Checkpoint conversion failed, please check error log." + hf_model_dir = hf_model_or_dir + hf_config_or_dir = hf_model_or_dir + config = Phi3Config.from_hugging_face(hf_config_or_dir, + dtype=dtype, + mapping=mapping, + quant_config=quant_config, + **kwargs) + + if not use_preloading: + hf_model = AutoModelForCausalLM.from_pretrained( + hf_model_dir, torch_dtype="auto", trust_remote_code=True) + + assert isinstance(hf_model, transformers.PreTrainedModel) + + weights = load_weights_from_hf_model(hf_model, config) + + model = cls(config) + model.load(weights) + return model + + @classmethod + def quantize( + cls, + hf_model_dir: str, + output_dir: str, + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + *, + device: str = 'cuda', + calib_dataset: str = 'cnn_dailymail', + calib_batches: int = 512, + calib_batch_size: int = 1, + calib_max_seq_length: int = 512, + random_seed: int = 1234, + tokenizer_max_seq_length: int = 2048, + **kwargs, + ): + DEFAULT_MODELOPT_FLOW = [ + QuantAlgo.W4A16_AWQ, + QuantAlgo.FP8, + QuantAlgo.W8A8_SQ_PER_CHANNEL, + ] + NATIVE_QUANT_FLOW = [QuantAlgo.W4A16, QuantAlgo.W8A16, None] + + config = Phi3Config.from_hugging_face(hf_model_dir, + dtype=dtype, + mapping=mapping, + quant_config=quant_config, + **kwargs) + + if quant_config.quant_algo in DEFAULT_MODELOPT_FLOW: + super().quantize(hf_model_dir, + output_dir, + dtype=config.dtype, + mapping=config.mapping, + quant_config=config.quantization, + device=device, + calib_dataset=calib_dataset, + calib_batches=calib_batches, + calib_batch_size=calib_batch_size, + calib_max_seq_length=calib_max_seq_length, + random_seed=random_seed, + tokenizer_max_seq_length=tokenizer_max_seq_length) + else: + assert quant_config.quant_algo in NATIVE_QUANT_FLOW, f"Internal error: shall call Modelopt for this quantization {quant_config}" + + hf_model = AutoModelForCausalLM.from_pretrained( + hf_model_dir, torch_dtype="auto", trust_remote_code=True) + + for rank in range(mapping.world_size): + weights = load_weights_from_hf_model(hf_model, config) + config = copy.deepcopy(config) + config.set_rank(rank) + safetensors.torch.save_file( + weights, os.path.join(output_dir, + f'rank{rank}.safetensors')) def use_lora(self, lora_config: LoraConfig): use_lora(self, lora_config, self.trtllm_modules_to_hf_modules) diff --git a/tensorrt_llm/models/phi3/split_weights.py b/tensorrt_llm/models/phi3/split_weights.py index 72ccf0117..5e1605504 100644 --- a/tensorrt_llm/models/phi3/split_weights.py +++ b/tensorrt_llm/models/phi3/split_weights.py @@ -15,6 +15,8 @@ import torch +from ..._utils import pad_vocab_size + def shuffle_qkv_weights(weights, config): # Input weights are organized as @@ -24,11 +26,11 @@ def shuffle_qkv_weights(weights, config): # Output weights will be organized as # (q00, q01, ..., qnm), (k0, k1, .., kn), (v0, v1, .., vn) - num_heads = config['num_attention_heads'] - num_kv_heads = config['num_key_value_heads'] + num_heads = config.num_attention_heads + num_kv_heads = config.num_key_value_heads num_q_per_kv = num_heads // num_kv_heads - hidden_size = config['hidden_size'] + hidden_size = config.hidden_size head_dim = hidden_size // num_heads input_shape = weights.shape @@ -145,19 +147,19 @@ def get_tllm_linear_weight(weight, return results -def split_weights_tp(config, weights, args, rank, dtype): - num_heads = config['num_attention_heads'] - num_kv_heads = config['num_key_value_heads'] - hidden_size = config['hidden_size'] +def split_weights_tp(config, weights, dtype): + num_heads = config.num_attention_heads + num_kv_heads = config.num_key_value_heads + hidden_size = config.hidden_size mha_mode = num_heads == num_kv_heads - tp_size = args.tp_size - - use_weight_only = args.use_weight_only + tp_size = config.mapping.tp_size + rank = config.mapping.tp_rank + use_weight_only = config.quant_mode.is_weight_only() plugin_weight_only_quant_type = None - if use_weight_only and args.weight_only_precision == 'int8': + if use_weight_only and config.quant_mode.is_int8_weight_only() == 'int8': plugin_weight_only_quant_type = torch.int8 - elif use_weight_only and args.weight_only_precision == 'int4': + elif use_weight_only and config.quant_mode.is_int4_weight_only() == 'int4': plugin_weight_only_quant_type = torch.quint4x2 # Helper @@ -165,7 +167,7 @@ def get_weight(weight, prefix, bias): return get_tllm_linear_weight(weight, prefix, bias, use_weight_only, plugin_weight_only_quant_type) - for layer_id in range(config['num_hidden_layers']): + for layer_id in range(config.num_hidden_layers): layer_prefix = f"transformer.layers.{layer_id}." prefix = layer_prefix + 'attention.qkv' diff --git a/tensorrt_llm/models/recurrentgemma/model.py b/tensorrt_llm/models/recurrentgemma/model.py index 25f0f61e2..f7d91491d 100644 --- a/tensorrt_llm/models/recurrentgemma/model.py +++ b/tensorrt_llm/models/recurrentgemma/model.py @@ -250,6 +250,10 @@ def __init__(self, config: PretrainedConfig): self.gather_context_logits = False self.logits_soft_cap = config.logits_soft_cap + # Create constant attention parameters to be reused by all layers. + Attention.create_attention_const_params(self, config) + self.position_embedding_type = config.position_embedding_type + if isinstance(logits_dtype, str): self._logits_dtype = str_dtype_to_trt(logits_dtype) else: @@ -279,6 +283,11 @@ def forward(self, last_token_ids_for_logits=None, host_context_lengths=None, slot_mapping=None): + + # fill attention params. + attention_params = Attention.fill_attention_params( + self, attention_params) + hidden_states, present_kvs, present_convs, present_rnns = self.transformer( input_ids, use_cache, attention_mask, kv_cache_params, attention_params, conv_states, rnn_states, host_request_types, diff --git a/tensorrt_llm/plugin/plugin.py b/tensorrt_llm/plugin/plugin.py index 72ab69ba6..d3ecc7da0 100644 --- a/tensorrt_llm/plugin/plugin.py +++ b/tensorrt_llm/plugin/plugin.py @@ -336,12 +336,6 @@ class CustomAllReduceHelper: Globally visible class to help usage of custom_all_reduce plugin. Provides the following utilities: - gen_id: int - Used for synchronization with custom kernels. Plugins instances MUST have the same - id across GPUs. I.e.: GPU#0's allreduce after MLP at layer i must have the same id as - GPU#1, GPU#2... Also, ids MUST be unique per model. There should not be two allreduce instances - in GPU#0 that have the same id. - workspace: Tensor When using CUSTOM or AUTO mode, a tensor containing pointers to memory visible to all GPUs. It should be 3 pointers per TP rank - @@ -349,26 +343,19 @@ class CustomAllReduceHelper: It must be initialized using IpcMemory class. Usage: - - Use `init_all_reduce_helper` to reset the id counter. This must be done in main model class. - Set custom_all_reduce_helper.workspace with the required tensor. Then, each instance of allreduce will reference that tensor automatically. """ POINTERS_PER_RANK = 4 def __init__(self) -> None: - self.current_id: int = 1 self.workspace: Optional[Tensor] = None - def gen_id(self) -> int: - result = self.current_id - self.current_id += 1 - return result - def set_workspace_tensor(self, mapping: Mapping, num_profiles: Optional[int] = None): from ..functional import Tensor - workspace_size = self.POINTERS_PER_RANK * mapping.tp_size + workspace_size = self.POINTERS_PER_RANK * mapping.tp_size + 1 dim_range = None if num_profiles is not None: @@ -407,7 +394,7 @@ def allocate_workspace(mapping: Mapping, return buffers, torch.tensor( ipc_buffers_ping.serialize() + ipc_buffers_pong.serialize() + - ipc_barriers_in.serialize() + ipc_barriers_out.serialize(), + ipc_barriers_in.serialize() + ipc_barriers_out.serialize() + [0], dtype=torch.int64, device="cpu") diff --git a/tensorrt_llm/quantization/layers.py b/tensorrt_llm/quantization/layers.py index 9eb29e2e3..a01ae37a6 100644 --- a/tensorrt_llm/quantization/layers.py +++ b/tensorrt_llm/quantization/layers.py @@ -22,10 +22,10 @@ from .._utils import fp32_array, is_same_dtype from ..functional import (ACT2FN, AllReduceFusionOp, AllReduceFusionParams, AttentionMaskType, PositionEmbeddingType, - RopeEmbeddingUtils, RotaryScalingType, Tensor, - allgather, allreduce, cast, concat, constant, - embedding, generate_alibi_slopes, gpt_attention, - matmul, mul, shape, slice, softmax, split, where) + RotaryScalingType, Tensor, allgather, allreduce, cast, + concat, constant, embedding, generate_alibi_slopes, + gpt_attention, matmul, mul, shape, slice, softmax, + split, where) from ..layers import SpecDecodingParams from ..layers.embedding import Embedding from ..layers.linear import Linear, RowLinear @@ -1461,18 +1461,6 @@ def __init__( if self.position_embedding_type.is_rope(): self.rotary_embedding_dim = int(self.attention_head_size * rotary_embedding_percentage) - rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( - self.max_position_embeddings, self.rotary_embedding_dim, - self.rotary_embedding_base, self.rotary_embedding_scale, - self.rotary_embedding_scale_type, rotary_embedding_scaling) - self.register_parameter( - 'rotary_inv_freq', - Parameter(rotary_inv_freq, dtype='float32', is_buffer=True)) - self.register_parameter( - 'embed_positions_for_gpt_attention', - Parameter(embed_positions_for_gpt_attention, - dtype='float32', - is_buffer=True)) elif self.position_embedding_type.is_alibi(): alibi_scale = 1. / self.norm_factor if self.scale_alibi_bias else 1. alibi_slopes = generate_alibi_slopes(self.num_attention_heads * @@ -1576,8 +1564,8 @@ def forward( kv_orig_quant_scale = None kv_quant_orig_scale = None if self.position_embedding_type.is_rope(): - rotary_inv_freq = self.rotary_inv_freq.value - rotary_cos_sin = self.embed_positions_for_gpt_attention.value + rotary_inv_freq = attention_params.rotary_inv_freq + rotary_cos_sin = attention_params.embed_positions_for_gpt_attention else: rotary_inv_freq = None rotary_cos_sin = None diff --git a/tensorrt_llm/quantization/quantize_by_modelopt.py b/tensorrt_llm/quantization/quantize_by_modelopt.py index 06ee19492..19df9da4d 100644 --- a/tensorrt_llm/quantization/quantize_by_modelopt.py +++ b/tensorrt_llm/quantization/quantize_by_modelopt.py @@ -513,6 +513,31 @@ def quantize_and_export(*, with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) + # Set rotary parameters correctly for chatglm. + if model_type == 'chatglm': + rotary_base = 10000.0 + rotary_embedding_scaling = None + chatglm_config = AutoConfig.from_pretrained(model_dir, + trust_remote_code=True) + chatglm_version = tensorrt_llm_config['chatglm_version'] + rope_ratio = tensorrt_llm_config.get('rope_ratio', 1.0) + if chatglm_version == 'chatglm2': + if rope_ratio > 1: + rotary_embedding_scaling = { + 'type': 'linear', + 'factor': rope_ratio + } + elif chatglm_version == 'chatglm3': + rotary_base *= rope_ratio + + with open(f"{export_path}/config.json", "r") as f: + tensorrt_llm_config = json.load(f) + tensorrt_llm_config['rotary_base'] = rotary_base + tensorrt_llm_config['rotary_scaling'] = rotary_embedding_scaling + tensorrt_llm_config['rotary_pct'] = 0.5 + with open(f"{export_path}/config.json", "w") as f: + json.dump(tensorrt_llm_config, f, indent=4) + torch.cuda.empty_cache( ) # otherwise torch is keeping using GPU, other routine like build engine has less free GPU to use diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index 897f024d1..3baf90723 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -2420,6 +2420,8 @@ def _prepare_context_inputs(self, batch_size, context_lengths, if self.is_redrafter_mode: self.buffer['position_ids_base'] = context_lengths.clone() + # NOTE: Generate random tensors using torch + redrafter_prepare_random_tensors(self, batch_size, initialize=True) return ret @@ -2459,24 +2461,25 @@ def _prepare_generation_inputs(self, batch_size, context_lengths, torch.cuda.nvtx.range_pop() ret = {'last_token_ids': last_token_ids} - if self.is_redrafter_mode: - torch.cuda.nvtx.range_push("position_ids_update") - # set position_ids - # buffers are swapped but sequence_length is not updated at this point - - if step != 0: - self.buffer['position_ids_base'] += self.buffer[ - 'num_accepted_tokens'] - position_ids = self.buffer['packed_position_ids'].view( - -1)[:self.host_total_gen_token] - if step == 0: - position_ids -= 1 + if use_gpt_attention_plugin: + if self.is_redrafter_mode: + torch.cuda.nvtx.range_push("position_ids_update") + # set position_ids + # buffers are swapped but sequence_length is not updated at this point + + if step != 0: + self.buffer['position_ids_base'] += self.buffer[ + 'num_accepted_tokens'] + position_ids = self.buffer['packed_position_ids'].view( + -1)[:self.host_total_gen_token] + if step == 0: + position_ids -= 1 - torch.cuda.nvtx.range_pop() - elif use_gpt_attention_plugin: - position_ids = context_lengths + step - if not remove_input_padding: - position_ids = torch.unsqueeze(position_ids, 1) + torch.cuda.nvtx.range_pop() + else: + position_ids = context_lengths + step + if not remove_input_padding: + position_ids = torch.unsqueeze(position_ids, 1) perf_knob_tensor_size = 16 gen_runtime_perf_knobs = torch.tensor([-1] * perf_knob_tensor_size, @@ -2505,19 +2508,7 @@ def _prepare_generation_inputs(self, batch_size, context_lengths, redrafter_convert_spec_decoding_mask_to_packed_mask( self, self.buffer['spec_decoding_generation_lengths']) # NOTE: Generate random tensors using torch - torch.cuda.nvtx.range_push("torch_rand") - # NOTE: Tried a single rand() instead of 2, no change in perf - torch.manual_seed(self.sequence_length_buffer.max()) - self.buffer['rand_data_sample'] = torch.rand([batch_size], - dtype=self.dtype, - device=self.device) - self.buffer['rand_data_validation'] = torch.rand([ - batch_size, self._model_config.redrafter_num_beams, - self._model_config.redrafter_draft_len_per_beam - ], - dtype=self.dtype, - device=self.device) - torch.cuda.nvtx.range_pop() + redrafter_prepare_random_tensors(self, batch_size) torch.cuda.nvtx.range_pop() return ret diff --git a/tensorrt_llm/runtime/redrafter_utils.py b/tensorrt_llm/runtime/redrafter_utils.py index 0877af4aa..706fab3be 100644 --- a/tensorrt_llm/runtime/redrafter_utils.py +++ b/tensorrt_llm/runtime/redrafter_utils.py @@ -2,6 +2,8 @@ import torch +REDRAFTER_DEFAULT_SEED = 0 + def get_redrafter_specific_tensor_names() -> List[str]: return [ @@ -128,18 +130,7 @@ def init_allocate_redrafter_tensors(session, batch_size): session.total_gen_token = torch.zeros(1, dtype=torch.int32, device=session.device) - torch.manual_seed(0) # use seed=0 for context - session.rand_data_sample = torch.rand([batch_size], - dtype=session.dtype, - device=session.device) - # print(session.rand_data_sample) - session.rand_data_validation = torch.rand([ - batch_size, session._model_config.redrafter_num_beams, - session._model_config.redrafter_draft_len_per_beam - ], - dtype=session.dtype, - device=session.device) - # print(session.rand_data_validation) + session.buffer['flat_tokens'] = session.flat_tokens session.buffer['next_flat_tokens'] = session.next_flat_tokens session.buffer['num_accepted_tokens'] = session.accept_lengths @@ -154,8 +145,6 @@ def init_allocate_redrafter_tensors(session, batch_size): 'spec_decoding_position_offsets'] = session.spec_decoding_position_offsets session.buffer[ 'spec_decoding_packed_mask'] = session.spec_decoding_packed_mask - session.buffer['rand_data_sample'] = session.rand_data_sample - session.buffer['rand_data_validation'] = session.rand_data_validation session.buffer[ 'next_spec_decoding_generation_lengths'] = session.next_spec_decoding_generation_lengths session.buffer['next_draft_tokens'] = session.next_draft_tokens @@ -304,3 +293,43 @@ def process_redrafter_outputs(session, step, batch_size, last_draft_tokens, session.end_ids[0]) #FIXME end id padding. torch.cuda.nvtx.range_pop() return best_path, best_path_lengths + + +def redrafter_prepare_random_tensors(session, batch_size, initialize=False): + torch.cuda.nvtx.range_push("torch_rand") + + def get_rand_tensors(): + rds = torch.rand([1], dtype=session.dtype, device=session.device) + rdv = torch.rand([ + 1, session._model_config.redrafter_num_beams, + session._model_config.redrafter_draft_len_per_beam + ], + dtype=session.dtype, + device=session.device) + return rds, rdv + + rand_data_sample = [] + rand_data_validation = [] + if initialize: # context phase + random_seed = session.random_seed + if random_seed is None: + random_seed = torch.full([batch_size], + REDRAFTER_DEFAULT_SEED, + dtype=torch.int64) + session.saved_rng_states = [] + for b in range(batch_size): + if initialize: # context phase + torch.manual_seed(random_seed[b].item()) + else: # generation phase + assert session.saved_rng_states is not None, "Couldn't find random states." + torch.set_rng_state(session.saved_rng_states[b]) + rds, rdv = get_rand_tensors() + session.saved_rng_states.append(torch.get_rng_state()) + rand_data_sample.append(rds) + rand_data_validation.append(rdv) + session.rand_data_sample = torch.concat(rand_data_sample, dim=0) + session.rand_data_validation = torch.concat(rand_data_validation, dim=0) + session.buffer["rand_data_sample"] = session.rand_data_sample + session.buffer["rand_data_validation"] = session.rand_data_validation + torch.cuda.nvtx.range_pop() + return diff --git a/tensorrt_llm/top_model_mixin.py b/tensorrt_llm/top_model_mixin.py index 1e032dea5..abaf58cd6 100644 --- a/tensorrt_llm/top_model_mixin.py +++ b/tensorrt_llm/top_model_mixin.py @@ -45,17 +45,6 @@ def from_hugging_face(cls, ''' raise NotImplementedError("Subclass shall override this") - @classmethod - def convert_hf_checkpoint(cls, - hf_model_dir: str, - dtype: Optional[str] = "float16", - output_dir: Optional[str] = None, - **kwargs): - ''' - Convert Huggingface checkpoint to TRT-LLM checkpoint - ''' - raise NotImplementedError("Subclass shall override this") - def use_lora(self, lora_config: LoraConfig): ''' Load lora weights from the give config to the module diff --git a/tensorrt_llm/version.py b/tensorrt_llm/version.py index 98d72a042..b9c4fbd18 100644 --- a/tensorrt_llm/version.py +++ b/tensorrt_llm/version.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.12.0.dev2024072302" +__version__ = "0.12.0.dev2024073000" diff --git a/tests/functional/test_moe.py b/tests/functional/test_moe.py index f2c9920bb..1e1e7e0a7 100644 --- a/tests/functional/test_moe.py +++ b/tests/functional/test_moe.py @@ -72,7 +72,7 @@ def config_is_allowed(config): # TODO: Support ootb path with getSMVersion() < 90: enable_ootb = getSMVersion() >= 90 enable_bf16 = getSMVersion() >= 80 - enable_fp8 = getSMVersion() >= 90 + enable_fp8 = getSMVersion() >= 89 DATA_TYPE_INDEX = 5 WEIGHT_TYPE_INDEX = 6 @@ -310,7 +310,7 @@ def get_params(): num_experts=8, topk=2, norm_mode=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE, - hidden_size=4096, + hidden_size=2048, dtype='bfloat16', actfn='swiglu') ] diff --git a/tests/hlapi/test_llm_models.py b/tests/hlapi/test_llm_models.py index 23d8ffac1..318481b78 100644 --- a/tests/hlapi/test_llm_models.py +++ b/tests/hlapi/test_llm_models.py @@ -1,9 +1,10 @@ +import subprocess from typing import List, Optional import pytest import torch -from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm import LLM, BuildConfig, SamplingParams from tensorrt_llm.hlapi import QuantAlgo, QuantConfig try: @@ -20,6 +21,11 @@ gptj_model_path = get_model_path('gpt-j-6b') gpt2_model_path = get_model_path('gpt2-medium') starcoder2_model_path = get_model_path('starcoder2-3b') +phi_1_5_model_path = get_model_path('phi-1_5') +phi_2_model_path = get_model_path('phi-2') +phi_3_mini_4k_model_path = get_model_path('Phi-3/Phi-3-mini-4k-instruct') +phi_3_small_8k_model_path = get_model_path('Phi-3/Phi-3-small-8k-instruct') +phi_3_medium_4k_model_path = get_model_path('Phi-3/Phi-3-medium-4k-instruct') sampling_params = SamplingParams(max_new_tokens=10) @@ -43,6 +49,7 @@ def llm_test_harness(model_dir: str, llm = LLM(model_dir, tokenizer=model_dir, **llm_kwargs) outputs = llm.generate(prompts, sampling_params=sampling_params) + print(outputs) for out, ref in zip(outputs, references): assert similar(out.outputs[0].text, ref, threshold=similar_threshold) @@ -110,5 +117,54 @@ def test_llm_starcoder2_fp8(): quant_config=quant_config) +def test_llm_phi_1_5(): + llm_test_harness(phi_1_5_model_path, + prompts=['A B C'], + references=[' D E F G H I J K L M'], + sampling_params=sampling_params) + + +#@force_ampere +def test_llm_phi_2(): + llm_test_harness(phi_2_model_path, + prompts=['A B C'], + references=[' D E F G H I J K L M'], + sampling_params=sampling_params) + + +force_ampere + + +def test_llm_phi_3_mini_4k(): + phi_requirement_path = os.path.join(os.getenv("LLM_ROOT"), + "examples/phi/requirements.txt") + command = f"pip install -r {phi_requirement_path}" + subprocess.run(command, shell=True, check=True, env=os.environ) + llm_test_harness(phi_3_mini_4k_model_path, + prompts=['A B C'], + references=[' D E F G H I J K L M'], + sampling_params=sampling_params) + + +@force_ampere +def test_llm_phi_3_small_8k(): + phi_requirement_path = os.path.join(os.getenv("LLM_ROOT"), + "examples/phi/requirements.txt") + command = f"pip install -r {phi_requirement_path}" + subprocess.run(command, shell=True, check=True, env=os.environ) + build_config = BuildConfig() + build_config.plugin_config._gemm_plugin = 'auto' + llm_test_harness( + phi_3_small_8k_model_path, + prompts=["where is France's capital?"], + references=[' Paris is the capital of France. It is known'], + sampling_params=sampling_params, + build_config=build_config) + + if __name__ == '__main__': - test_llm_gpt2() + test_llm_gptj() + test_llm_phi_1_5() + test_llm_phi_2() + test_llm_phi_3_mini_4k() + test_llm_phi_3_small_8k() diff --git a/tests/model/test_gptneox.py b/tests/model/test_gptneox.py index e246f706a..89df74f39 100644 --- a/tests/model/test_gptneox.py +++ b/tests/model/test_gptneox.py @@ -87,7 +87,7 @@ def _gen_tensorrt_llm_network(self, network, builder, hf_gpt, gpt_config, "num_attention_heads": num_heads, "hidden_size": hidden_size, "vocab_size": vocab_size, - "position_embedding_type": "learned_absolute", + "position_embedding_type": "rope_gpt_neox", "max_position_embeddings": max_position_embeddings, "rotary_emb_base": 10000, "rotary_pct": gpt_config.rotary_pct, diff --git a/tests/model/test_phi.py b/tests/model/test_phi.py index 5921e5db1..aedecc780 100644 --- a/tests/model/test_phi.py +++ b/tests/model/test_phi.py @@ -30,7 +30,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) -from tensorrt_llm.models.phi.convert import convert_hf_weights +from tensorrt_llm.models.phi.convert import load_weights_from_hf_model sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.util import skip_fp32_accum_pre_ampere, unittest_name_func @@ -66,7 +66,8 @@ def initialize_network(self, network: tensorrt_llm.Network, hf_model, 'dtype': dtype, 'num_hidden_layers': hf_config.num_hidden_layers, 'num_attention_heads': hf_config.num_key_value_heads, - 'partial_rotary_factor': hf_config.partial_rotary_factor, + 'rotary_pct': hf_config.partial_rotary_factor, + 'position_embedding_type': 'rope_gpt_neox', 'rope_theta': hf_config.rope_theta, 'hidden_size': hf_config.hidden_size, 'intermediate_size': hf_config.intermediate_size, @@ -84,7 +85,7 @@ def initialize_network(self, network: tensorrt_llm.Network, hf_model, } config = tensorrt_llm.models.PretrainedConfig.from_dict(config) config.set_rank(rank) - weights = convert_hf_weights(hf_model, dtype=dtype) + weights = load_weights_from_hf_model(hf_model, config) trtllm_model = tensorrt_llm.models.PhiForCausalLM(config) trtllm_model.load(weights)