diff --git a/src/c++/examples/simple_grpc_infer_client.cc b/src/c++/examples/simple_grpc_infer_client.cc index 5959d3cd6..13c5e8bc1 100644 --- a/src/c++/examples/simple_grpc_infer_client.cc +++ b/src/c++/examples/simple_grpc_infer_client.cc @@ -229,6 +229,10 @@ main(int argc, char** argv) use_cached_channel), err); + if (verbose) { + std::cout << "There are " << client->GetNumCachedChannels() + << " cached channels" << std::endl; + } // Create the data for the two input tensors. Initialize the first // to unique integers and the second to all ones. std::vector input0_data(16); diff --git a/src/c++/library/grpc_client.cc b/src/c++/library/grpc_client.cc index 6e67f859e..c9ee70125 100644 --- a/src/c++/library/grpc_client.cc +++ b/src/c++/library/grpc_client.cc @@ -94,7 +94,7 @@ GetStub( "TRITON_CLIENT_GRPC_CHANNEL_MAX_SHARE_COUNT", "6")); const auto& channel_itr = grpc_channel_stub_map_.find(url); // Reuse cached channel if the channel is found in the map and - // used_cached_channel flag is true + // use_cached_channel flag is true if ((channel_itr != grpc_channel_stub_map_.end()) && use_cached_channel) { // check if NewStub should be created const auto& shared_count = std::get<0>(channel_itr->second); @@ -136,6 +136,8 @@ GetStub( std::shared_ptr stub = inference::GRPCInferenceService::NewStub(channel); + // If `use_cached_channel` is true, create no new channels even if there + // are no cached channels. if (use_cached_channel) { // Replace if channel / stub have been in the map if (channel_itr != grpc_channel_stub_map_.end()) { @@ -1706,6 +1708,13 @@ InferenceServerGrpcClient::~InferenceServerGrpcClient() StopStream(); } +size_t +InferenceServerGrpcClient::GetNumCachedChannels() const +{ + std::lock_guard lock(grpc_channel_stub_map_mtx_); + return grpc_channel_stub_map_.size(); +} + //============================================================================== }} // namespace triton::client diff --git a/src/c++/library/grpc_client.h b/src/c++/library/grpc_client.h index 7609c10b7..641b017ee 100644 --- a/src/c++/library/grpc_client.h +++ b/src/c++/library/grpc_client.h @@ -600,6 +600,9 @@ class InferenceServerGrpcClient : public InferenceServerClient { const std::vector& outputs = std::vector()); + // Number of Cached Channels + size_t GetNumCachedChannels() const; + private: InferenceServerGrpcClient( const std::string& url, bool verbose, bool use_ssl,