Skip to content

Commit

Permalink
Remove client checks for string inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
yinggeh committed Jul 23, 2024
1 parent e5e6b7e commit 07059a6
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 171 deletions.
109 changes: 11 additions & 98 deletions src/c++/library/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,109 +236,22 @@ InferInput::SetBinaryData(const bool binary_data)
return Error::Success;
}

Error
InferInput::GetStringCount(size_t* str_cnt) const
{
int64_t str_checked = 0;
size_t remaining_str_size = 0;

size_t next_buf_idx = 0;
const size_t buf_cnt = bufs_.size();

const uint8_t* buf = nullptr;
size_t remaining_buf_size = 0;

// Validate elements until all buffers have been fully processed.
while (remaining_buf_size || next_buf_idx < buf_cnt) {
// Get the next buf if not currently processing one.
if (!remaining_buf_size) {
// Reset remaining buf size and pointers for next buf.
buf = bufs_[next_buf_idx];
remaining_buf_size = buf_byte_sizes_[next_buf_idx];
next_buf_idx++;
}

constexpr size_t kStringSizeIndicator = sizeof(uint32_t);
// Get the next element if not currently processing one.
if (!remaining_str_size) {
// FIXME: Assume the string element's byte size indicator is not spread
// across buf boundaries for simplicity. Also needs better log msg.
if (remaining_buf_size < kStringSizeIndicator) {
return Error("element byte size indicator exceeds the end of the buf.");
}

// Start the next element and reset the remaining element size.
remaining_str_size = *(reinterpret_cast<const uint32_t*>(buf));
str_checked++;

// Advance pointer and remainder by the indicator size.
buf += kStringSizeIndicator;
remaining_buf_size -= kStringSizeIndicator;
}

// If the remaining buf fits it: consume the rest of the element, proceed
// to the next element.
if (remaining_buf_size >= remaining_str_size) {
buf += remaining_str_size;
remaining_buf_size -= remaining_str_size;
remaining_str_size = 0;
}
// Otherwise the remaining element is larger: consume the rest of the
// buf, proceed to the next buf.
else {
remaining_str_size -= remaining_buf_size;
remaining_buf_size = 0;
}
}

// FIXME: If more than expected, should stop earlier
// Validate the number of processed elements exactly match expectations.
*str_cnt = str_checked;
return Error::Success;
}

Error
InferInput::ValidateData() const
{
inference::DataType datatype =
triton::common::ProtocolStringToDataType(datatype_);
if (io_type_ == SHARED_MEMORY) {
if (datatype == inference::DataType::TYPE_STRING) {
// TODO Didn't find any shm and BYTES inputs inference example
} else {
int64_t expected_byte_size =
triton::common::GetByteSize(datatype, shape_);
if ((int64_t)byte_size_ != expected_byte_size) {
return Error(
"input '" + name_ + "' got unexpected byte size " +
std::to_string(byte_size_) + ", expected " +
std::to_string(expected_byte_size));
}
}
} else {
if (datatype == inference::DataType::TYPE_STRING) {
int64_t expected_str_cnt = triton::common::GetElementCount(shape_);
size_t str_cnt;
Error err = GetStringCount(&str_cnt);
if (!err.IsOk()) {
return err;
}
if ((int64_t)str_cnt != expected_str_cnt) {
return Error(
"input '" + name_ + "' got unexpected string count " +
std::to_string(str_cnt) + ", expected " +
std::to_string(expected_str_cnt));
}
} else {
int64_t expected_byte_size =
triton::common::GetByteSize(datatype, shape_);
if ((int64_t)byte_size_ != expected_byte_size) {
return Error(
"input '" + name_ + "' got unexpected byte size " +
std::to_string(byte_size_) + ", expected " +
std::to_string(expected_byte_size));
}
}
// String inputs will be checked at core and backend to reduce overhead.
if (datatype == inference::DataType::TYPE_STRING) {
return Error::Success;
}

int64_t expected_byte_size = triton::common::GetByteSize(datatype, shape_);
if ((int64_t)byte_size_ != expected_byte_size) {
return Error(
"input '" + name_ + "' got unexpected byte size " +
std::to_string(byte_size_) + ", expected " +
std::to_string(expected_byte_size));
}
return Error::Success;
}
Expand Down
5 changes: 0 additions & 5 deletions src/c++/library/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,6 @@ class InferInput {
/// \return Error object indicating success or failure.
Error SetBinaryData(const bool binary_data);

/// Gets the total number of strings in this input data.
/// \param byte_size The number of strings.
/// \return Error object indicating success or failure.
Error GetStringCount(size_t* str_cnt) const;

/// Validate input has data and input shape matches input data.
/// \return Error object indicating success of failure.
Error ValidateData() const;
Expand Down
104 changes: 36 additions & 68 deletions src/c++/tests/client_input_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,23 @@ TYPED_TEST_P(ClientInputTest, AppendRaw)

std::vector<tc::InferInput*> inputs = {input0_ptr.get(), input1_ptr.get()};
tc::InferResult* results;

// Test 1
inputs[1]->SetShape({1, 15});
FAIL_IF_SUCCESS(
this->client_->Infer(&results, options, inputs),
"expect error with inference request",
"input 'INPUT1' got unexpected byte size 64, expected 60");

// Check error message and verify the request reaches the server
// Test 2
inputs[0]->SetShape({2, 8});
inputs[1]->SetShape({2, 8});
// Assert the request reaches the server
FAIL_IF_SUCCESS(
this->client_->Infer(&results, options, inputs),
"expect error with inference request",
"input 'INPUT0' batch size does not match other inputs for 'simple'");
"unexpected shape for input 'INPUT1' for model 'simple'. Expected "
"[-1,16], got [2,8]");
}

TYPED_TEST_P(ClientInputTest, SetSharedMemory)
Expand Down Expand Up @@ -198,22 +203,43 @@ TYPED_TEST_P(ClientInputTest, SetSharedMemory)
options.model_version_ = "";

std::vector<tc::InferInput*> inputs = {input0_ptr.get(), input1_ptr.get()};
inputs[1]->SetShape({1, 15});

tc::InferResult* results;

// Test 1
inputs[1]->SetShape({1, 15});
FAIL_IF_SUCCESS(
this->client_->Infer(&results, options, inputs),
"expect error with inference request",
("input 'INPUT1' got unexpected byte size " +
std::to_string(input_byte_size) + ", expected " +
std::to_string(input_byte_size - sizeof(int))));

// Test 2
inputs[0]->SetShape({2, 8});
inputs[1]->SetShape({2, 8});
// Assert the request reaches the server
FAIL_IF_SUCCESS(
this->client_->Infer(&results, options, inputs),
"expect error with inference request",
"unexpected shape for input 'INPUT1' for model 'simple'. Expected "
"[-1,16], got [2,8]");

// Get shared memory regions active/registered within triton
// std::string shm_status;
// FAIL_IF_ERR(
// this->client_->SystemSharedMemoryStatus(&shm_status),
// "failed to get shared memory status");
// std::cout << "Shared Memory Status:\n" << shm_status << "\n";
using ClientType = TypeParam;
if constexpr (std::is_same<
ClientType, tc::InferenceServerGrpcClient>::value) {
inference::SystemSharedMemoryStatusResponse shm_status;
FAIL_IF_ERR(
this->client_->SystemSharedMemoryStatus(&shm_status),
"failed to get shared memory status");
std::cout << "Shared Memory Status:\n" << shm_status.DebugString() << "\n";
} else {
std::string shm_status;
FAIL_IF_ERR(
this->client_->SystemSharedMemoryStatus(&shm_status),
"failed to get shared memory status");
std::cout << "Shared Memory Status:\n" << shm_status << "\n";
}

// Unregister shared memory
FAIL_IF_ERR(
Expand All @@ -225,65 +251,7 @@ TYPED_TEST_P(ClientInputTest, SetSharedMemory)
FAIL_IF_ERR(tc::UnlinkSharedMemoryRegion("/input_simple"), "");
}

TYPED_TEST_P(ClientInputTest, AppendString)
{
// Create the data for the two input tensors. Initialize the first
// to unique integers and the second to all ones. The input tensors
// are the string representation of these values.
std::vector<std::string> input0_data(16);
std::vector<std::string> input1_data(16);
for (size_t i = 0; i < 16; ++i) {
input0_data[i] = std::to_string(i);
input1_data[i] = std::to_string(1);
}

std::vector<int64_t> shape{1, 16};

// Initialize the inputs with the data.
tc::InferInput* input0;
tc::InferInput* input1;

FAIL_IF_ERR(
tc::InferInput::Create(&input0, "INPUT0", shape, "BYTES"),
"unable to get INPUT0");
std::shared_ptr<tc::InferInput> input0_ptr;
input0_ptr.reset(input0);
FAIL_IF_ERR(
tc::InferInput::Create(&input1, "INPUT1", shape, "BYTES"),
"unable to get INPUT1");
std::shared_ptr<tc::InferInput> input1_ptr;
input1_ptr.reset(input1);

FAIL_IF_ERR(
input0_ptr->AppendFromString(input0_data),
"unable to set data for INPUT0");
FAIL_IF_ERR(
input1_ptr->AppendFromString(input1_data),
"unable to set data for INPUT1");

// The inference settings. Will be using default for now.
tc::InferOptions options("simple_string");
options.model_version_ = "";

std::vector<tc::InferInput*> inputs = {input0_ptr.get(), input1_ptr.get()};
tc::InferResult* results;
input1_ptr->SetShape({1, 15});
FAIL_IF_SUCCESS(
this->client_->Infer(&results, options, inputs),
"expect error with inference request",
"input 'INPUT1' got unexpected elements count 16, expected 15");

// Check error message and verify the request reaches the server
inputs[1]->SetShape({2, 8});
FAIL_IF_SUCCESS(
this->client_->Infer(&results, options, inputs),
"expect error with inference request",
"input 'INPUT0' batch size does not match other inputs for "
"'simple_string'");
}

REGISTER_TYPED_TEST_SUITE_P(
ClientInputTest, AppendRaw, SetSharedMemory, AppendString);
REGISTER_TYPED_TEST_SUITE_P(ClientInputTest, AppendRaw, SetSharedMemory);

INSTANTIATE_TYPED_TEST_SUITE_P(
GRPC, ClientInputTest, tc::InferenceServerGrpcClient);
Expand Down

0 comments on commit 07059a6

Please sign in to comment.