Skip to content

Commit

Permalink
Prune non requested outputs from decoupled models
Browse files Browse the repository at this point in the history
  • Loading branch information
kthui committed Jun 11, 2024
1 parent e63594f commit 37c199e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 24 deletions.
28 changes: 7 additions & 21 deletions src/response_sender.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,17 @@ ResponseSender::~ResponseSender()
PYTHONSTUB_DecoupledResponseFactoryCleanup);
}

bool
ResponseSender::IsDecoupled() const
void
ResponseSender::UpdateStateAndCounters(
const std::shared_ptr<InferResponse>& response, const uint32_t flags)
{
if (is_decoupled_ == nullptr) {
// TODO: Can a model access the response sender on a BLS infer request?
throw PythonBackendException(
"Unable to send response. Response sender has no reference to the "
"decoupled state of the model.");
}
return *is_decoupled_;
}

void
ResponseSender::UpdateStateAndCounters(
const std::shared_ptr<InferResponse>& response, const uint32_t flags)
{
bool is_decoupled = IsDecoupled();
bool is_decoupled = *is_decoupled_;

std::lock_guard<std::mutex> lk(mu_);

Expand Down Expand Up @@ -119,16 +113,6 @@ ResponseSender::UpdateStateAndCounters(
number_of_response_sent_++;
}

void
ResponseSender::PruneNonRequestedOutputs(
const std::shared_ptr<InferResponse>& infer_response) const
{
// TODO: should this be limited to non decoupled only?
if (!IsDecoupled() && infer_response) {
infer_response->PruneOutputTensors(requested_output_names_);
}
}

void
ResponseSender::Send(
std::shared_ptr<InferResponse> infer_response, const uint32_t flags)
Expand All @@ -142,7 +126,9 @@ ResponseSender::Send(

CheckResponseSenderArguments(infer_response, flags);
UpdateStateAndCounters(infer_response, flags);
PruneNonRequestedOutputs(infer_response);
if (infer_response) {
infer_response->PruneOutputTensors(requested_output_names_);
}

std::unique_ptr<Stub>& stub = Stub::GetOrCreateInstance();

Expand Down
3 changes: 0 additions & 3 deletions src/response_sender.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@ class ResponseSender {
void Close();

private:
bool IsDecoupled() const;
void UpdateStateAndCounters(
const std::shared_ptr<InferResponse>& response, const uint32_t flags);
void PruneNonRequestedOutputs(
const std::shared_ptr<InferResponse>& infer_response) const;

intptr_t request_address_;
intptr_t response_factory_address_;
Expand Down

0 comments on commit 37c199e

Please sign in to comment.