From 9baaaab172720eb6b69322a8f5340844c23c3941 Mon Sep 17 00:00:00 2001 From: Elias Bermudez Date: Thu, 1 Feb 2024 12:36:11 -0800 Subject: [PATCH] Add testing to verify use_cached_channel false setting --- src/c++/tests/cc_client_test.cc | 146 ++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/src/c++/tests/cc_client_test.cc b/src/c++/tests/cc_client_test.cc index df5981cfb..5ac538cdb 100644 --- a/src/c++/tests/cc_client_test.cc +++ b/src/c++/tests/cc_client_test.cc @@ -134,6 +134,100 @@ class ClientTest : public ::testing::Test { std::string dtype_; }; +class GRPCClientTest : public ::testing::Test { + public: + GRPCClientTest() + : model_name_("onnx_int32_int32_int32"), shape_{1, 16}, dtype_("INT32") + { + } + + void SetUp() override + { + std::string url = "localhost:8001"; + bool verbose = false; + bool use_ssl = false; + const tc::SslOptions& ssl_options = tc::SslOptions(); + const tc::KeepAliveOptions& keepalive_options = tc::KeepAliveOptions(); + bool use_cached_channel = false; + auto err = tc::InferenceServerGrpcClient::Create( + &this->client_, url, verbose, use_ssl, ssl_options, keepalive_options, + use_cached_channel); + ASSERT_TRUE(err.IsOk()) + << "failed to create GRPC client: " << err.Message(); + + // Initialize 3 sets of inputs, each with 16 elements + for (size_t i = 0; i < 3; ++i) { + this->input_data_.emplace_back(); + for (size_t j = 0; j < 16; ++j) { + this->input_data_.back().emplace_back(i * 16 + j); + } + } + } + + tc::Error PrepareInputs( + const std::vector& input_0, const std::vector& input_1, + std::vector* inputs) + { + inputs->emplace_back(); + auto err = tc::InferInput::Create( + &inputs->back(), "INPUT0", this->shape_, this->dtype_); + if (!err.IsOk()) { + return err; + } + err = inputs->back()->AppendRaw( + reinterpret_cast(input_0.data()), + input_0.size() * sizeof(int32_t)); + if (!err.IsOk()) { + return err; + } + inputs->emplace_back(); + err = tc::InferInput::Create( + &inputs->back(), "INPUT1", this->shape_, this->dtype_); + if (!err.IsOk()) { + return err; + } + err = inputs->back()->AppendRaw( + reinterpret_cast(input_1.data()), + input_1.size() * sizeof(int32_t)); + if (!err.IsOk()) { + return err; + } + return tc::Error::Success; + } + + void ValidateOutput( + const std::vector& results, + const std::vector>>& + expected_outputs) + { + ASSERT_EQ(results.size(), expected_outputs.size()) + << "unexpected number of results"; + for (size_t i = 0; i < results.size(); ++i) { + ASSERT_TRUE(results[i]->RequestStatus().IsOk()); + for (const auto& expected : expected_outputs[i]) { + const uint8_t* buf = nullptr; + size_t byte_size = 0; + auto err = results[i]->RawData(expected.first, &buf, &byte_size); + ASSERT_TRUE(err.IsOk()) + << "failed to retrieve output '" << expected.first + << "' for result " << i << ": " << err.Message(); + ASSERT_EQ(byte_size, (expected.second.size() * sizeof(int32_t))); + EXPECT_EQ(memcmp(buf, expected.second.data(), byte_size), 0); + } + } + } + + tc::Error LoadModel( + const std::string& model_name, const std::string& config, + const std::map>& files = {}); + + std::string model_name_; + std::unique_ptr client_; + std::vector> input_data_; + std::vector shape_; + std::string dtype_; +}; + template <> tc::Error ClientTest::LoadModel( @@ -1615,6 +1709,58 @@ TEST_F(GRPCTraceTest, GRPCClearTraceSettings) << std::endl; } +TEST_F(GRPCClientTest, InferMultiNoUseCachedChannel) +{ + // Create only 1 sets of 'options'. + tc::Error err = tc::Error::Success; + std::vector options; + std::vector> inputs; + std::vector> outputs; + + std::vector>> expected_outputs; + options.emplace_back(this->model_name_); + // Not swap + options.back().model_version_ = "1"; + for (size_t i = 0; i < 3; ++i) { + const auto& input_0 = this->input_data_[i % this->input_data_.size()]; + const auto& input_1 = this->input_data_[(i + 1) % this->input_data_.size()]; + inputs.emplace_back(); + err = this->PrepareInputs(input_0, input_1, &inputs.back()); + + tc::InferRequestedOutput* output; + outputs.emplace_back(); + err = tc::InferRequestedOutput::Create(&output, "OUTPUT0"); + ASSERT_TRUE(err.IsOk()) + << "failed to create inference output: " << err.Message(); + outputs.back().emplace_back(output); + err = tc::InferRequestedOutput::Create(&output, "OUTPUT1"); + ASSERT_TRUE(err.IsOk()) + << "failed to create inference output: " << err.Message(); + outputs.back().emplace_back(output); + + expected_outputs.emplace_back(); + { + auto& expected = expected_outputs.back()["OUTPUT0"]; + for (size_t i = 0; i < 16; ++i) { + expected.emplace_back(input_0[i] + input_1[i]); + } + } + { + auto& expected = expected_outputs.back()["OUTPUT1"]; + for (size_t i = 0; i < 16; ++i) { + expected.emplace_back(input_0[i] - input_1[i]); + } + } + } + + std::vector results; + err = this->client_->InferMulti(&results, options, inputs, outputs); + ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: " + << err.Message(); + + EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs)); +} + class TestHttpInferRequest : public tc::HttpInferRequest { public: tc::Error ConvertBinaryInputsToJSON(