Skip to content

Commit

Permalink
Update TensorRT-LLM (NVIDIA#2053)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaiyux authored Jul 30, 2024
1 parent 93293aa commit a681853
Show file tree
Hide file tree
Showing 128 changed files with 2,392 additions and 1,093 deletions.
63 changes: 51 additions & 12 deletions benchmarks/cpp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,55 @@ cd cpp/build
./benchmarks/gptManagerBenchmark --help
```

Take GPT-350M as an example for 2-GPU inflight batching
```
mpirun -n 2 ./benchmarks/gptManagerBenchmark \
--engine_dir ../../examples/gpt/trt_engine/gpt2-ib/fp16/2-gpu/ \
--request_rate 10 \
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
--max_num_samples 500
```
`gptManagerBenchmark` now supports decoder-only models and encoder-decoder models.

1. Decoder-only Models

To benchmark decoder-only models, pass in the engine path with `--engine_dir` as executable input argument.

Take GPT-350M as an example for 2-GPU inflight batching
```
mpirun -n 2 ./benchmarks/gptManagerBenchmark \
--engine_dir ../../examples/gpt/trt_engine/gpt2-ib/fp16/2-gpu/ \
--request_rate 10 \
--dataset ../../benchmarks/cpp/preprocessed_dataset.json \
--max_num_samples 500
```
`gptManagerBenchmark` by default uses the high-level C++ API defined by the `executor::Executor` class (see `cpp/include/tensorrt_llm/executor/executor.h`).
2. Encoder-Decoder Models
To benchmark encoder-decoder models, pass in the encoder engine path with `--encoder_engine_dir` and the decoder engine path with `--decoder_engine_dir` as executable input arguments. `--decoder_engine_dir` is an alias of `--engine_dir`.
Currently encoder-decoder engines only support `--api executor`, `--type IFB`, `--enable_kv_cache_reuse false`, which are all default values so no specific settings required.
Prepare t5-small engine from [examples/enc_dec](/examples/enc_dec/README.md#convert-and-split-weights) for the encoder-decoder 4-GPU inflight batching example.
Prepare the dataset suitable for engine input lengths.
```
python prepare_dataset.py \
--tokenizer <path/to/tokenizer> \
--output cnn_dailymail.json \
dataset \
--dataset-name cnn_dailymail \
--dataset-split validation \
--dataset-config-name 3.0.0 \
--dataset-input-key article \
--dataset-prompt "Summarize the following article:" \
--dataset-output-key "highlights" \
--num-requests 100 \
--max-input-len 512 \
--output-len-dist 128,20
```
Run the benchmark
```
mpirun --allow-run-as-root -np 4 ./benchmarks/gptManagerBenchmark \
--encoder_engine_dir ../../examples/enc_dec/tmp/trt_engines/t5-small-4gpu/bfloat16/encoder \
--decoder_engine_dir ../../examples/enc_dec/tmp/trt_engines/t5-small-4gpu/bfloat16/decoder \
--dataset cnn_dailymail.json
```
`gptManagerBenchmark` by default uses the high-level C++ API defined by the `executor::Executor` class (see `cpp/include/tensorrt_llm/executor/executor.h`).
#### Emulated static batching
Expand All @@ -125,7 +164,7 @@ Take GPT-350M as an example for single GPU with static batching
```
./benchmarks/gptManagerBenchmark \
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
--request-rate -1 \
--request_rate -1 \
--static_emulated_batch_size 32 \
--static_emulated_timeout 100 \
--dataset ../../benchmarks/cpp/tokens-fixed-lengths.json
Expand Down Expand Up @@ -221,7 +260,7 @@ python benchmarks/cpp/utils/generate_rand_loras.py ${CPP_LORA} ${EG_DIR}/loras 1
# First run inference without LoRAs
mkdir -p ${EG_DIR}/log-base-lora
mpirun -n ${TP} --output-filename ${EG_DIR}/log-base-lora \
cpp/build_Debug/benchmarks/gptManagerBenchmark \
cpp/build/benchmarks/gptManagerBenchmark \
--engine_dir $LORA_ENGINE \
--type IFB \
--dataset "${EG_DIR}/data/token-norm-dist.json" \
Expand All @@ -239,7 +278,7 @@ mpirun -n ${TP} --output-filename ${EG_DIR}/log-base-lora \
for nloras in ${NUM_LORAS[@]}; do
mkdir -p ${EG_DIR}/log-lora-${nloras}
mpirun -n ${TP} --output-filename "${EG_DIR}/log-lora-${nloras}" \
cpp/build_Debug/benchmarks/gptManagerBenchmark \
cpp/build/benchmarks/gptManagerBenchmark \
--engine_dir $LORA_ENGINE \
--type IFB \
--dataset "${EG_DIR}/data/token-norm-dist-lora-${nloras}.json" \
Expand Down
180 changes: 158 additions & 22 deletions benchmarks/cpp/gptManagerBenchmark.cpp

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"meta-llama/Llama-2-70b-hf": "rope_gpt_neox",
"meta-llama/Meta-Llama-3-8B": "rope_gpt_neox",
"meta-llama/Meta-Llama-3-70B": "rope_gpt_neox",
"EleutherAI/gpt-j-6b": "rope_gptj",
"gpt-j-6b": "rope_gptj",
"bigscience/bloom-560m": "alibi",
"mistralai/Mistral-7B-v0.1": "rope_gpt_neox",
"mistralai/Mixtral-8x7B-v0.1": "rope_gpt_neox",
Expand Down Expand Up @@ -126,7 +126,7 @@ class TRTLLMConfig(BaseModel):
moe_top_k: Optional[int] = Field(
default=0, validation_alias=AliasChoices("num_experts_per_tok"))
rotary_base: Optional[float] = Field(
default=None, validation_alias=AliasChoices("rope_theta"))
default=10000.0, validation_alias=AliasChoices("rope_theta"))

mapping: TRTLLM_Mapping
quantization: TRTLLM_Quantization
Expand Down Expand Up @@ -176,8 +176,9 @@ def populate_build_config(cls,
"quant_algo": quant_dtype,
"kv_cache_quant_algo": kv_cache_quant_dtype,
}
if model_name in PET_dict:
build_config["position_embedding_type"] = PET_dict.get(model_name)
for name, pet in PET_dict.items():
if name in str(model_name):
build_config["position_embedding_type"] = pet
return build_config

@classmethod
Expand Down
27 changes: 6 additions & 21 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -718,23 +718,17 @@ class GenericLlmRequest
mReturnGenerationLogits = returnGenerationLogits;
}

// Return all generation logits for model w/o draft token
[[nodiscard]] bool getReturnGenerationLogits() const
{
return mReturnGenerationLogits && (getNumDraftTokens() == 0);
}

// Return accepted tokens logits for target model
[[nodiscard]] bool getReturnTargetModelAcceptedLogits() const
{
return mReturnGenerationLogits && (getNumDraftTokens() > 0);
return mReturnGenerationLogits;
}

[[nodiscard]] TensorPtr const& getContextLogitsHost() const
{
return mContextLogitsHost;
}

/// @param contextLogitsHost Expected shape [promtLen, vocabSizePadded]
void setContextLogitsHost(TensorPtr contextLogitsHost)
{
mContextLogitsHost = std::move(contextLogitsHost);
Expand All @@ -751,6 +745,9 @@ class GenericLlmRequest
return mGenerationLogitsHost;
}

/// @param generationLogitsHost Expected shape
/// * [beamWidth, maxNewTokens, vocabSizePadded] for non-speculative decoding
/// * [1, numDraftTokens + 1, vocabSizePadded] for speculative decoding
void setGenerationLogitsHost(TensorPtr generationLogitsHost)
{
mGenerationLogitsHost = std::move(generationLogitsHost);
Expand All @@ -765,7 +762,7 @@ class GenericLlmRequest
void allocTargetModelAcceptedTokenLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType)
{
mGenerationLogitsHost = runtime::BufferManager::pinnedPool(
runtime::ITensor::makeShape({getNumDraftTokens() + 1, vocabSizePadded}), logitsDataType);
runtime::ITensor::makeShape({1, getNumDraftTokens() + 1, vocabSizePadded}), logitsDataType);
}

[[nodiscard]] std::vector<TensorPtr> const& getGenerationLogitsFragments() const
Expand Down Expand Up @@ -966,18 +963,6 @@ class GenericLlmRequest
result.generationLogits = executor::detail::ofITensor(getGenerationLogitsHost());
}

if (getReturnTargetModelAcceptedLogits())
{
auto targetModelAcceptedTokenLogitsShape = getGenerationLogitsHost()->getShape();
TLLM_CHECK(targetModelAcceptedTokenLogitsShape.nbDims == 2);
auto numAcceptedToken = targetModelAcceptedTokenLogitsShape.d[0];
auto vocabSizePadded = targetModelAcceptedTokenLogitsShape.d[1];
// Align the shape of accepted token logits and generation logits
TensorPtr targetModelAcceptedTokenLogitsHostView = runtime::ITensor::view(
getGenerationLogitsHost(), runtime::ITensor::makeShape({1, numAcceptedToken, vocabSizePadded}));
result.generationLogits = executor::detail::ofITensor(targetModelAcceptedTokenLogitsHostView);
}

if (getReturnEncoderOutput())
{
result.encoderOutput = executor::detail::ofITensor(getEncoderOutputHost());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
#ifndef ENABLE_FP8
static_assert(!FP8, "FP8 Tests enabled on unsupported CUDA version");
#endif
bool should_skip_unsupported_fp8 = getSMVersion() < 90 && FP8;
bool should_skip_unsupported_fp8 = getSMVersion() < 89 && FP8;
return should_skip_unsupported_fp8;
}

Expand Down
Git LFS file not shown
Git LFS file not shown
6 changes: 3 additions & 3 deletions cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
08d59f31da00044ae21995c6573a55da libtensorrt_llm_batch_manager_static.a
abdb9b58e0a4587d2d2ce6bc83655f8a libtensorrt_llm_batch_manager_static.pre_cxx11.a
315e9f5ccd286e906d4c0d402fefbf2f69a1febe commit
ea8bb6e3a155175a0dcfc0e87d1e7f25 libtensorrt_llm_batch_manager_static.a
42ebee6d5349709e33ab36a82f6fef4d libtensorrt_llm_batch_manager_static.pre_cxx11.a
8baf57c648b66a48dbe29f766c6fdff505045f24 commit
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/half.h"
#include "cutlass/layout/matrix.h"

#include "cutlass_extensions/arch/mma.h"
Expand Down Expand Up @@ -138,11 +139,7 @@ struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm80,
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value
#ifdef ENABLE_FP8
|| cutlass::platform::is_same<TypeA, cutlass::float_e4m3_t>::value
#endif
>::type>
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
Expand All @@ -162,6 +159,32 @@ struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
using Operator = typename LayoutDetails::Operator;
};

// FP8 A/B = fp8, C/D = fp32
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::float_e4m3_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::float_e5m2_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;

public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;

using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
// be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t<TypeA>
using TypeC = __nv_bfloat16;
using LayoutB = typename LayoutDetails::Layout;

static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeC>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;

using Operator = typename LayoutDetails::Operator;
};

} // namespace kernel
} // namespace gemm
} // namespace cutlass
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,23 @@ struct LayoutDetailsB<TypeA, bfloat16_t, Arch, typename platform::enable_if<Arch
using Operator = cutlass::arch::OpMultiplyAdd;
};

template <typename TypeA>
struct LayoutDetailsB<TypeA, cutlass::float_e4m3_t, arch::Sm89>
{
static constexpr int ThreadblockK = 64;

private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;

public:
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
// for fast accumulation
// using Operator = cutlass::arch::OpMultiplyAddFastAccum;
};

// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
// which signals that we want to dequantize after loading from smem.
template <typename TypeA, typename Arch>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,6 @@ struct MoeFCGemm
// Epilogue
//

EpilogueOutputOp output_op(params.output_op);

ElementC* ptr_C = reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n;
ElementC* ptr_D = reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;

Expand All @@ -468,7 +466,20 @@ struct MoeFCGemm
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);

// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
if constexpr (platform::is_same<EpilogueOutputOp,
cutlass::epilogue::thread::LinearCombination<typename EpilogueOutputOp::ElementOutput,
EpilogueOutputOp::kCount, typename EpilogueOutputOp::ElementAccumulator,
typename EpilogueOutputOp::ElementCompute, EpilogueOutputOp::kScale,
EpilogueOutputOp::kRound>>::value)
{
EpilogueOutputOp output_op(params.output_op, problem_idx);
epilogue(output_op, iterator_D, accumulators, iterator_C);
}
else
{
EpilogueOutputOp output_op(params.output_op);
epilogue(output_op, iterator_D, accumulators, iterator_C);
}

// Next tile
problem_visitor.advance(gridDim.x);
Expand Down Expand Up @@ -501,8 +512,19 @@ struct MoeFCGemm
run_kernel<arch::Sm70>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900)
constexpr bool isFp8 = platform::is_same<ElementA, cutlass::float_e4m3_t>::value
|| platform::is_same<ElementA, cutlass::float_e5m2_t>::value;
if constexpr (isFp8)
{
run_kernel<arch::Sm89>(params, shared_storage);
}
else
{ // reuse sm80 kernel for other types, align with dispatchToArch
run_kernel<arch::Sm80>(params, shared_storage);
}
#elif (__CUDA_ARCH__ >= 900)
run_kernel<arch::Sm80>(params, shared_storage);
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ enum class CutlassTileConfig
CtaShape256x128x64_WarpShape64x64x64,

// TensorCore config CTA_N = 256, CTA_K = 64
CtaShape16x256x64_WarpShape16x64x64
CtaShape16x256x64_WarpShape16x64x64,

// TensorCore config CTA_N = 256, CTA_K = 128
CtaShape16x256x128_WarpShape16x64x128

};

enum class SplitKStyle
Expand Down Expand Up @@ -129,6 +133,7 @@ struct CutlassGemmConfig
INT8_ONLY = 1u << 2,
HOPPER = 1u << 3,
GROUPED_GEMM = 1u << 4,
FP8_ONLY = 1u << 5,
};

CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
Expand Down
Git LFS file not shown
Git LFS file not shown
Loading

0 comments on commit a681853

Please sign in to comment.