Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
 into yinggeh-DLIS-6657-client-input-byte-size-check
  • Loading branch information
yinggeh committed Jul 27, 2024
2 parents 07059a6 + 442915d commit 8b699c0
Show file tree
Hide file tree
Showing 91 changed files with 6,300 additions and 2,105 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ repos:
- id: isort
additional_dependencies: [toml]
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 24.4.0
hooks:
- id: black
types_or: [python, cython]
Expand Down
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ project(tritonclient LANGUAGES C CXX)
# Use C++17 standard as Triton's minimum required.
set(TRITON_MIN_CXX_STANDARD 17 CACHE STRING "The minimum C++ standard which features are requested to build this target.")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

#
# Options
#
Expand Down
6 changes: 6 additions & 0 deletions src/c++/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ cmake_minimum_required(VERSION 3.17)

project(cc-clients LANGUAGES C CXX)

# Use C++17 standard as Triton's minimum required.
set(TRITON_MIN_CXX_STANDARD 17 CACHE STRING "The minimum C++ standard which features are requested to build this target.")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

#
# Options
#
Expand Down
194 changes: 104 additions & 90 deletions src/c++/library/http_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1371,27 +1371,23 @@ InferenceServerHttpClient::InferenceServerHttpClient(

InferenceServerHttpClient::~InferenceServerHttpClient()
{
exiting_ = true;
{
std::lock_guard<std::mutex> lock(mutex_);
exiting_ = true;
}

curl_multi_wakeup(multi_handle_);

// thread not joinable if AsyncInfer() is not called
// (it is default constructed thread before the first AsyncInfer() call)
if (worker_.joinable()) {
cv_.notify_all();
worker_.join();
}

if (easy_handle_ != nullptr) {
curl_easy_cleanup(reinterpret_cast<CURL*>(easy_handle_));
}

if (multi_handle_ != nullptr) {
for (auto& request : ongoing_async_requests_) {
CURL* easy_handle = reinterpret_cast<CURL*>(request.first);
curl_multi_remove_handle(multi_handle_, easy_handle);
curl_easy_cleanup(easy_handle);
}
curl_multi_cleanup(multi_handle_);
}
curl_multi_cleanup(multi_handle_);
}

Error
Expand Down Expand Up @@ -1887,25 +1883,28 @@ InferenceServerHttpClient::AsyncInfer(
{
std::lock_guard<std::mutex> lock(mutex_);

auto insert_result = ongoing_async_requests_.emplace(std::make_pair(
if (exiting_) {
return Error("Client is exiting.");
}

auto insert_result = new_async_requests_.emplace(std::make_pair(
reinterpret_cast<uintptr_t>(multi_easy_handle), async_request));
if (!insert_result.second) {
curl_easy_cleanup(multi_easy_handle);
return Error("Failed to insert new asynchronous request context.");
}
}

async_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_START);
if (async_request->total_input_byte_size_ == 0) {
// Set SEND_END here because CURLOPT_READFUNCTION will not be called if
// content length is 0. In that case, we can't measure SEND_END properly
// (send ends after sending request header).
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_END);
}
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_START);
curl_multi_wakeup(multi_handle_);

curl_multi_add_handle(multi_handle_, multi_easy_handle);
if (async_request->total_input_byte_size_ == 0) {
// Set SEND_END here because CURLOPT_READFUNCTION will not be called if
// content length is 0. In that case, we can't measure SEND_END properly
// (send ends after sending request header).
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_END);
}

cv_.notify_all();
return Error::Success;
}

Expand Down Expand Up @@ -2254,88 +2253,103 @@ InferenceServerHttpClient::PreRunProcessing(
void
InferenceServerHttpClient::AsyncTransfer()
{
int place_holder = 0;
int messages_in_queue = 0;
int still_running = 0;
int numfds = 0;
CURLMsg* msg = nullptr;
AsyncReqMap ongoing_async_requests;
do {
std::vector<std::shared_ptr<HttpInferRequest>> request_list;
// Check for new requests and add them to ongoing requests
{
std::lock_guard<std::mutex> lock(mutex_);

for (auto& pair : new_async_requests_) {
curl_multi_add_handle(
multi_handle_, reinterpret_cast<CURL*>(pair.first));

// sleep if no work is available
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] {
if (this->exiting_) {
return true;
ongoing_async_requests[pair.first] = std::move(pair.second);
}
// wake up if an async request has been generated
return !this->ongoing_async_requests_.empty();
});

CURLMcode mc = curl_multi_perform(multi_handle_, &place_holder);
int numfds;
if (mc == CURLM_OK) {
// Wait for activity. If there are no descriptors in the multi_handle_
// then curl_multi_wait will return immediately
mc = curl_multi_wait(multi_handle_, NULL, 0, INT_MAX, &numfds);
if (mc == CURLM_OK) {
while ((msg = curl_multi_info_read(multi_handle_, &place_holder))) {
uintptr_t identifier = reinterpret_cast<uintptr_t>(msg->easy_handle);
auto itr = ongoing_async_requests_.find(identifier);
// This shouldn't happen
if (itr == ongoing_async_requests_.end()) {
std::cerr
<< "Unexpected error: received completed request that is not "
"in the list of asynchronous requests"
<< std::endl;
curl_multi_remove_handle(multi_handle_, msg->easy_handle);
curl_easy_cleanup(msg->easy_handle);
continue;
}
new_async_requests_.clear();
}

long http_code = 400;
if (msg->data.result == CURLE_OK) {
curl_easy_getinfo(
msg->easy_handle, CURLINFO_RESPONSE_CODE, &http_code);
} else if (msg->data.result == CURLE_OPERATION_TIMEDOUT) {
http_code = 499;
}
CURLMcode mc = curl_multi_perform(multi_handle_, &still_running);

request_list.emplace_back(itr->second);
ongoing_async_requests_.erase(itr);
curl_multi_remove_handle(multi_handle_, msg->easy_handle);
curl_easy_cleanup(msg->easy_handle);

std::shared_ptr<HttpInferRequest> async_request = request_list.back();
async_request->http_code_ = http_code;

if (msg->msg != CURLMSG_DONE) {
// Something wrong happened.
std::cerr << "Unexpected error: received CURLMsg=" << msg->msg
<< std::endl;
} else {
async_request->Timer().CaptureTimestamp(
RequestTimers::Kind::REQUEST_END);
Error err = UpdateInferStat(async_request->Timer());
if (!err.IsOk()) {
std::cerr << "Failed to update context stat: " << err
<< std::endl;
}
}
}
} else {
std::cerr << "Unexpected error: curl_multi failed. Code:" << mc
<< std::endl;
}
} else {
if (mc != CURLM_OK) {
std::cerr << "Unexpected error: curl_multi failed. Code:" << mc
<< std::endl;
continue;
}
lock.unlock();

for (auto& this_request : request_list) {
while ((msg = curl_multi_info_read(multi_handle_, &messages_in_queue))) {
if (msg->msg != CURLMSG_DONE) {
// Something wrong happened.
std::cerr << "Unexpected error: received CURLMsg=" << msg->msg
<< std::endl;
continue;
}

uintptr_t identifier = reinterpret_cast<uintptr_t>(msg->easy_handle);
auto itr = ongoing_async_requests.find(identifier);
// This shouldn't happen
if (itr == ongoing_async_requests.end()) {
std::cerr << "Unexpected error: received completed request that is not "
"in the list of asynchronous requests"
<< std::endl;
curl_multi_remove_handle(multi_handle_, msg->easy_handle);
curl_easy_cleanup(msg->easy_handle);
continue;
}
auto async_request = itr->second;

uint32_t http_code = 400;
if (msg->data.result == CURLE_OK) {
curl_easy_getinfo(msg->easy_handle, CURLINFO_RESPONSE_CODE, &http_code);
async_request->Timer().CaptureTimestamp(
RequestTimers::Kind::REQUEST_END);
Error err = UpdateInferStat(async_request->Timer());
if (!err.IsOk()) {
std::cerr << "Failed to update context stat: " << err << std::endl;
}
} else if (msg->data.result == CURLE_OPERATION_TIMEDOUT) {
http_code = 499;
}

async_request->http_code_ = http_code;
InferResult* result;
InferResultHttp::Create(&result, this_request);
this_request->callback_(result);
InferResultHttp::Create(&result, async_request);
async_request->callback_(result);
ongoing_async_requests.erase(itr);
curl_multi_remove_handle(multi_handle_, msg->easy_handle);
curl_easy_cleanup(msg->easy_handle);
}

// Wait for activity on existing requests or
// explicit curl_multi_wakeup call
//
// If there are no descriptors in the multi_handle_
// then curl_multi_poll will wait until curl_multi_wakeup
// is called
//
// curl_multi_wakeup is called when adding a new request
// or exiting

mc = curl_multi_poll(multi_handle_, NULL, 0, INT_MAX, &numfds);
if (mc != CURLM_OK) {
std::cerr << "Unexpected error: curl_multi_poll failed. Code:" << mc
<< std::endl;
}
} while (!exiting_);

for (auto& request : ongoing_async_requests) {
CURL* easy_handle = reinterpret_cast<CURL*>(request.first);
curl_multi_remove_handle(multi_handle_, easy_handle);
curl_easy_cleanup(easy_handle);
}

for (auto& request : new_async_requests_) {
CURL* easy_handle = reinterpret_cast<CURL*>(request.first);
curl_easy_cleanup(easy_handle);
}
}

size_t
Expand Down
4 changes: 2 additions & 2 deletions src/c++/library/http_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -643,9 +643,9 @@ class InferenceServerHttpClient : public InferenceServerClient {
void* easy_handle_;
// curl multi handle for processing asynchronous requests
void* multi_handle_;
// map to record ongoing asynchronous requests with pointer to easy handle
// map to record new asynchronous requests with pointer to easy handle
// or tag id as key
AsyncReqMap ongoing_async_requests_;
AsyncReqMap new_async_requests_;
};

}} // namespace triton::client
Loading

0 comments on commit 8b699c0

Please sign in to comment.