Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Client-side input shape/element validation #742

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/c++/library/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -122,6 +122,7 @@ if(TRITON_ENABLE_CC_GRPC OR TRITON_ENABLE_PERF_ANALYZER)
grpcclient_static
PRIVATE gRPC::grpc++
PRIVATE gRPC::grpc
PRIVATE triton-common-model-config
PUBLIC protobuf::libprotobuf
PUBLIC Threads::Threads
)
Expand Down Expand Up @@ -150,6 +151,7 @@ if(TRITON_ENABLE_CC_GRPC OR TRITON_ENABLE_PERF_ANALYZER)
grpcclient
PRIVATE gRPC::grpc++
PRIVATE gRPC::grpc
PRIVATE triton-common-model-config
PUBLIC protobuf::libprotobuf
PUBLIC Threads::Threads
)
Expand Down Expand Up @@ -275,6 +277,10 @@ if(TRITON_ENABLE_CC_HTTP OR TRITON_ENABLE_PERF_ANALYZER)
http-client-library EXCLUDE_FROM_ALL OBJECT
${REQUEST_SRCS} ${REQUEST_HDRS}
)
add_dependencies(
http-client-library
proto-library
)
Copy link
Contributor

@rmccorm4 rmccorm4 Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't proto-library used for protobuf<->grpc? Why is it needed for HTTP client?

edit: guessing the requirement is here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any concerns with introducing the new protobuf dependency to the HTTP client, or any alternatives? CC @GuanLuo @tanmayv25


if (NOT WIN32)
set_property(
Expand All @@ -287,12 +293,14 @@ if(TRITON_ENABLE_CC_HTTP OR TRITON_ENABLE_PERF_ANALYZER)
http-client-library
PUBLIC
triton-common-json # from repo-common
triton-common-model-config
)

# libhttpclient_static.a
add_library(
httpclient_static STATIC
$<TARGET_OBJECTS:http-client-library>
$<TARGET_OBJECTS:proto-library>
)
add_library(
TritonClient::httpclient_static ALIAS httpclient_static
Expand All @@ -301,6 +309,7 @@ if(TRITON_ENABLE_CC_HTTP OR TRITON_ENABLE_PERF_ANALYZER)
target_link_libraries(
httpclient_static
PRIVATE triton-common-json
PRIVATE triton-common-model-config
PUBLIC CURL::libcurl
PUBLIC Threads::Threads
)
Expand All @@ -316,6 +325,7 @@ if(TRITON_ENABLE_CC_HTTP OR TRITON_ENABLE_PERF_ANALYZER)
add_library(
httpclient SHARED
$<TARGET_OBJECTS:http-client-library>
$<TARGET_OBJECTS:proto-library>
)
add_library(
TritonClient::httpclient ALIAS httpclient
Expand All @@ -333,6 +343,7 @@ if(TRITON_ENABLE_CC_HTTP OR TRITON_ENABLE_PERF_ANALYZER)
target_link_libraries(
httpclient
PRIVATE triton-common-json
PRIVATE triton-common-model-config
PUBLIC CURL::libcurl
PUBLIC Threads::Threads
)
Expand All @@ -358,6 +369,7 @@ if(TRITON_ENABLE_CC_HTTP OR TRITON_ENABLE_PERF_ANALYZER)
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
$<TARGET_PROPERTY:CURL::libcurl,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:proto-library,INCLUDE_DIRECTORIES>
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
Expand Down
26 changes: 25 additions & 1 deletion src/c++/library/common.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -26,6 +26,10 @@

#include "common.h"

#include <numeric>

#include "triton/common/model_config.h"

namespace triton { namespace client {

//==============================================================================
Expand Down Expand Up @@ -232,6 +236,26 @@ InferInput::SetBinaryData(const bool binary_data)
return Error::Success;
}

Error
InferInput::ValidateData() const
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving TRT reformat conversation to a thread 🧵

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yingge:

There is a known issue with TensorRT (Jira DLIS-6805 ) which causes TRT tests to fail again at client (CI job 102924904). There is no way to know the platform of inference model at the client side. Should we wait until @pskiran1 finish his change first?
CC @tanmayv25 @GuanLuo @rmccorm4

Sai:

@yinggeh, I just merged DLIS-6805 changes, could you please try with the latest code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just merged DLIS-6805 changes, could you please try with the latest code?

Sai's changes will allow the check to work on core side, but probably not on client side, right? @yinggeh

There is no way to know the platform of inference model at the client side.

You can query the platform/backend through the model config APIs on client side, which would work when inferring on a TensorRT model directly. You can probably even query is_non_linear_format_io from the model config if needed.

For an ensemble model containing one of these TRT models with non-linear inputs, you may need to follow the ensemble definition to find out if it's calling a TRT model with its inputs, which can be a pain. It may be simpler to skip the check on ensemble models and let the core check handle it (but it feels like we're starting to introduce a lot of special checks and cases with this feature).

For a BLS model, I think it's fine and will work as any other python model, then it will trigger the core check internally if the BLS is calling the TRT model.

CC @tanmayv25

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another alternative is to introduce a new flag in client Input/Output tensors to skip the byte size check on the client side.
We can document when we expect the user to provide this option (using non-linear format).
This way the user can be aware of what they are doing.
Pro:

  1. Generic API change allows for all the flexibility
  2. Powerful expression for the client-side code.

Cons:

  1. Adding a flag to skip these checks seems to be counter-intuitive and makes us question even the requirement of such checks in the first place.
    a. This can be alleviated by an additional check to some degree by validating the skip_byte_size check flag is set for the correct scenario.
  2. Breaks backwards compatibility, as the user now has to set a new flag to use models with non-linear tensors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For an ensemble model containing one of these TRT models with non-linear inputs, you may need to follow the ensemble definition to find out if it's calling a TRT model with its inputs, which can be a pain.

@rmccorm4 Can you elaborate on this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on this?

If you have an ensemble with ENSEMBLE_INPUT0 where the first step is a TRT model with non-linear IO INPUT0 and a mapping of ENSEMBLE_INPUT0 -> INPUT0, do we require an ensemble config to mention that the ENSEMBLE_INPUT0 is non-linear IO too? Or is it inferred internally?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a flag to skip these checks seems to be counter-intuitive and makes us question even the requirement of such checks in the first place

+1 that I think this is counter-intuitive to the goal

This can be alleviated by an additional check to some degree by validating the skip_byte_size check flag is set for the correct scenario.

If we are able to internally determine "the correct scenario" programatically, isn't this the same as being able to skip internally without user specification?

inference::DataType datatype =
triton::common::ProtocolStringToDataType(datatype_);
// String inputs will be checked at core and backend to reduce overhead.
if (datatype == inference::DataType::TYPE_STRING) {
return Error::Success;
}

int64_t expected_byte_size = triton::common::GetByteSize(datatype, shape_);
if ((int64_t)byte_size_ != expected_byte_size) {
yinggeh marked this conversation as resolved.
Show resolved Hide resolved
return Error(
"input '" + name_ + "' got unexpected byte size " +
std::to_string(byte_size_) + ", expected " +
std::to_string(expected_byte_size));
}
return Error::Success;
}

Error
InferInput::PrepareForRequest()
{
Expand Down
6 changes: 5 additions & 1 deletion src/c++/library/common.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -354,6 +354,10 @@ class InferInput {
/// \return Error object indicating success or failure.
Error SetBinaryData(const bool binary_data);

/// Validate input has data and input shape matches input data.
/// \return Error object indicating success of failure.
Error ValidateData() const;

private:
#ifdef TRITON_INFERENCE_SERVER_CLIENT_CLASS
friend class TRITON_INFERENCE_SERVER_CLIENT_CLASS;
Expand Down
8 changes: 7 additions & 1 deletion src/c++/library/grpc_client.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -1470,7 +1470,13 @@ InferenceServerGrpcClient::PreRunProcessing(

int index = 0;
infer_request_.mutable_raw_input_contents()->Clear();
Error err;
for (const auto input : inputs) {
err = input->ValidateData();
if (!err.IsOk()) {
return err;
}

// Add new InferInputTensor submessages only if required, otherwise
// reuse the submessages already available.
auto grpc_input = (infer_request_.inputs().size() <= index)
Expand Down
7 changes: 6 additions & 1 deletion src/c++/library/http_client.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -2117,6 +2117,11 @@ InferenceServerHttpClient::PreRunProcessing(
// Add the buffers holding input tensor data
bool all_inputs_are_json{true};
for (const auto this_input : inputs) {
err = this_input->ValidateData();
if (!err.IsOk()) {
return err;
}

if (this_input->BinaryData()) {
all_inputs_are_json = false;
}
Expand Down
33 changes: 32 additions & 1 deletion src/c++/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -70,6 +70,9 @@ install(
RUNTIME DESTINATION bin
)

#
# cc_client_test
#
add_executable(
cc_client_test
cc_client_test.cc
Expand All @@ -89,6 +92,34 @@ install(
RUNTIME DESTINATION bin
)

#
# client_input_test
#
add_executable(
client_input_test
client_input_test.cc
$<TARGET_OBJECTS:shm-utils-library>
)
target_include_directories(
client_input_test
PRIVATE
${GTEST_INCLUDE_DIRS}
)
target_link_libraries(
client_input_test
PRIVATE
grpcclient_static
httpclient_static
gtest
${GTEST_LIBRARY}
${GTEST_MAIN_LIBRARY}
GTest::gmock
)
install(
TARGETS client_input_test
RUNTIME DESTINATION bin
)

endif() # TRITON_ENABLE_CC_HTTP AND TRITON_ENABLE_CC_GRPC

endif()
Loading
Loading