Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TensorRT-LLM #2130

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ __pycache__/
*.cache
*.nsys-rep
.VSCodeCounter
build*/
cpp/build*
build
!tensorrt_llm/bench/build
!builders/
*.egg-info/
.coverage
Expand Down
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ TensorRT-LLM
<div align="left">

## Latest News
* [2024/08/06] 🗫 Multilingual Challenge Accepted 🗫
🤖 #TensorRT #LLM boosts low-resource languages like Hebrew, Indonesian and Vietnamese ⚡[➡️ link](https://developer.nvidia.com/blog/accelerating-hebrew-llm-performance-with-nvidia-tensorrt-llm/?linkId=100000278659647)
* [2024/08/13] 🐍 DIY Code Completion with #Mamba ⚡ #TensorRT #LLM for speed 🤖 NIM for ease ☁️ deploy anywhere
[➡️ link](https://developer.nvidia.com/blog/revolutionizing-code-completion-with-codestral-mamba-the-next-gen-coding-llm/)
<div align="center">
<img src="docs/source/media/picture-08-06-2024.png" width="50%">
<img src="docs/source/media/picture-08-13-2024.png" width="50%">
<div align="left">

* [2024/08/06] 🗫 Multilingual Challenge Accepted 🗫
🤖 #TensorRT #LLM boosts low-resource languages like Hebrew, Indonesian and Vietnamese ⚡[➡️ link](https://developer.nvidia.com/blog/accelerating-hebrew-llm-performance-with-nvidia-tensorrt-llm/?linkId=100000278659647)

* [2024/07/30] Introducing🍊 @SliceXAI ELM Turbo 🤖 train ELM once ⚡ #TensorRT #LLM optimize ☁️ deploy anywhere
[➡️ link](https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms)

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ There are currently three workflows to benchmark TensorRT-LLM:
- The recommended workflow that uses TensorRT-LLM C++ API and can take advantage of the latest features of TensorRT-LLM.
* [Python benchmarks](./python)
- The Python benchmarking scripts can only benchmark the Python runtime, which do not support the latest features, such as in-flight batching.
* [The Python benchmarking suite](./suite)
* [The Python benchmarking suite](./Suite.md)
- This benchmarking suite is a current work in progress and is prone to large changes.
316 changes: 316 additions & 0 deletions benchmarks/Suite.md

Large diffs are not rendered by default.

109 changes: 84 additions & 25 deletions benchmarks/cpp/gptManagerBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/gptJsonConfig.h"
Expand Down Expand Up @@ -173,6 +174,9 @@ struct BenchmarkParams

// Decoding params
std::optional<std::vector<std::vector<SizeType32>>> medusaChoices;

std::optional<texec::LookaheadDecodingConfig> executorLookaheadConfig;
std::optional<texec::LookaheadDecodingConfig> requestLookaheadConfig;
};

class InferenceRequestsAsyncSend
Expand Down Expand Up @@ -509,6 +513,7 @@ class Recorder
{
if (!mStreaming)
{
TLLM_LOG_DEBUG("response.getResult().outputTokenIds");
auto outputTokenIds = response.getResult().outputTokenIds;

int32_t outSeqLen = 0;
Expand Down Expand Up @@ -824,9 +829,11 @@ class ExecutorServer
executorConfig.setMaxNumTokens(benchmarkParams.maxNumTokens.value());
}

executorConfig.setDecodingConfig(texec::DecodingConfig(
benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa() : texec::DecodingMode::Auto(),
std::nullopt, benchmarkParams.medusaChoices));
executorConfig.setDecodingConfig(
texec::DecodingConfig(benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa()
: benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
: texec::DecodingMode::Auto(),
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices));
executorConfig.setExtendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig);

if (executorModelType == texec::ModelType::kDECODER_ONLY)
Expand Down Expand Up @@ -910,7 +917,7 @@ class ExecutorServer
for (auto const& response : responses)
{
auto const reqId = response.getRequestId();

TLLM_LOG_DEBUG("response.getResult().isFinal");
if (response.getResult().isFinal)
{
mActiveCount--;
Expand Down Expand Up @@ -1323,7 +1330,8 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const&
ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId, ITensor::SharedPtr const& padId,
BufferManager const& bufferManager, ITensor::SharedPtr const& returnContextLogits = nullptr,
ITensor::SharedPtr const& returnGenerationLogits = nullptr, ITensor::SharedPtr const& loraWeights = nullptr,
ITensor::SharedPtr const& loraConfig = nullptr)
ITensor::SharedPtr const& loraConfig = nullptr,
std::optional<tensorrt_llm::executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt)
{
auto request = std::make_shared<InferenceRequest>(reqId);
auto const& inputIds = sample.inputIds;
Expand Down Expand Up @@ -1361,6 +1369,10 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const&
{
request->setLoraConfig(loraConfig);
}
if (lookaheadConfig)
{
request->setLookaheadConfig(lookaheadConfig.value());
}
if (streaming)
{
request->setIsStreaming(true);
Expand All @@ -1372,18 +1384,20 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
std::optional<SizeType32> const& eosId, std::optional<SizeType32> const& padId, bool streaming = false,
bool const& returnContextLogits = false, bool const& returnGenerationLogits = false,
std::optional<texec::LoraConfig> const& loraConfig = std::nullopt,
std::optional<texec::LookaheadDecodingConfig> const& lookaheadConfig = std::nullopt,
std::optional<texec::VecTokens> encoderInputTokenIds = std::nullopt)
{
auto samplingConfig = texec::SamplingConfig{beamWidth};
auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false};
return texec::Request(sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId,
std::nullopt, // badWords
std::nullopt, // stopWords
std::nullopt, // embeddingBias
std::nullopt, // speculativeDecoding
std::nullopt, // pTuning
loraConfig,
std::nullopt, // logitsPostProcessorName
std::nullopt, // badWords
std::nullopt, // stopWords
std::nullopt, // embeddingBias
std::nullopt, // speculativeDecoding
std::nullopt, // pTuning
loraConfig, // loraConfig
lookaheadConfig, // lookaheadConfig
std::nullopt, // logitsPostProcessorName
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
}

Expand Down Expand Up @@ -1429,9 +1443,11 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
optionalParams.maxBatchSize = benchmarkParams.maxBatchSize;
optionalParams.maxNumTokens = benchmarkParams.maxNumTokens;
optionalParams.schedulerConfig = texec::SchedulerConfig{capacitySchedulerPolicy};
optionalParams.decodingConfig = texec::DecodingConfig(
benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa() : texec::DecodingMode::Auto(),
std::nullopt, benchmarkParams.medusaChoices);
optionalParams.decodingConfig
= texec::DecodingConfig(benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa()
: benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
: texec::DecodingMode::Auto(),
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices);
optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(
benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);

Expand Down Expand Up @@ -1501,8 +1517,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
++reqId;
if (i == terminateReqId)
++reqId;
auto request = makeRequest(
reqId, samples[0], benchmarkParams.streaming, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
auto request = makeRequest(reqId, samples[0], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
padIdTensor, bufferManager, nullptr, nullptr, nullptr, nullptr, benchmarkParams.requestLookaheadConfig);
gptServer->enqueue(request);
}
gptServer->waitForEmpty();
Expand All @@ -1517,7 +1533,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
for (std::size_t i = 0; i < numSamples; ++i)
{
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor);
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor, nullptr,
nullptr, benchmarkParams.requestLookaheadConfig);
gptServer->enqueue(request);

if (i < numSamples - 1)
Expand All @@ -1541,7 +1558,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
for (std::size_t i = 0; i < numSamples; ++i)
{
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor);
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor,
nullptr, nullptr, benchmarkParams.requestLookaheadConfig);
gptServer->enqueue(request);
}
gptServer->waitForEmpty();
Expand Down Expand Up @@ -1644,13 +1662,13 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
{
Sample s{std::vector<int32_t>{decoderStartTokenId}, 1, static_cast<int32_t>(taskId)};
requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false,
loraConfig, std::vector<int32_t>{1, 2, 3, 4, 5}));
loraConfig, std::nullopt, std::vector<int32_t>{1, 2, 3, 4, 5}));
}
else
{
Sample s{std::vector<int32_t>{1, 2, 3, 4, 5}, 1, static_cast<int32_t>(taskId)};
requests.emplace_back(
makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, loraConfig));
makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, loraConfig, std::nullopt));
}
}
executorServer->enqueue(std::move(requests), true);
Expand All @@ -1668,12 +1686,14 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
{
Sample s{std::vector<int32_t>{decoderStartTokenId}, samples[0].outputLen, samples[0].taskId};
requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, benchmarkParams.streaming,
returnContextLogits, returnGenerationLogits, std::nullopt, samples[0].inputIds));
returnContextLogits, returnGenerationLogits, std::nullopt,
benchmarkParams.requestLookaheadConfig, samples[0].inputIds));
}
else
{
requests.emplace_back(makeExecutorRequest(samples[0], beamWidth, eosId, padId,
benchmarkParams.streaming, returnContextLogits, returnGenerationLogits));
benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, std::nullopt,
benchmarkParams.requestLookaheadConfig));
}
}
executorServer->enqueue(std::move(requests), true);
Expand All @@ -1699,12 +1719,14 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
{
Sample s{std::vector<int32_t>{decoderStartTokenId}, samples[i].outputLen, samples[i].taskId};
requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, benchmarkParams.streaming,
returnContextLogits, returnGenerationLogits, loraConfig, samples[i].inputIds));
returnContextLogits, returnGenerationLogits, loraConfig, benchmarkParams.requestLookaheadConfig,
samples[i].inputIds));
}
else
{
requests.emplace_back(makeExecutorRequest(samples[i], beamWidth, eosId, padId,
benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig));
benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig,
benchmarkParams.requestLookaheadConfig));
}
}

Expand Down Expand Up @@ -1789,6 +1811,25 @@ std::vector<std::vector<SizeType32>> parseVectorOfVectors(std::string const& inp
return result;
}

texec::LookaheadDecodingConfig parseLookaheadConfig(std::string const& input)
{
std::regex regex("\\[ *(\\d+) *, *(\\d+) *, *(\\d+) *\\]");
std::smatch match;
if (std::regex_match(input, match, regex))
{
TLLM_CHECK(match.size() == 4);
auto w = std::stoi(match[1]);
auto n = std::stoi(match[2]);
auto g = std::stoi(match[3]);
return texec::LookaheadDecodingConfig(w, n, g);
}
else
{
TLLM_LOG_WARNING("cannot parse lookahead config from '%s'", input.c_str());
return texec::LookaheadDecodingConfig();
}
}

} // namespace

int main(int argc, char* argv[])
Expand Down Expand Up @@ -1898,6 +1939,14 @@ int main(int argc, char* argv[])

options.add_options()("enable_context_fmha_fp32_acc", "Enable FMHA runner FP32 accumulation",
cxxopts::value<bool>()->default_value("false"));
options.add_options()("executor_lookahead_config",
"lookahead config in the format of [max_window_size, max_ngram_size, max_verification_set_size]",
cxxopts::value<std::string>());

options.add_options()("request_lookahead_config",
"lookahead config in the format of [max_window_size, max_ngram_size, max_verification_set_size], and each <= "
"executor lookahead config",
cxxopts::value<std::string>());

auto result = options.parse(argc, argv);

Expand Down Expand Up @@ -2055,6 +2104,16 @@ int main(int argc, char* argv[])
{
benchmarkParams.medusaChoices = parseVectorOfVectors(result["medusa_choices"].as<std::string>());
}
if (result.count("executor_lookahead_config"))
{
benchmarkParams.executorLookaheadConfig
= parseLookaheadConfig(result["executor_lookahead_config"].as<std::string>());
}
if (result.count("request_lookahead_config"))
{
benchmarkParams.requestLookaheadConfig
= parseLookaheadConfig(result["request_lookahead_config"].as<std::string>());
}

// Argument: multi_block_mode
benchmarkParams.multiBlockMode = result["multi_block_mode"].as<bool>();
Expand Down
25 changes: 12 additions & 13 deletions benchmarks/python/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import tensorrt_llm as tllm
from tensorrt_llm import Mapping, Tensor
from tensorrt_llm._ipc_utils import peer_access
from tensorrt_llm._utils import OMPI_COMM_TYPE_HOST, mpi_comm
from tensorrt_llm.functional import AllReduceStrategy, allreduce
from tensorrt_llm.plugin.plugin import current_all_reduce_helper
Expand Down Expand Up @@ -106,18 +105,18 @@ def allreduce_benchmark(dtype: str,
_, start = cuda.cuEventCreate(0)
_, stop = cuda.cuEventCreate(0)
runtimes = []
with peer_access(mapping):
tllm.mpi_barrier()

for _ in range(10):
cuda.cuEventRecord(start, stream.cuda_stream)
session.run(inputs=feed_dict,
outputs={"output": output},
stream=stream.cuda_stream)
cuda.cuEventRecord(stop, stream.cuda_stream)
torch.cuda.synchronize()
_, ms = cuda.cuEventElapsedTime(start, stop)
runtimes.append(ms)

tllm.mpi_barrier()

for _ in range(10):
cuda.cuEventRecord(start, stream.cuda_stream)
session.run(inputs=feed_dict,
outputs={"output": output},
stream=stream.cuda_stream)
cuda.cuEventRecord(stop, stream.cuda_stream)
torch.cuda.synchronize()
_, ms = cuda.cuEventElapsedTime(start, stop)
runtimes.append(ms)

median_ms = sorted(runtimes)[len(runtimes) // 2]
assert torch.allclose(output, (input * world_size)**inner_loop)
Expand Down
Loading