diff --git a/src/c++/tests/client_timeout_test.cc b/src/c++/tests/client_timeout_test.cc index 71226da53..d38af72d7 100644 --- a/src/c++/tests/client_timeout_test.cc +++ b/src/c++/tests/client_timeout_test.cc @@ -47,6 +47,22 @@ namespace tc = triton::client; namespace { +void +TestTimeoutAPIs( + const uint64_t timeout_ms, const std::string& name, + std::unique_ptr& grpc_client) +{ + std::cout << "testing other apis" << std::endl; + std::map headers; + FAIL_IF_ERR( + grpc_client->LoadModel(name, headers, "", {}, timeout_ms), + "Could not load model"); + bool isReady = true; + FAIL_IF_ERR( + grpc_client->IsModelReady(&isReady, name, "", headers, timeout_ms), + "Could not get model ready information"); +} + void ValidateShapeAndDatatype( const std::string& name, std::shared_ptr result) @@ -109,11 +125,11 @@ void RunSynchronousInference( std::unique_ptr& grpc_client, std::unique_ptr& http_client, - uint32_t client_timeout, std::vector& inputs, + uint32_t client_timeout_ms, std::vector& inputs, std::vector& outputs, tc::InferOptions& options, std::vector& input0_data) { - options.client_timeout_ = client_timeout; + options.client_timeout_ = client_timeout_ms; tc::InferResult* results; if (grpc_client.get() != nullptr) { FAIL_IF_ERR( @@ -141,7 +157,7 @@ void RunAsynchronousInference( std::unique_ptr& grpc_client, std::unique_ptr& http_client, - uint32_t client_timeout, std::vector& inputs, + uint32_t client_timeout_ms, std::vector& inputs, std::vector& outputs, tc::InferOptions& options, std::vector& input0_data) { @@ -167,7 +183,7 @@ RunAsynchronousInference( cv.notify_all(); }; - options.client_timeout_ = client_timeout; + options.client_timeout_ = client_timeout_ms; if (grpc_client.get() != nullptr) { FAIL_IF_ERR( grpc_client->AsyncInfer(callback, options, inputs, outputs), @@ -188,7 +204,7 @@ RunAsynchronousInference( void RunStreamingInference( std::unique_ptr& grpc_client, - uint32_t client_timeout, std::vector& inputs, + uint32_t client_timeout_ms, std::vector& inputs, std::vector& outputs, tc::InferOptions& options, std::vector& input0_data) { @@ -206,13 +222,13 @@ RunStreamingInference( } cv.notify_all(); }, - false /*ship_stats*/, client_timeout), + false /*ship_stats*/, client_timeout_ms), "Failed to start the stream"); FAIL_IF_ERR( grpc_client->AsyncStreamInfer(options, inputs), "unable to run model"); - auto timeout = std::chrono::microseconds(client_timeout); + auto timeout = std::chrono::microseconds(client_timeout_ms); // Wait until all callbacks are invoked or the timeout expires { std::unique_lock lk(mtx); @@ -263,11 +279,12 @@ main(int argc, char** argv) std::string url; bool async = false; bool streaming = false; - uint32_t client_timeout = 0; + uint32_t client_timeout_ms = 0; + bool test_client_apis = false; // Parse commandline... int opt; - while ((opt = getopt(argc, argv, "vi:u:ast:")) != -1) { + while ((opt = getopt(argc, argv, "vi:u:ast:p")) != -1) { switch (opt) { case 'v': verbose = true; @@ -292,7 +309,10 @@ main(int argc, char** argv) streaming = true; break; case 't': - client_timeout = std::stoi(optarg); + client_timeout_ms = std::stoi(optarg); + break; + case 'p': + test_client_apis = true; break; case '?': Usage(argv); @@ -335,6 +355,12 @@ main(int argc, char** argv) "unable to create grpc client"); } + // Test server timeouts for grpc client + if (protocol == "grpc" && test_client_apis) { + TestTimeoutAPIs(client_timeout_ms, model_name, grpc_client); + return 0; + } + // Initialize the tensor data std::vector input0_data(16); for (size_t i = 0; i < 16; ++i) { @@ -370,7 +396,7 @@ main(int argc, char** argv) // The inference settings. Will be using default for now. tc::InferOptions options(model_name); options.model_version_ = model_version; - options.client_timeout_ = client_timeout; + options.client_timeout_ = client_timeout_ms; std::vector inputs = {input0_ptr.get()}; std::vector outputs = {output0_ptr.get()}; @@ -378,14 +404,14 @@ main(int argc, char** argv) // Send inference request to the inference server. if (streaming) { RunStreamingInference( - grpc_client, client_timeout, inputs, outputs, options, input0_data); + grpc_client, client_timeout_ms, inputs, outputs, options, input0_data); } else if (async) { RunAsynchronousInference( - grpc_client, http_client, client_timeout, inputs, outputs, options, + grpc_client, http_client, client_timeout_ms, inputs, outputs, options, input0_data); } else { RunSynchronousInference( - grpc_client, http_client, client_timeout, inputs, outputs, options, + grpc_client, http_client, client_timeout_ms, inputs, outputs, options, input0_data); }