From 712d86977507708219f36538c39540cb946bf458 Mon Sep 17 00:00:00 2001 From: Markus Hennerbichler Date: Mon, 4 Sep 2023 10:17:31 +0100 Subject: [PATCH 1/6] Only store channel if use_cached_channel is true --- src/c++/library/grpc_client.cc | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/c++/library/grpc_client.cc b/src/c++/library/grpc_client.cc index fe91f5c17..6e67f859e 100644 --- a/src/c++/library/grpc_client.cc +++ b/src/c++/library/grpc_client.cc @@ -135,12 +135,15 @@ GetStub( grpc::CreateCustomChannel(url, credentials, arguments); std::shared_ptr stub = inference::GRPCInferenceService::NewStub(channel); - // Replace if channel / stub have been in the map - if (channel_itr != grpc_channel_stub_map_.end()) { - channel_itr->second = std::make_tuple(1, channel, stub); - } else { - grpc_channel_stub_map_.insert( - std::make_pair(url, std::make_tuple(1, channel, stub))); + + if (use_cached_channel) { + // Replace if channel / stub have been in the map + if (channel_itr != grpc_channel_stub_map_.end()) { + channel_itr->second = std::make_tuple(1, channel, stub); + } else { + grpc_channel_stub_map_.insert( + std::make_pair(url, std::make_tuple(1, channel, stub))); + } } return stub; From 9baaaab172720eb6b69322a8f5340844c23c3941 Mon Sep 17 00:00:00 2001 From: Elias Bermudez Date: Thu, 1 Feb 2024 12:36:11 -0800 Subject: [PATCH 2/6] Add testing to verify use_cached_channel false setting --- src/c++/tests/cc_client_test.cc | 146 ++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/src/c++/tests/cc_client_test.cc b/src/c++/tests/cc_client_test.cc index df5981cfb..5ac538cdb 100644 --- a/src/c++/tests/cc_client_test.cc +++ b/src/c++/tests/cc_client_test.cc @@ -134,6 +134,100 @@ class ClientTest : public ::testing::Test { std::string dtype_; }; +class GRPCClientTest : public ::testing::Test { + public: + GRPCClientTest() + : model_name_("onnx_int32_int32_int32"), shape_{1, 16}, dtype_("INT32") + { + } + + void SetUp() override + { + std::string url = "localhost:8001"; + bool verbose = false; + bool use_ssl = false; + const tc::SslOptions& ssl_options = tc::SslOptions(); + const tc::KeepAliveOptions& keepalive_options = tc::KeepAliveOptions(); + bool use_cached_channel = false; + auto err = tc::InferenceServerGrpcClient::Create( + &this->client_, url, verbose, use_ssl, ssl_options, keepalive_options, + use_cached_channel); + ASSERT_TRUE(err.IsOk()) + << "failed to create GRPC client: " << err.Message(); + + // Initialize 3 sets of inputs, each with 16 elements + for (size_t i = 0; i < 3; ++i) { + this->input_data_.emplace_back(); + for (size_t j = 0; j < 16; ++j) { + this->input_data_.back().emplace_back(i * 16 + j); + } + } + } + + tc::Error PrepareInputs( + const std::vector& input_0, const std::vector& input_1, + std::vector* inputs) + { + inputs->emplace_back(); + auto err = tc::InferInput::Create( + &inputs->back(), "INPUT0", this->shape_, this->dtype_); + if (!err.IsOk()) { + return err; + } + err = inputs->back()->AppendRaw( + reinterpret_cast(input_0.data()), + input_0.size() * sizeof(int32_t)); + if (!err.IsOk()) { + return err; + } + inputs->emplace_back(); + err = tc::InferInput::Create( + &inputs->back(), "INPUT1", this->shape_, this->dtype_); + if (!err.IsOk()) { + return err; + } + err = inputs->back()->AppendRaw( + reinterpret_cast(input_1.data()), + input_1.size() * sizeof(int32_t)); + if (!err.IsOk()) { + return err; + } + return tc::Error::Success; + } + + void ValidateOutput( + const std::vector& results, + const std::vector>>& + expected_outputs) + { + ASSERT_EQ(results.size(), expected_outputs.size()) + << "unexpected number of results"; + for (size_t i = 0; i < results.size(); ++i) { + ASSERT_TRUE(results[i]->RequestStatus().IsOk()); + for (const auto& expected : expected_outputs[i]) { + const uint8_t* buf = nullptr; + size_t byte_size = 0; + auto err = results[i]->RawData(expected.first, &buf, &byte_size); + ASSERT_TRUE(err.IsOk()) + << "failed to retrieve output '" << expected.first + << "' for result " << i << ": " << err.Message(); + ASSERT_EQ(byte_size, (expected.second.size() * sizeof(int32_t))); + EXPECT_EQ(memcmp(buf, expected.second.data(), byte_size), 0); + } + } + } + + tc::Error LoadModel( + const std::string& model_name, const std::string& config, + const std::map>& files = {}); + + std::string model_name_; + std::unique_ptr client_; + std::vector> input_data_; + std::vector shape_; + std::string dtype_; +}; + template <> tc::Error ClientTest::LoadModel( @@ -1615,6 +1709,58 @@ TEST_F(GRPCTraceTest, GRPCClearTraceSettings) << std::endl; } +TEST_F(GRPCClientTest, InferMultiNoUseCachedChannel) +{ + // Create only 1 sets of 'options'. + tc::Error err = tc::Error::Success; + std::vector options; + std::vector> inputs; + std::vector> outputs; + + std::vector>> expected_outputs; + options.emplace_back(this->model_name_); + // Not swap + options.back().model_version_ = "1"; + for (size_t i = 0; i < 3; ++i) { + const auto& input_0 = this->input_data_[i % this->input_data_.size()]; + const auto& input_1 = this->input_data_[(i + 1) % this->input_data_.size()]; + inputs.emplace_back(); + err = this->PrepareInputs(input_0, input_1, &inputs.back()); + + tc::InferRequestedOutput* output; + outputs.emplace_back(); + err = tc::InferRequestedOutput::Create(&output, "OUTPUT0"); + ASSERT_TRUE(err.IsOk()) + << "failed to create inference output: " << err.Message(); + outputs.back().emplace_back(output); + err = tc::InferRequestedOutput::Create(&output, "OUTPUT1"); + ASSERT_TRUE(err.IsOk()) + << "failed to create inference output: " << err.Message(); + outputs.back().emplace_back(output); + + expected_outputs.emplace_back(); + { + auto& expected = expected_outputs.back()["OUTPUT0"]; + for (size_t i = 0; i < 16; ++i) { + expected.emplace_back(input_0[i] + input_1[i]); + } + } + { + auto& expected = expected_outputs.back()["OUTPUT1"]; + for (size_t i = 0; i < 16; ++i) { + expected.emplace_back(input_0[i] - input_1[i]); + } + } + } + + std::vector results; + err = this->client_->InferMulti(&results, options, inputs, outputs); + ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " + << err.Message(); + + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); +} + class TestHttpInferRequest : public tc::HttpInferRequest { public: tc::Error ConvertBinaryInputsToJSON( From d22208d875f5171b70a5aad2d7606a1154e39431 Mon Sep 17 00:00:00 2001 From: Katherine Yang Date: Mon, 15 Apr 2024 14:24:59 -0700 Subject: [PATCH 3/6] working build with suggestions --- src/c++/tests/cc_client_test.cc | 294 +++++++++++++------------------- 1 file changed, 122 insertions(+), 172 deletions(-) diff --git a/src/c++/tests/cc_client_test.cc b/src/c++/tests/cc_client_test.cc index 5ac538cdb..3475de803 100644 --- a/src/c++/tests/cc_client_test.cc +++ b/src/c++/tests/cc_client_test.cc @@ -42,24 +42,28 @@ template class ClientTest : public ::testing::Test { public: ClientTest() - : model_name_("onnx_int32_int32_int32"), shape_{1, 16}, dtype_("INT32") + : model_name_("onnx_int32_int32_int32"), shape_{1, 16}, dtype_("INT32"), + client_type_("") { } void SetUp() override { - std::string url; + std::string url = ""; + std::string client_type = ""; if (std::is_same::value) { url = "localhost:8001"; + this->client_type_ = "grpc"; } else if (std::is_same::value) { url = "localhost:8000"; + this->client_type_ = "http"; } else { ASSERT_TRUE(false) << "Unrecognized client class type '" << typeid(ClientType).name() << "'"; } auto err = ClientType::Create(&this->client_, url); - ASSERT_TRUE(err.IsOk()) - << "failed to create GRPC client: " << err.Message(); + ASSERT_TRUE(err.IsOk()) << "failed to create " << this->client_type_ + << " client: " << err.Message(); // Initialize 3 sets of inputs, each with 16 elements for (size_t i = 0; i < 3; ++i) { @@ -102,100 +106,21 @@ class ClientTest : public ::testing::Test { } void ValidateOutput( - const std::vector& results, - const std::vector>>& - expected_outputs) + const tc::InferResult* result, + const std::map>& expected_output) { - ASSERT_EQ(results.size(), expected_outputs.size()) - << "unexpected number of results"; - for (size_t i = 0; i < results.size(); ++i) { - ASSERT_TRUE(results[i]->RequestStatus().IsOk()); - for (const auto& expected : expected_outputs[i]) { - const uint8_t* buf = nullptr; - size_t byte_size = 0; - auto err = results[i]->RawData(expected.first, &buf, &byte_size); - ASSERT_TRUE(err.IsOk()) - << "failed to retrieve output '" << expected.first - << "' for result " << i << ": " << err.Message(); - ASSERT_EQ(byte_size, (expected.second.size() * sizeof(int32_t))); - EXPECT_EQ(memcmp(buf, expected.second.data(), byte_size), 0); - } + ASSERT_TRUE(result->RequestStatus().IsOk()); + for (const auto& expected : expected_output) { + const uint8_t* buf = nullptr; + size_t byte_size = 0; + auto err = result->RawData(expected.first, &buf, &byte_size); + ASSERT_TRUE(err.IsOk()) << "failed to retrieve output '" << expected.first + << "' for result: " << err.Message(); + ASSERT_EQ(byte_size, (expected.second.size() * sizeof(int32_t))); + EXPECT_EQ(memcmp(buf, expected.second.data(), byte_size), 0); } } - - tc::Error LoadModel( - const std::string& model_name, const std::string& config, - const std::map>& files = {}); - - std::string model_name_; - std::unique_ptr client_; - std::vector> input_data_; - std::vector shape_; - std::string dtype_; -}; - -class GRPCClientTest : public ::testing::Test { - public: - GRPCClientTest() - : model_name_("onnx_int32_int32_int32"), shape_{1, 16}, dtype_("INT32") - { - } - - void SetUp() override - { - std::string url = "localhost:8001"; - bool verbose = false; - bool use_ssl = false; - const tc::SslOptions& ssl_options = tc::SslOptions(); - const tc::KeepAliveOptions& keepalive_options = tc::KeepAliveOptions(); - bool use_cached_channel = false; - auto err = tc::InferenceServerGrpcClient::Create( - &this->client_, url, verbose, use_ssl, ssl_options, keepalive_options, - use_cached_channel); - ASSERT_TRUE(err.IsOk()) - << "failed to create GRPC client: " << err.Message(); - - // Initialize 3 sets of inputs, each with 16 elements - for (size_t i = 0; i < 3; ++i) { - this->input_data_.emplace_back(); - for (size_t j = 0; j < 16; ++j) { - this->input_data_.back().emplace_back(i * 16 + j); - } - } - } - - tc::Error PrepareInputs( - const std::vector& input_0, const std::vector& input_1, - std::vector* inputs) - { - inputs->emplace_back(); - auto err = tc::InferInput::Create( - &inputs->back(), "INPUT0", this->shape_, this->dtype_); - if (!err.IsOk()) { - return err; - } - err = inputs->back()->AppendRaw( - reinterpret_cast(input_0.data()), - input_0.size() * sizeof(int32_t)); - if (!err.IsOk()) { - return err; - } - inputs->emplace_back(); - err = tc::InferInput::Create( - &inputs->back(), "INPUT1", this->shape_, this->dtype_); - if (!err.IsOk()) { - return err; - } - err = inputs->back()->AppendRaw( - reinterpret_cast(input_1.data()), - input_1.size() * sizeof(int32_t)); - if (!err.IsOk()) { - return err; - } - return tc::Error::Success; - } - - void ValidateOutput( + void ValidateOutputs( const std::vector& results, const std::vector>>& expected_outputs) @@ -203,17 +128,7 @@ class GRPCClientTest : public ::testing::Test { ASSERT_EQ(results.size(), expected_outputs.size()) << "unexpected number of results"; for (size_t i = 0; i < results.size(); ++i) { - ASSERT_TRUE(results[i]->RequestStatus().IsOk()); - for (const auto& expected : expected_outputs[i]) { - const uint8_t* buf = nullptr; - size_t byte_size = 0; - auto err = results[i]->RawData(expected.first, &buf, &byte_size); - ASSERT_TRUE(err.IsOk()) - << "failed to retrieve output '" << expected.first - << "' for result " << i << ": " << err.Message(); - ASSERT_EQ(byte_size, (expected.second.size() * sizeof(int32_t))); - EXPECT_EQ(memcmp(buf, expected.second.data(), byte_size), 0); - } + ValidateOutput(results[i], expected_outputs[i]); } } @@ -221,13 +136,43 @@ class GRPCClientTest : public ::testing::Test { const std::string& model_name, const std::string& config, const std::map>& files = {}); + tc::Error CreateClientWithSpecifications( + const std::string& server_url, bool verbose = false, bool use_ssl = false, + const tc::SslOptions& ssl_options = tc::SslOptions(), + const tc::KeepAliveOptions& keepalive_options = tc::KeepAliveOptions(), + const bool use_cached_channel = true); + std::string model_name_; - std::unique_ptr client_; + std::unique_ptr client_; std::vector> input_data_; std::vector shape_; std::string dtype_; + std::string client_type_; }; +template <> +tc::Error +ClientTest::CreateClientWithSpecifications( + const std::string& server_url, bool verbose, bool use_ssl, + const tc::SslOptions& ssl_options, + const tc::KeepAliveOptions& keepalive_options, + const bool use_cached_channel) +{ + return this->client_->Create( + &this->client_, server_url, verbose, use_ssl, ssl_options, + keepalive_options, use_cached_channel); +} + +template <> +tc::Error +ClientTest::CreateClientWithSpecifications( + const std::string& server_url, bool verbose, bool use_ssl, + const tc::SslOptions& ssl_options, + const tc::KeepAliveOptions& keepalive_option, const bool use_cached_channel) +{ + return this->client_->Create(&this->client_, server_url); +} + template <> tc::Error ClientTest::LoadModel( @@ -445,7 +390,7 @@ TYPED_TEST_P(ClientTest, InferMulti) ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); } TYPED_TEST_P(ClientTest, InferMultiDifferentOutputs) @@ -510,7 +455,7 @@ TYPED_TEST_P(ClientTest, InferMultiDifferentOutputs) ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); } TYPED_TEST_P(ClientTest, InferMultiDifferentOptions) @@ -572,7 +517,7 @@ TYPED_TEST_P(ClientTest, InferMultiDifferentOptions) ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); } TYPED_TEST_P(ClientTest, InferMultiOneOption) @@ -624,7 +569,7 @@ TYPED_TEST_P(ClientTest, InferMultiOneOption) ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); } TYPED_TEST_P(ClientTest, InferMultiOneOutput) @@ -677,7 +622,7 @@ TYPED_TEST_P(ClientTest, InferMultiOneOutput) ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); } TYPED_TEST_P(ClientTest, InferMultiNoOutput) @@ -725,7 +670,7 @@ TYPED_TEST_P(ClientTest, InferMultiNoOutput) ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); } TYPED_TEST_P(ClientTest, InferMultiMismatchOptions) @@ -862,7 +807,7 @@ TYPED_TEST_P(ClientTest, AsyncInferMulti) std::unique_lock lk(mu); cv.wait(lk, [this, &results] { return !results.empty(); }); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); } TYPED_TEST_P(ClientTest, AsyncInferMultiDifferentOutputs) @@ -939,7 +884,7 @@ TYPED_TEST_P(ClientTest, AsyncInferMultiDifferentOutputs) std::unique_lock lk(mu); cv.wait(lk, [this, &results] { return !results.empty(); }); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); } TYPED_TEST_P(ClientTest, AsyncInferMultiDifferentOptions) @@ -1013,7 +958,7 @@ TYPED_TEST_P(ClientTest, AsyncInferMultiDifferentOptions) std::unique_lock lk(mu); cv.wait(lk, [this, &results] { return !results.empty(); }); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); } TYPED_TEST_P(ClientTest, AsyncInferMultiOneOption) @@ -1077,7 +1022,7 @@ TYPED_TEST_P(ClientTest, AsyncInferMultiOneOption) std::unique_lock lk(mu); cv.wait(lk, [this, &results] { return !results.empty(); }); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); } TYPED_TEST_P(ClientTest, AsyncInferMultiOneOutput) @@ -1142,7 +1087,7 @@ TYPED_TEST_P(ClientTest, AsyncInferMultiOneOutput) std::unique_lock lk(mu); cv.wait(lk, [this, &results] { return !results.empty(); }); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); } TYPED_TEST_P(ClientTest, AsyncInferMultiNoOutput) @@ -1202,7 +1147,7 @@ TYPED_TEST_P(ClientTest, AsyncInferMultiNoOutput) std::unique_lock lk(mu); cv.wait(lk, [this, &results] { return !results.empty(); }); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); } TYPED_TEST_P(ClientTest, AsyncInferMultiMismatchOptions) @@ -1443,6 +1388,63 @@ TYPED_TEST_P(ClientTest, LoadWithConfigOverride) } } +TYPED_TEST_P(ClientTest, InferNoUseCachedChannel) +{ + if (this->client_type_ == "grpc") { + std::string url = "localhost:8001"; + bool verbose = false; + bool use_ssl = false; + const tc::SslOptions& ssl_options = tc::SslOptions(); + const tc::KeepAliveOptions& keepalive_options = tc::KeepAliveOptions(); + bool use_cached_channel = false; + auto err = this->CreateClientWithSpecifications( + url, verbose, use_ssl, ssl_options, keepalive_options, + use_cached_channel); + ASSERT_TRUE(err.IsOk()) + << "failed to create GRPC client: " << err.Message(); + + tc::InferOptions option(this->model_name_); + std::vector inputs; + std::vector outputs; + + std::map> expected_outputs; + // Not swap + option.model_version_ = "1"; + const auto& input_0 = this->input_data_[0]; + const auto& input_1 = this->input_data_[1]; + err = this->PrepareInputs(input_0, input_1, &inputs); + + tc::InferRequestedOutput* output; + err = tc::InferRequestedOutput::Create(&output, "OUTPUT0"); + ASSERT_TRUE(err.IsOk()) + << "failed to create inference output: " << err.Message(); + outputs.emplace_back(output); + err = tc::InferRequestedOutput::Create(&output, "OUTPUT1"); + ASSERT_TRUE(err.IsOk()) + << "failed to create inference output: " << err.Message(); + outputs.emplace_back(output); + + { + auto& expected = expected_outputs["OUTPUT0"]; + for (size_t i = 0; i < 16; ++i) { + expected.emplace_back(input_0[i] + input_1[i]); + } + } + { + auto& expected = expected_outputs["OUTPUT1"]; + for (size_t i = 0; i < 16; ++i) { + expected.emplace_back(input_0[i] - input_1[i]); + } + } + + tc::InferResult* result; + err = this->client_->Infer(&result, option, inputs, outputs); + ASSERT_TRUE(err.IsOk()) + << "failed to perform multiple inferences: " << err.Message(); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(result, expected_outputs)); + } +} + TEST_F(HTTPTraceTest, HTTPUpdateTraceSettings) { // Update model and global trace settings in order, and expect the global @@ -1709,58 +1711,6 @@ TEST_F(GRPCTraceTest, GRPCClearTraceSettings) << std::endl; } -TEST_F(GRPCClientTest, InferMultiNoUseCachedChannel) -{ - // Create only 1 sets of 'options'. - tc::Error err = tc::Error::Success; - std::vector options; - std::vector> inputs; - std::vector> outputs; - - std::vector>> expected_outputs; - options.emplace_back(this->model_name_); - // Not swap - options.back().model_version_ = "1"; - for (size_t i = 0; i < 3; ++i) { - const auto& input_0 = this->input_data_[i % this->input_data_.size()]; - const auto& input_1 = this->input_data_[(i + 1) % this->input_data_.size()]; - inputs.emplace_back(); - err = this->PrepareInputs(input_0, input_1, &inputs.back()); - - tc::InferRequestedOutput* output; - outputs.emplace_back(); - err = tc::InferRequestedOutput::Create(&output, "OUTPUT0"); - ASSERT_TRUE(err.IsOk()) - << "failed to create inference output: " << err.Message(); - outputs.back().emplace_back(output); - err = tc::InferRequestedOutput::Create(&output, "OUTPUT1"); - ASSERT_TRUE(err.IsOk()) - << "failed to create inference output: " << err.Message(); - outputs.back().emplace_back(output); - - expected_outputs.emplace_back(); - { - auto& expected = expected_outputs.back()["OUTPUT0"]; - for (size_t i = 0; i < 16; ++i) { - expected.emplace_back(input_0[i] + input_1[i]); - } - } - { - auto& expected = expected_outputs.back()["OUTPUT1"]; - for (size_t i = 0; i < 16; ++i) { - expected.emplace_back(input_0[i] - input_1[i]); - } - } - } - - std::vector results; - err = this->client_->InferMulti(&results, options, inputs, outputs); - ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " - << err.Message(); - - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); -} - class TestHttpInferRequest : public tc::HttpInferRequest { public: tc::Error ConvertBinaryInputsToJSON( @@ -2301,7 +2251,7 @@ REGISTER_TYPED_TEST_SUITE_P( AsyncInferMultiDifferentOptions, AsyncInferMultiOneOption, AsyncInferMultiOneOutput, AsyncInferMultiNoOutput, AsyncInferMultiMismatchOptions, AsyncInferMultiMismatchOutputs, - LoadWithFileOverride, LoadWithConfigOverride); + LoadWithFileOverride, LoadWithConfigOverride, InferNoUseCachedChannel); INSTANTIATE_TYPED_TEST_SUITE_P(GRPC, ClientTest, tc::InferenceServerGrpcClient); INSTANTIATE_TYPED_TEST_SUITE_P(HTTP, ClientTest, tc::InferenceServerHttpClient); From 0080426abc7b998bc718eef4ca37fd073cbd651a Mon Sep 17 00:00:00 2001 From: Katherine Yang Date: Mon, 15 Apr 2024 15:43:02 -0700 Subject: [PATCH 4/6] Revert "working build with suggestions" This reverts commit d22208d875f5171b70a5aad2d7606a1154e39431. --- src/c++/tests/cc_client_test.cc | 294 +++++++++++++++++++------------- 1 file changed, 172 insertions(+), 122 deletions(-) diff --git a/src/c++/tests/cc_client_test.cc b/src/c++/tests/cc_client_test.cc index 3475de803..5ac538cdb 100644 --- a/src/c++/tests/cc_client_test.cc +++ b/src/c++/tests/cc_client_test.cc @@ -42,28 +42,24 @@ template class ClientTest : public ::testing::Test { public: ClientTest() - : model_name_("onnx_int32_int32_int32"), shape_{1, 16}, dtype_("INT32"), - client_type_("") + : model_name_("onnx_int32_int32_int32"), shape_{1, 16}, dtype_("INT32") { } void SetUp() override { - std::string url = ""; - std::string client_type = ""; + std::string url; if (std::is_same::value) { url = "localhost:8001"; - this->client_type_ = "grpc"; } else if (std::is_same::value) { url = "localhost:8000"; - this->client_type_ = "http"; } else { ASSERT_TRUE(false) << "Unrecognized client class type '" << typeid(ClientType).name() << "'"; } auto err = ClientType::Create(&this->client_, url); - ASSERT_TRUE(err.IsOk()) << "failed to create " << this->client_type_ - << " client: " << err.Message(); + ASSERT_TRUE(err.IsOk()) + << "failed to create GRPC client: " << err.Message(); // Initialize 3 sets of inputs, each with 16 elements for (size_t i = 0; i < 3; ++i) { @@ -106,21 +102,6 @@ class ClientTest : public ::testing::Test { } void ValidateOutput( - const tc::InferResult* result, - const std::map>& expected_output) - { - ASSERT_TRUE(result->RequestStatus().IsOk()); - for (const auto& expected : expected_output) { - const uint8_t* buf = nullptr; - size_t byte_size = 0; - auto err = result->RawData(expected.first, &buf, &byte_size); - ASSERT_TRUE(err.IsOk()) << "failed to retrieve output '" << expected.first - << "' for result: " << err.Message(); - ASSERT_EQ(byte_size, (expected.second.size() * sizeof(int32_t))); - EXPECT_EQ(memcmp(buf, expected.second.data(), byte_size), 0); - } - } - void ValidateOutputs( const std::vector& results, const std::vector>>& expected_outputs) @@ -128,7 +109,17 @@ class ClientTest : public ::testing::Test { ASSERT_EQ(results.size(), expected_outputs.size()) << "unexpected number of results"; for (size_t i = 0; i < results.size(); ++i) { - ValidateOutput(results[i], expected_outputs[i]); + ASSERT_TRUE(results[i]->RequestStatus().IsOk()); + for (const auto& expected : expected_outputs[i]) { + const uint8_t* buf = nullptr; + size_t byte_size = 0; + auto err = results[i]->RawData(expected.first, &buf, &byte_size); + ASSERT_TRUE(err.IsOk()) + << "failed to retrieve output '" << expected.first + << "' for result " << i << ": " << err.Message(); + ASSERT_EQ(byte_size, (expected.second.size() * sizeof(int32_t))); + EXPECT_EQ(memcmp(buf, expected.second.data(), byte_size), 0); + } } } @@ -136,42 +127,106 @@ class ClientTest : public ::testing::Test { const std::string& model_name, const std::string& config, const std::map>& files = {}); - tc::Error CreateClientWithSpecifications( - const std::string& server_url, bool verbose = false, bool use_ssl = false, - const tc::SslOptions& ssl_options = tc::SslOptions(), - const tc::KeepAliveOptions& keepalive_options = tc::KeepAliveOptions(), - const bool use_cached_channel = true); - std::string model_name_; std::unique_ptr client_; std::vector> input_data_; std::vector shape_; std::string dtype_; - std::string client_type_; }; -template <> -tc::Error -ClientTest::CreateClientWithSpecifications( - const std::string& server_url, bool verbose, bool use_ssl, - const tc::SslOptions& ssl_options, - const tc::KeepAliveOptions& keepalive_options, - const bool use_cached_channel) -{ - return this->client_->Create( - &this->client_, server_url, verbose, use_ssl, ssl_options, - keepalive_options, use_cached_channel); -} +class GRPCClientTest : public ::testing::Test { + public: + GRPCClientTest() + : model_name_("onnx_int32_int32_int32"), shape_{1, 16}, dtype_("INT32") + { + } -template <> -tc::Error -ClientTest::CreateClientWithSpecifications( - const std::string& server_url, bool verbose, bool use_ssl, - const tc::SslOptions& ssl_options, - const tc::KeepAliveOptions& keepalive_option, const bool use_cached_channel) -{ - return this->client_->Create(&this->client_, server_url); -} + void SetUp() override + { + std::string url = "localhost:8001"; + bool verbose = false; + bool use_ssl = false; + const tc::SslOptions& ssl_options = tc::SslOptions(); + const tc::KeepAliveOptions& keepalive_options = tc::KeepAliveOptions(); + bool use_cached_channel = false; + auto err = tc::InferenceServerGrpcClient::Create( + &this->client_, url, verbose, use_ssl, ssl_options, keepalive_options, + use_cached_channel); + ASSERT_TRUE(err.IsOk()) + << "failed to create GRPC client: " << err.Message(); + + // Initialize 3 sets of inputs, each with 16 elements + for (size_t i = 0; i < 3; ++i) { + this->input_data_.emplace_back(); + for (size_t j = 0; j < 16; ++j) { + this->input_data_.back().emplace_back(i * 16 + j); + } + } + } + + tc::Error PrepareInputs( + const std::vector& input_0, const std::vector& input_1, + std::vector* inputs) + { + inputs->emplace_back(); + auto err = tc::InferInput::Create( + &inputs->back(), "INPUT0", this->shape_, this->dtype_); + if (!err.IsOk()) { + return err; + } + err = inputs->back()->AppendRaw( + reinterpret_cast(input_0.data()), + input_0.size() * sizeof(int32_t)); + if (!err.IsOk()) { + return err; + } + inputs->emplace_back(); + err = tc::InferInput::Create( + &inputs->back(), "INPUT1", this->shape_, this->dtype_); + if (!err.IsOk()) { + return err; + } + err = inputs->back()->AppendRaw( + reinterpret_cast(input_1.data()), + input_1.size() * sizeof(int32_t)); + if (!err.IsOk()) { + return err; + } + return tc::Error::Success; + } + + void ValidateOutput( + const std::vector& results, + const std::vector>>& + expected_outputs) + { + ASSERT_EQ(results.size(), expected_outputs.size()) + << "unexpected number of results"; + for (size_t i = 0; i < results.size(); ++i) { + ASSERT_TRUE(results[i]->RequestStatus().IsOk()); + for (const auto& expected : expected_outputs[i]) { + const uint8_t* buf = nullptr; + size_t byte_size = 0; + auto err = results[i]->RawData(expected.first, &buf, &byte_size); + ASSERT_TRUE(err.IsOk()) + << "failed to retrieve output '" << expected.first + << "' for result " << i << ": " << err.Message(); + ASSERT_EQ(byte_size, (expected.second.size() * sizeof(int32_t))); + EXPECT_EQ(memcmp(buf, expected.second.data(), byte_size), 0); + } + } + } + + tc::Error LoadModel( + const std::string& model_name, const std::string& config, + const std::map>& files = {}); + + std::string model_name_; + std::unique_ptr client_; + std::vector> input_data_; + std::vector shape_; + std::string dtype_; +}; template <> tc::Error @@ -390,7 +445,7 @@ TYPED_TEST_P(ClientTest, InferMulti) ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); } TYPED_TEST_P(ClientTest, InferMultiDifferentOutputs) @@ -455,7 +510,7 @@ TYPED_TEST_P(ClientTest, InferMultiDifferentOutputs) ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); } TYPED_TEST_P(ClientTest, InferMultiDifferentOptions) @@ -517,7 +572,7 @@ TYPED_TEST_P(ClientTest, InferMultiDifferentOptions) ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); } TYPED_TEST_P(ClientTest, InferMultiOneOption) @@ -569,7 +624,7 @@ TYPED_TEST_P(ClientTest, InferMultiOneOption) ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); } TYPED_TEST_P(ClientTest, InferMultiOneOutput) @@ -622,7 +677,7 @@ TYPED_TEST_P(ClientTest, InferMultiOneOutput) ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); } TYPED_TEST_P(ClientTest, InferMultiNoOutput) @@ -670,7 +725,7 @@ TYPED_TEST_P(ClientTest, InferMultiNoOutput) ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); } TYPED_TEST_P(ClientTest, InferMultiMismatchOptions) @@ -807,7 +862,7 @@ TYPED_TEST_P(ClientTest, AsyncInferMulti) std::unique_lock lk(mu); cv.wait(lk, [this, &results] { return !results.empty(); }); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); } TYPED_TEST_P(ClientTest, AsyncInferMultiDifferentOutputs) @@ -884,7 +939,7 @@ TYPED_TEST_P(ClientTest, AsyncInferMultiDifferentOutputs) std::unique_lock lk(mu); cv.wait(lk, [this, &results] { return !results.empty(); }); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); } TYPED_TEST_P(ClientTest, AsyncInferMultiDifferentOptions) @@ -958,7 +1013,7 @@ TYPED_TEST_P(ClientTest, AsyncInferMultiDifferentOptions) std::unique_lock lk(mu); cv.wait(lk, [this, &results] { return !results.empty(); }); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); } TYPED_TEST_P(ClientTest, AsyncInferMultiOneOption) @@ -1022,7 +1077,7 @@ TYPED_TEST_P(ClientTest, AsyncInferMultiOneOption) std::unique_lock lk(mu); cv.wait(lk, [this, &results] { return !results.empty(); }); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); } TYPED_TEST_P(ClientTest, AsyncInferMultiOneOutput) @@ -1087,7 +1142,7 @@ TYPED_TEST_P(ClientTest, AsyncInferMultiOneOutput) std::unique_lock lk(mu); cv.wait(lk, [this, &results] { return !results.empty(); }); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); } TYPED_TEST_P(ClientTest, AsyncInferMultiNoOutput) @@ -1147,7 +1202,7 @@ TYPED_TEST_P(ClientTest, AsyncInferMultiNoOutput) std::unique_lock lk(mu); cv.wait(lk, [this, &results] { return !results.empty(); }); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutputs(results, expected_outputs)); + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); } TYPED_TEST_P(ClientTest, AsyncInferMultiMismatchOptions) @@ -1388,63 +1443,6 @@ TYPED_TEST_P(ClientTest, LoadWithConfigOverride) } } -TYPED_TEST_P(ClientTest, InferNoUseCachedChannel) -{ - if (this->client_type_ == "grpc") { - std::string url = "localhost:8001"; - bool verbose = false; - bool use_ssl = false; - const tc::SslOptions& ssl_options = tc::SslOptions(); - const tc::KeepAliveOptions& keepalive_options = tc::KeepAliveOptions(); - bool use_cached_channel = false; - auto err = this->CreateClientWithSpecifications( - url, verbose, use_ssl, ssl_options, keepalive_options, - use_cached_channel); - ASSERT_TRUE(err.IsOk()) - << "failed to create GRPC client: " << err.Message(); - - tc::InferOptions option(this->model_name_); - std::vector inputs; - std::vector outputs; - - std::map> expected_outputs; - // Not swap - option.model_version_ = "1"; - const auto& input_0 = this->input_data_[0]; - const auto& input_1 = this->input_data_[1]; - err = this->PrepareInputs(input_0, input_1, &inputs); - - tc::InferRequestedOutput* output; - err = tc::InferRequestedOutput::Create(&output, "OUTPUT0"); - ASSERT_TRUE(err.IsOk()) - << "failed to create inference output: " << err.Message(); - outputs.emplace_back(output); - err = tc::InferRequestedOutput::Create(&output, "OUTPUT1"); - ASSERT_TRUE(err.IsOk()) - << "failed to create inference output: " << err.Message(); - outputs.emplace_back(output); - - { - auto& expected = expected_outputs["OUTPUT0"]; - for (size_t i = 0; i < 16; ++i) { - expected.emplace_back(input_0[i] + input_1[i]); - } - } - { - auto& expected = expected_outputs["OUTPUT1"]; - for (size_t i = 0; i < 16; ++i) { - expected.emplace_back(input_0[i] - input_1[i]); - } - } - - tc::InferResult* result; - err = this->client_->Infer(&result, option, inputs, outputs); - ASSERT_TRUE(err.IsOk()) - << "failed to perform multiple inferences: " << err.Message(); - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(result, expected_outputs)); - } -} - TEST_F(HTTPTraceTest, HTTPUpdateTraceSettings) { // Update model and global trace settings in order, and expect the global @@ -1711,6 +1709,58 @@ TEST_F(GRPCTraceTest, GRPCClearTraceSettings) << std::endl; } +TEST_F(GRPCClientTest, InferMultiNoUseCachedChannel) +{ + // Create only 1 sets of 'options'. + tc::Error err = tc::Error::Success; + std::vector options; + std::vector> inputs; + std::vector> outputs; + + std::vector>> expected_outputs; + options.emplace_back(this->model_name_); + // Not swap + options.back().model_version_ = "1"; + for (size_t i = 0; i < 3; ++i) { + const auto& input_0 = this->input_data_[i % this->input_data_.size()]; + const auto& input_1 = this->input_data_[(i + 1) % this->input_data_.size()]; + inputs.emplace_back(); + err = this->PrepareInputs(input_0, input_1, &inputs.back()); + + tc::InferRequestedOutput* output; + outputs.emplace_back(); + err = tc::InferRequestedOutput::Create(&output, "OUTPUT0"); + ASSERT_TRUE(err.IsOk()) + << "failed to create inference output: " << err.Message(); + outputs.back().emplace_back(output); + err = tc::InferRequestedOutput::Create(&output, "OUTPUT1"); + ASSERT_TRUE(err.IsOk()) + << "failed to create inference output: " << err.Message(); + outputs.back().emplace_back(output); + + expected_outputs.emplace_back(); + { + auto& expected = expected_outputs.back()["OUTPUT0"]; + for (size_t i = 0; i < 16; ++i) { + expected.emplace_back(input_0[i] + input_1[i]); + } + } + { + auto& expected = expected_outputs.back()["OUTPUT1"]; + for (size_t i = 0; i < 16; ++i) { + expected.emplace_back(input_0[i] - input_1[i]); + } + } + } + + std::vector results; + err = this->client_->InferMulti(&results, options, inputs, outputs); + ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " + << err.Message(); + + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); +} + class TestHttpInferRequest : public tc::HttpInferRequest { public: tc::Error ConvertBinaryInputsToJSON( @@ -2251,7 +2301,7 @@ REGISTER_TYPED_TEST_SUITE_P( AsyncInferMultiDifferentOptions, AsyncInferMultiOneOption, AsyncInferMultiOneOutput, AsyncInferMultiNoOutput, AsyncInferMultiMismatchOptions, AsyncInferMultiMismatchOutputs, - LoadWithFileOverride, LoadWithConfigOverride, InferNoUseCachedChannel); + LoadWithFileOverride, LoadWithConfigOverride); INSTANTIATE_TYPED_TEST_SUITE_P(GRPC, ClientTest, tc::InferenceServerGrpcClient); INSTANTIATE_TYPED_TEST_SUITE_P(HTTP, ClientTest, tc::InferenceServerHttpClient); From 9630af8610895ef808351f114e053c4f5d42b2ac Mon Sep 17 00:00:00 2001 From: Katherine Yang Date: Mon, 15 Apr 2024 15:43:18 -0700 Subject: [PATCH 5/6] Revert "Add testing to verify use_cached_channel false setting" This reverts commit 9baaaab172720eb6b69322a8f5340844c23c3941. --- src/c++/tests/cc_client_test.cc | 146 -------------------------------- 1 file changed, 146 deletions(-) diff --git a/src/c++/tests/cc_client_test.cc b/src/c++/tests/cc_client_test.cc index 5ac538cdb..df5981cfb 100644 --- a/src/c++/tests/cc_client_test.cc +++ b/src/c++/tests/cc_client_test.cc @@ -134,100 +134,6 @@ class ClientTest : public ::testing::Test { std::string dtype_; }; -class GRPCClientTest : public ::testing::Test { - public: - GRPCClientTest() - : model_name_("onnx_int32_int32_int32"), shape_{1, 16}, dtype_("INT32") - { - } - - void SetUp() override - { - std::string url = "localhost:8001"; - bool verbose = false; - bool use_ssl = false; - const tc::SslOptions& ssl_options = tc::SslOptions(); - const tc::KeepAliveOptions& keepalive_options = tc::KeepAliveOptions(); - bool use_cached_channel = false; - auto err = tc::InferenceServerGrpcClient::Create( - &this->client_, url, verbose, use_ssl, ssl_options, keepalive_options, - use_cached_channel); - ASSERT_TRUE(err.IsOk()) - << "failed to create GRPC client: " << err.Message(); - - // Initialize 3 sets of inputs, each with 16 elements - for (size_t i = 0; i < 3; ++i) { - this->input_data_.emplace_back(); - for (size_t j = 0; j < 16; ++j) { - this->input_data_.back().emplace_back(i * 16 + j); - } - } - } - - tc::Error PrepareInputs( - const std::vector& input_0, const std::vector& input_1, - std::vector* inputs) - { - inputs->emplace_back(); - auto err = tc::InferInput::Create( - &inputs->back(), "INPUT0", this->shape_, this->dtype_); - if (!err.IsOk()) { - return err; - } - err = inputs->back()->AppendRaw( - reinterpret_cast(input_0.data()), - input_0.size() * sizeof(int32_t)); - if (!err.IsOk()) { - return err; - } - inputs->emplace_back(); - err = tc::InferInput::Create( - &inputs->back(), "INPUT1", this->shape_, this->dtype_); - if (!err.IsOk()) { - return err; - } - err = inputs->back()->AppendRaw( - reinterpret_cast(input_1.data()), - input_1.size() * sizeof(int32_t)); - if (!err.IsOk()) { - return err; - } - return tc::Error::Success; - } - - void ValidateOutput( - const std::vector& results, - const std::vector>>& - expected_outputs) - { - ASSERT_EQ(results.size(), expected_outputs.size()) - << "unexpected number of results"; - for (size_t i = 0; i < results.size(); ++i) { - ASSERT_TRUE(results[i]->RequestStatus().IsOk()); - for (const auto& expected : expected_outputs[i]) { - const uint8_t* buf = nullptr; - size_t byte_size = 0; - auto err = results[i]->RawData(expected.first, &buf, &byte_size); - ASSERT_TRUE(err.IsOk()) - << "failed to retrieve output '" << expected.first - << "' for result " << i << ": " << err.Message(); - ASSERT_EQ(byte_size, (expected.second.size() * sizeof(int32_t))); - EXPECT_EQ(memcmp(buf, expected.second.data(), byte_size), 0); - } - } - } - - tc::Error LoadModel( - const std::string& model_name, const std::string& config, - const std::map>& files = {}); - - std::string model_name_; - std::unique_ptr client_; - std::vector> input_data_; - std::vector shape_; - std::string dtype_; -}; - template <> tc::Error ClientTest::LoadModel( @@ -1709,58 +1615,6 @@ TEST_F(GRPCTraceTest, GRPCClearTraceSettings) << std::endl; } -TEST_F(GRPCClientTest, InferMultiNoUseCachedChannel) -{ - // Create only 1 sets of 'options'. - tc::Error err = tc::Error::Success; - std::vector options; - std::vector> inputs; - std::vector> outputs; - - std::vector>> expected_outputs; - options.emplace_back(this->model_name_); - // Not swap - options.back().model_version_ = "1"; - for (size_t i = 0; i < 3; ++i) { - const auto& input_0 = this->input_data_[i % this->input_data_.size()]; - const auto& input_1 = this->input_data_[(i + 1) % this->input_data_.size()]; - inputs.emplace_back(); - err = this->PrepareInputs(input_0, input_1, &inputs.back()); - - tc::InferRequestedOutput* output; - outputs.emplace_back(); - err = tc::InferRequestedOutput::Create(&output, "OUTPUT0"); - ASSERT_TRUE(err.IsOk()) - << "failed to create inference output: " << err.Message(); - outputs.back().emplace_back(output); - err = tc::InferRequestedOutput::Create(&output, "OUTPUT1"); - ASSERT_TRUE(err.IsOk()) - << "failed to create inference output: " << err.Message(); - outputs.back().emplace_back(output); - - expected_outputs.emplace_back(); - { - auto& expected = expected_outputs.back()["OUTPUT0"]; - for (size_t i = 0; i < 16; ++i) { - expected.emplace_back(input_0[i] + input_1[i]); - } - } - { - auto& expected = expected_outputs.back()["OUTPUT1"]; - for (size_t i = 0; i < 16; ++i) { - expected.emplace_back(input_0[i] - input_1[i]); - } - } - } - - std::vector results; - err = this->client_->InferMulti(&results, options, inputs, outputs); - ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " - << err.Message(); - - EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); -} - class TestHttpInferRequest : public tc::HttpInferRequest { public: tc::Error ConvertBinaryInputsToJSON( From be5b7a1b6818c5a46072c1d7e0e4386746e0cc60 Mon Sep 17 00:00:00 2001 From: Katherine Yang Date: Mon, 15 Apr 2024 17:23:24 -0700 Subject: [PATCH 6/6] add updated tests --- src/c++/examples/simple_grpc_infer_client.cc | 4 ++++ src/c++/library/grpc_client.cc | 11 ++++++++++- src/c++/library/grpc_client.h | 3 +++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/c++/examples/simple_grpc_infer_client.cc b/src/c++/examples/simple_grpc_infer_client.cc index 5959d3cd6..13c5e8bc1 100644 --- a/src/c++/examples/simple_grpc_infer_client.cc +++ b/src/c++/examples/simple_grpc_infer_client.cc @@ -229,6 +229,10 @@ main(int argc, char** argv) use_cached_channel), err); + if (verbose) { + std::cout << "There are " << client->GetNumCachedChannels() + << " cached channels" << std::endl; + } // Create the data for the two input tensors. Initialize the first // to unique integers and the second to all ones. std::vector input0_data(16); diff --git a/src/c++/library/grpc_client.cc b/src/c++/library/grpc_client.cc index 6e67f859e..c9ee70125 100644 --- a/src/c++/library/grpc_client.cc +++ b/src/c++/library/grpc_client.cc @@ -94,7 +94,7 @@ GetStub( "TRITON_CLIENT_GRPC_CHANNEL_MAX_SHARE_COUNT", "6")); const auto& channel_itr = grpc_channel_stub_map_.find(url); // Reuse cached channel if the channel is found in the map and - // used_cached_channel flag is true + // use_cached_channel flag is true if ((channel_itr != grpc_channel_stub_map_.end()) && use_cached_channel) { // check if NewStub should be created const auto& shared_count = std::get<0>(channel_itr->second); @@ -136,6 +136,8 @@ GetStub( std::shared_ptr stub = inference::GRPCInferenceService::NewStub(channel); + // If `use_cached_channel` is true, create no new channels even if there + // are no cached channels. if (use_cached_channel) { // Replace if channel / stub have been in the map if (channel_itr != grpc_channel_stub_map_.end()) { @@ -1706,6 +1708,13 @@ InferenceServerGrpcClient::~InferenceServerGrpcClient() StopStream(); } +size_t +InferenceServerGrpcClient::GetNumCachedChannels() const +{ + std::lock_guard lock(grpc_channel_stub_map_mtx_); + return grpc_channel_stub_map_.size(); +} + //============================================================================== }} // namespace triton::client diff --git a/src/c++/library/grpc_client.h b/src/c++/library/grpc_client.h index cc90b12de..f13336214 100644 --- a/src/c++/library/grpc_client.h +++ b/src/c++/library/grpc_client.h @@ -602,6 +602,9 @@ class InferenceServerGrpcClient : public InferenceServerClient { const std::vector& outputs = std::vector()); + // Number of Cached Channels + size_t GetNumCachedChannels() const; + private: InferenceServerGrpcClient( const std::string& url, bool verbose, bool use_ssl,