Skip to content

Commit

Permalink
Add testing to verify use_cached_channel false setting
Browse files Browse the repository at this point in the history
  • Loading branch information
debermudez committed Feb 1, 2024
1 parent 712d869 commit 9baaaab
Showing 1 changed file with 146 additions and 0 deletions.
146 changes: 146 additions & 0 deletions src/c++/tests/cc_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,100 @@ class ClientTest : public ::testing::Test {
std::string dtype_;
};

class GRPCClientTest : public ::testing::Test {
public:
GRPCClientTest()
: model_name_("onnx_int32_int32_int32"), shape_{1, 16}, dtype_("INT32")
{
}

void SetUp() override
{
std::string url = "localhost:8001";
bool verbose = false;
bool use_ssl = false;
const tc::SslOptions& ssl_options = tc::SslOptions();
const tc::KeepAliveOptions& keepalive_options = tc::KeepAliveOptions();
bool use_cached_channel = false;
auto err = tc::InferenceServerGrpcClient::Create(
&this->client_, url, verbose, use_ssl, ssl_options, keepalive_options,
use_cached_channel);
ASSERT_TRUE(err.IsOk())
<< "failed to create GRPC client: " << err.Message();

// Initialize 3 sets of inputs, each with 16 elements
for (size_t i = 0; i < 3; ++i) {
this->input_data_.emplace_back();
for (size_t j = 0; j < 16; ++j) {
this->input_data_.back().emplace_back(i * 16 + j);
}
}
}

tc::Error PrepareInputs(
const std::vector<int32_t>& input_0, const std::vector<int32_t>& input_1,
std::vector<tc::InferInput*>* inputs)
{
inputs->emplace_back();
auto err = tc::InferInput::Create(
&inputs->back(), "INPUT0", this->shape_, this->dtype_);
if (!err.IsOk()) {
return err;
}
err = inputs->back()->AppendRaw(
reinterpret_cast<const uint8_t*>(input_0.data()),
input_0.size() * sizeof(int32_t));
if (!err.IsOk()) {
return err;
}
inputs->emplace_back();
err = tc::InferInput::Create(
&inputs->back(), "INPUT1", this->shape_, this->dtype_);
if (!err.IsOk()) {
return err;
}
err = inputs->back()->AppendRaw(
reinterpret_cast<const uint8_t*>(input_1.data()),
input_1.size() * sizeof(int32_t));
if (!err.IsOk()) {
return err;
}
return tc::Error::Success;
}

void ValidateOutput(
const std::vector<tc::InferResult*>& results,
const std::vector<std::map<std::string, std::vector<int32_t>>>&
expected_outputs)
{
ASSERT_EQ(results.size(), expected_outputs.size())
<< "unexpected number of results";
for (size_t i = 0; i < results.size(); ++i) {
ASSERT_TRUE(results[i]->RequestStatus().IsOk());
for (const auto& expected : expected_outputs[i]) {
const uint8_t* buf = nullptr;
size_t byte_size = 0;
auto err = results[i]->RawData(expected.first, &buf, &byte_size);
ASSERT_TRUE(err.IsOk())
<< "failed to retrieve output '" << expected.first
<< "' for result " << i << ": " << err.Message();
ASSERT_EQ(byte_size, (expected.second.size() * sizeof(int32_t)));
EXPECT_EQ(memcmp(buf, expected.second.data(), byte_size), 0);
}
}
}

tc::Error LoadModel(
const std::string& model_name, const std::string& config,
const std::map<std::string, std::vector<char>>& files = {});

std::string model_name_;
std::unique_ptr<tc::InferenceServerGrpcClient> client_;
std::vector<std::vector<int32_t>> input_data_;
std::vector<int64_t> shape_;
std::string dtype_;
};

template <>
tc::Error
ClientTest<tc::InferenceServerGrpcClient>::LoadModel(
Expand Down Expand Up @@ -1615,6 +1709,58 @@ TEST_F(GRPCTraceTest, GRPCClearTraceSettings)
<< std::endl;
}

TEST_F(GRPCClientTest, InferMultiNoUseCachedChannel)
{
// Create only 1 sets of 'options'.
tc::Error err = tc::Error::Success;
std::vector<tc::InferOptions> options;
std::vector<std::vector<tc::InferInput*>> inputs;
std::vector<std::vector<const tc::InferRequestedOutput*>> outputs;

std::vector<std::map<std::string, std::vector<int32_t>>> expected_outputs;
options.emplace_back(this->model_name_);
// Not swap
options.back().model_version_ = "1";
for (size_t i = 0; i < 3; ++i) {
const auto& input_0 = this->input_data_[i % this->input_data_.size()];
const auto& input_1 = this->input_data_[(i + 1) % this->input_data_.size()];
inputs.emplace_back();
err = this->PrepareInputs(input_0, input_1, &inputs.back());

tc::InferRequestedOutput* output;
outputs.emplace_back();
err = tc::InferRequestedOutput::Create(&output, "OUTPUT0");
ASSERT_TRUE(err.IsOk())
<< "failed to create inference output: " << err.Message();
outputs.back().emplace_back(output);
err = tc::InferRequestedOutput::Create(&output, "OUTPUT1");
ASSERT_TRUE(err.IsOk())
<< "failed to create inference output: " << err.Message();
outputs.back().emplace_back(output);

expected_outputs.emplace_back();
{
auto& expected = expected_outputs.back()["OUTPUT0"];
for (size_t i = 0; i < 16; ++i) {
expected.emplace_back(input_0[i] + input_1[i]);
}
}
{
auto& expected = expected_outputs.back()["OUTPUT1"];
for (size_t i = 0; i < 16; ++i) {
expected.emplace_back(input_0[i] - input_1[i]);
}
}
}

std::vector<tc::InferResult*> results;
err = this->client_->InferMulti(&results, options, inputs, outputs);
ASSERT_TRUE(err.IsOk()) << "failed to perform multiple inferences: "
<< err.Message();

EXPECT_NO_FATAL_FAILURE(this->ValidateOutput(results, expected_outputs));
}

class TestHttpInferRequest : public tc::HttpInferRequest {
public:
tc::Error ConvertBinaryInputsToJSON(
Expand Down

0 comments on commit 9baaaab

Please sign in to comment.