From 37c199ea911be7eda1e899328353fe569a91a752 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Tue, 11 Jun 2024 10:53:44 -0700 Subject: [PATCH] Prune non requested outputs from decoupled models --- src/response_sender.cc | 28 +++++++--------------------- src/response_sender.h | 3 --- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/src/response_sender.cc b/src/response_sender.cc index b8ac4603..1831601f 100644 --- a/src/response_sender.cc +++ b/src/response_sender.cc @@ -74,8 +74,9 @@ ResponseSender::~ResponseSender() PYTHONSTUB_DecoupledResponseFactoryCleanup); } -bool -ResponseSender::IsDecoupled() const +void +ResponseSender::UpdateStateAndCounters( + const std::shared_ptr& response, const uint32_t flags) { if (is_decoupled_ == nullptr) { // TODO: Can a model access the response sender on a BLS infer request? @@ -83,14 +84,7 @@ ResponseSender::IsDecoupled() const "Unable to send response. Response sender has no reference to the " "decoupled state of the model."); } - return *is_decoupled_; -} - -void -ResponseSender::UpdateStateAndCounters( - const std::shared_ptr& response, const uint32_t flags) -{ - bool is_decoupled = IsDecoupled(); + bool is_decoupled = *is_decoupled_; std::lock_guard lk(mu_); @@ -119,16 +113,6 @@ ResponseSender::UpdateStateAndCounters( number_of_response_sent_++; } -void -ResponseSender::PruneNonRequestedOutputs( - const std::shared_ptr& infer_response) const -{ - // TODO: should this be limited to non decoupled only? - if (!IsDecoupled() && infer_response) { - infer_response->PruneOutputTensors(requested_output_names_); - } -} - void ResponseSender::Send( std::shared_ptr infer_response, const uint32_t flags) @@ -142,7 +126,9 @@ ResponseSender::Send( CheckResponseSenderArguments(infer_response, flags); UpdateStateAndCounters(infer_response, flags); - PruneNonRequestedOutputs(infer_response); + if (infer_response) { + infer_response->PruneOutputTensors(requested_output_names_); + } std::unique_ptr& stub = Stub::GetOrCreateInstance(); diff --git a/src/response_sender.h b/src/response_sender.h index 05ad8069..f274f5b4 100644 --- a/src/response_sender.h +++ b/src/response_sender.h @@ -50,11 +50,8 @@ class ResponseSender { void Close(); private: - bool IsDecoupled() const; void UpdateStateAndCounters( const std::shared_ptr& response, const uint32_t flags); - void PruneNonRequestedOutputs( - const std::shared_ptr& infer_response) const; intptr_t request_address_; intptr_t response_factory_address_;