Skip to content

Commit

Permalink
add timeout test
Browse files Browse the repository at this point in the history
  • Loading branch information
jbkyang-nvi committed Nov 7, 2023
1 parent 2ab7c22 commit 2302ba8
Showing 1 changed file with 40 additions and 14 deletions.
54 changes: 40 additions & 14 deletions src/c++/tests/client_timeout_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@ namespace tc = triton::client;

namespace {

void
TestTimeoutAPIs(
const uint64_t timeout_ms, const std::string& name,
std::unique_ptr<tc::InferenceServerGrpcClient>& grpc_client)
{
std::cout << "testing other apis" << std::endl;
std::map<std::string, std::string> headers;
FAIL_IF_ERR(
grpc_client->LoadModel(name, headers, "", {}, timeout_ms),
"Could not load model");
bool isReady = true;
FAIL_IF_ERR(
grpc_client->IsModelReady(&isReady, name, "", headers, timeout_ms),
"Could not get model ready information");
}

void
ValidateShapeAndDatatype(
const std::string& name, std::shared_ptr<tc::InferResult> result)
Expand Down Expand Up @@ -109,11 +125,11 @@ void
RunSynchronousInference(
std::unique_ptr<tc::InferenceServerGrpcClient>& grpc_client,
std::unique_ptr<tc::InferenceServerHttpClient>& http_client,
uint32_t client_timeout, std::vector<tc::InferInput*>& inputs,
uint32_t client_timeout_ms, std::vector<tc::InferInput*>& inputs,
std::vector<const tc::InferRequestedOutput*>& outputs,
tc::InferOptions& options, std::vector<int32_t>& input0_data)
{
options.client_timeout_ = client_timeout;
options.client_timeout_ = client_timeout_ms;
tc::InferResult* results;
if (grpc_client.get() != nullptr) {
FAIL_IF_ERR(
Expand Down Expand Up @@ -141,7 +157,7 @@ void
RunAsynchronousInference(
std::unique_ptr<tc::InferenceServerGrpcClient>& grpc_client,
std::unique_ptr<tc::InferenceServerHttpClient>& http_client,
uint32_t client_timeout, std::vector<tc::InferInput*>& inputs,
uint32_t client_timeout_ms, std::vector<tc::InferInput*>& inputs,
std::vector<const tc::InferRequestedOutput*>& outputs,
tc::InferOptions& options, std::vector<int32_t>& input0_data)
{
Expand All @@ -167,7 +183,7 @@ RunAsynchronousInference(
cv.notify_all();
};

options.client_timeout_ = client_timeout;
options.client_timeout_ = client_timeout_ms;
if (grpc_client.get() != nullptr) {
FAIL_IF_ERR(
grpc_client->AsyncInfer(callback, options, inputs, outputs),
Expand All @@ -188,7 +204,7 @@ RunAsynchronousInference(
void
RunStreamingInference(
std::unique_ptr<tc::InferenceServerGrpcClient>& grpc_client,
uint32_t client_timeout, std::vector<tc::InferInput*>& inputs,
uint32_t client_timeout_ms, std::vector<tc::InferInput*>& inputs,
std::vector<const tc::InferRequestedOutput*>& outputs,
tc::InferOptions& options, std::vector<int32_t>& input0_data)
{
Expand All @@ -206,13 +222,13 @@ RunStreamingInference(
}
cv.notify_all();
},
false /*ship_stats*/, client_timeout),
false /*ship_stats*/, client_timeout_ms),
"Failed to start the stream");

FAIL_IF_ERR(
grpc_client->AsyncStreamInfer(options, inputs), "unable to run model");

auto timeout = std::chrono::microseconds(client_timeout);
auto timeout = std::chrono::microseconds(client_timeout_ms);
// Wait until all callbacks are invoked or the timeout expires
{
std::unique_lock<std::mutex> lk(mtx);
Expand Down Expand Up @@ -263,11 +279,12 @@ main(int argc, char** argv)
std::string url;
bool async = false;
bool streaming = false;
uint32_t client_timeout = 0;
uint32_t client_timeout_ms = 0;
bool test_client_apis = false;

// Parse commandline...
int opt;
while ((opt = getopt(argc, argv, "vi:u:ast:")) != -1) {
while ((opt = getopt(argc, argv, "vi:u:ast:p")) != -1) {
switch (opt) {
case 'v':
verbose = true;
Expand All @@ -292,7 +309,10 @@ main(int argc, char** argv)
streaming = true;
break;
case 't':
client_timeout = std::stoi(optarg);
client_timeout_ms = std::stoi(optarg);
break;
case 'p':
test_client_apis = true;
break;
case '?':
Usage(argv);
Expand Down Expand Up @@ -335,6 +355,12 @@ main(int argc, char** argv)
"unable to create grpc client");
}

// Test server timeouts for grpc client
if (protocol == "grpc" && test_client_apis) {
TestTimeoutAPIs(client_timeout_ms, model_name, grpc_client);
return 0;
}

// Initialize the tensor data
std::vector<int32_t> input0_data(16);
for (size_t i = 0; i < 16; ++i) {
Expand Down Expand Up @@ -370,22 +396,22 @@ main(int argc, char** argv)
// The inference settings. Will be using default for now.
tc::InferOptions options(model_name);
options.model_version_ = model_version;
options.client_timeout_ = client_timeout;
options.client_timeout_ = client_timeout_ms;

std::vector<tc::InferInput*> inputs = {input0_ptr.get()};
std::vector<const tc::InferRequestedOutput*> outputs = {output0_ptr.get()};

// Send inference request to the inference server.
if (streaming) {
RunStreamingInference(
grpc_client, client_timeout, inputs, outputs, options, input0_data);
grpc_client, client_timeout_ms, inputs, outputs, options, input0_data);
} else if (async) {
RunAsynchronousInference(
grpc_client, http_client, client_timeout, inputs, outputs, options,
grpc_client, http_client, client_timeout_ms, inputs, outputs, options,
input0_data);
} else {
RunSynchronousInference(
grpc_client, http_client, client_timeout, inputs, outputs, options,
grpc_client, http_client, client_timeout_ms, inputs, outputs, options,
input0_data);
}

Expand Down

0 comments on commit 2302ba8

Please sign in to comment.