Skip to content

Commit

Permalink
Added new cmake flag TRITON_ENABLE_CIG to make the CiG support build …
Browse files Browse the repository at this point in the history
…conditional
  • Loading branch information
Akarale98 committed Sep 19, 2024
1 parent 9ae5a09 commit 89ab580
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 11 deletions.
14 changes: 13 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ set(TRITON_MIN_CXX_STANDARD 17 CACHE STRING "The minimum C++ standard which feat
option(TRITON_ENABLE_GPU "Enable GPU support in backend." ON)
option(TRITON_ENABLE_STATS "Include statistics collections in backend." ON)
option(TRITON_ENABLE_NVTX "Include nvtx markers collection in backend." OFF)
option(TRITON_ENABLE_CIG "Enable Cuda in Graphics (CiG) support in backend." OFF)

set(TRITON_TENSORRT_LIB_PATHS "" CACHE PATH "Paths to TensorRT libraries. Multiple paths may be specified by separating them with a semicolon.")
set(TRITON_TENSORRT_INCLUDE_PATHS "" CACHE PATH "Paths to TensorRT includes. Multiple paths may be specified by separating them with a semicolon.")

Expand Down Expand Up @@ -269,9 +271,19 @@ target_link_libraries(
triton-tensorrt-backend
PRIVATE
CUDA::cudart
CUDA::cuda_driver
)

if(${TRITON_ENABLE_CIG})
target_compile_definitions(
triton-tensorrt-backend
PRIVATE TRITON_ENABLE_CIG
)
target_link_libraries(
triton-tensorrt-backend
PRIVATE
CUDA::cuda_driver
)
endif()

#
# Install
Expand Down
15 changes: 12 additions & 3 deletions src/instance_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,11 @@ ModelInstanceState::ModelInstanceState(

ModelInstanceState::~ModelInstanceState()
{
#ifdef TRITON_ENABLE_CIG
// Set device if CiG is disabled
if (!model_state_->isCiGEnabled()) {
if (!model_state_->isCiGEnabled())
#endif //TRITON_ENABLE_CIG
{
cudaSetDevice(DeviceId());
}
for (auto& io_binding_infos : io_binding_infos_) {
Expand Down Expand Up @@ -427,8 +430,11 @@ ModelInstanceState::Run(
payload_.reset(new Payload(next_set_, requests, request_count));
SET_TIMESTAMP(payload_->compute_start_ns_);

#ifdef TRITON_ENABLE_CIG
// Set device if CiG is disabled
if (!model_state_->isCiGEnabled()) {
if (!model_state_->isCiGEnabled())
#endif //TRITON_ENABLE_CIG
{
cudaSetDevice(DeviceId());
}
#ifdef TRITON_ENABLE_STATS
Expand Down Expand Up @@ -1557,8 +1563,11 @@ ModelInstanceState::EvaluateTensorRTContext(
TRITONSERVER_Error*
ModelInstanceState::InitStreamsAndEvents()
{
#ifdef TRITON_ENABLE_CIG
// Set device if CiG is disabled
if (!model_state_->isCiGEnabled()) {
if (!model_state_->isCiGEnabled())
#endif //TRITON_ENABLE_CIG
{
// Set the device before preparing the context.
auto cuerr = cudaSetDevice(DeviceId());
if (cuerr != cudaSuccess) {
Expand Down
21 changes: 17 additions & 4 deletions src/model_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,11 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
ModelState::~ModelState()
{
for (auto& device_engine : device_engines_) {
#ifdef TRITON_ENABLE_CIG
// Set device if CiG is disabled
if (!isCiGEnabled()) {
if (!isCiGEnabled())
#endif //TRITON_ENABLE_CIG
{
cudaSetDevice(device_engine.first.first);
}
auto& runtime = device_engine.second.first;
Expand Down Expand Up @@ -212,8 +215,12 @@ ModelState::CreateEngine(
// We share the engine (for models that don't have dynamic shapes) and
// runtime across instances that have access to the same GPU/NVDLA.
if (eit->second.second == nullptr) {

#ifdef TRITON_ENABLE_CIG
// Set device if CiG is disabled
if (!isCiGEnabled()) {
if (!isCiGEnabled())
#endif //TRITON_ENABLE_CIG
{
auto cuerr = cudaSetDevice(gpu_device);
if (cuerr != cudaSuccess) {
return TRITONSERVER_ErrorNew(
Expand Down Expand Up @@ -326,8 +333,11 @@ ModelState::AutoCompleteConfig()
" to auto-complete config for " + Name())
.c_str()));

#ifdef TRITON_ENABLE_CIG
// Set device if CiG is disabled
if (!isCiGEnabled()) {
if (!isCiGEnabled())
#endif //TRITON_ENABLE_CIG
{
cuerr = cudaSetDevice(device_id);
if (cuerr != cudaSuccess) {
return TRITONSERVER_ErrorNew(
Expand Down Expand Up @@ -381,8 +391,11 @@ ModelState::AutoCompleteConfig()

RETURN_IF_ERROR(AutoCompleteConfigHelper(model_path));

#ifdef TRITON_ENABLE_CIG
// Set device if CiG is disabled
if (!isCiGEnabled()) {
if (!isCiGEnabled())
#endif //TRITON_ENABLE_CIG
{
cuerr = cudaSetDevice(current_device);
if (cuerr != cudaSuccess) {
return TRITONSERVER_ErrorNew(
Expand Down
8 changes: 7 additions & 1 deletion src/tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance)
DeviceMemoryTracker::TrackThreadMemoryUsage(lusage.get());
}

#ifdef TRITON_ENABLE_CIG
ScopedRuntimeCiGContext cig_scope(model_state);
#endif //TRITON_ENABLE_CIG

// With each instance we create a ModelInstanceState object and
// associate it with the TRITONBACKEND_ModelInstance.
Expand Down Expand Up @@ -357,7 +359,9 @@ TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance)
if (!instance_state) {
return nullptr;
}
#ifdef TRITON_ENABLE_CIG
ScopedRuntimeCiGContext cig_scope(instance_state->StateForModel());
#endif //TRITON_ENABLE_CIG

delete instance_state;

Expand All @@ -382,7 +386,9 @@ TRITONBACKEND_ModelInstanceExecute(
instance, reinterpret_cast<void**>(&instance_state)));
ModelState* model_state = instance_state->StateForModel();

ScopedRuntimeCiGContext cig_scope(instance_state->StateForModel());
#ifdef TRITON_ENABLE_CIG
ScopedRuntimeCiGContext cig_scope(model_state);
#endif //TRITON_ENABLE_CIG

// For TensorRT backend, the executing instance may not closely tie to
// TRITONBACKEND_ModelInstance, the instance will be assigned based on
Expand Down
9 changes: 8 additions & 1 deletion src/tensorrt_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ TensorRTModel::TensorRTModel(TRITONBACKEND_Model* triton_model)
: BackendModel(triton_model), priority_(Priority::DEFAULT),
use_cuda_graphs_(false), gather_kernel_buffer_threshold_(0),
separate_output_stream_(false), eager_batching_(false),
busy_wait_events_(false), cig_ctx_(nullptr)
busy_wait_events_(false)
#ifdef TRITON_ENABLE_CIG
,cig_ctx_(nullptr)
#endif // TRITON_ENABLE_CIG
{
ParseModelConfig();
}
Expand Down Expand Up @@ -91,6 +94,8 @@ TensorRTModel::ParseModelConfig()
cuda.MemberAsBool("output_copy_stream", &separate_output_stream_));
}
}

#ifdef TRITON_ENABLE_CIG
triton::common::TritonJson::Value parameters;
if (model_config_.Find("parameters", &parameters)) {
triton::common::TritonJson::Value value;
Expand All @@ -105,6 +110,8 @@ TensorRTModel::ParseModelConfig()
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, "CiG Context pointer is set");
}
}
#endif //TRITON_ENABLE_CIG

return nullptr; // Success
}

Expand Down
10 changes: 9 additions & 1 deletion src/tensorrt_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once

#ifdef TRITON_ENABLE_CIG
#include <cuda.h>
#endif //TRITON_ENABLE_CIG

#include "triton/backend/backend_model.h"

Expand Down Expand Up @@ -55,7 +57,7 @@ class TensorRTModel : public BackendModel {
bool EagerBatching() { return eager_batching_; }
bool BusyWaitEvents() { return busy_wait_events_; }


#ifdef TRITON_ENABLE_CIG
//! Following functions are related to CiG (Cuda in Graphics) context sharing
//! for gaming use case. Creating a shared contexts reduces context switching
//! overhead and leads to better performance of model execution along side
Expand Down Expand Up @@ -88,6 +90,7 @@ class TensorRTModel : public BackendModel {
}
return nullptr;
}
#endif //TRITON_ENABLE_CIG

protected:
common::TritonJson::Value graph_specs_;
Expand All @@ -97,9 +100,13 @@ class TensorRTModel : public BackendModel {
bool separate_output_stream_;
bool eager_batching_;
bool busy_wait_events_;
#ifdef TRITON_ENABLE_CIG
CUcontext cig_ctx_;
#endif //TRITON_ENABLE_CIG

};

#ifdef TRITON_ENABLE_CIG
struct ScopedRuntimeCiGContext {
ScopedRuntimeCiGContext(TensorRTModel* model_state)
: model_state_(model_state)
Expand All @@ -116,5 +123,6 @@ struct ScopedRuntimeCiGContext {
}
TensorRTModel* model_state_;
};
#endif //TRITON_ENABLE_CIG

}}} // namespace triton::backend::tensorrt

0 comments on commit 89ab580

Please sign in to comment.