From 5add5293c9268f613ce5e9a6a7856469f6fa6fa2 Mon Sep 17 00:00:00 2001 From: GuanLuo Date: Thu, 29 Feb 2024 23:42:23 -0800 Subject: [PATCH] Add OpenAI client --- .../client_backend/openai/CMakeLists.txt | 8 +- .../client_backend/openai/http_client.cc | 267 ++++++++++++++++ .../client_backend/openai/http_client.h | 191 +++++++++++ .../client_backend/openai/openai_client.cc | 298 ++++++++++++++++++ .../client_backend/openai/openai_client.h | 186 +++++++++++ .../openai/openai_client_backend.cc | 67 +--- .../openai/openai_client_backend.h | 25 +- .../openai/openai_http_client.cc | 120 ------- .../openai/openai_http_client.h | 106 ------- .../openai/openai_infer_input.cc | 38 +-- .../openai/openai_infer_input.h | 8 +- 11 files changed, 974 insertions(+), 340 deletions(-) create mode 100644 src/c++/perf_analyzer/client_backend/openai/http_client.cc create mode 100644 src/c++/perf_analyzer/client_backend/openai/http_client.h create mode 100644 src/c++/perf_analyzer/client_backend/openai/openai_client.cc create mode 100644 src/c++/perf_analyzer/client_backend/openai/openai_client.h delete mode 100644 src/c++/perf_analyzer/client_backend/openai/openai_http_client.cc delete mode 100644 src/c++/perf_analyzer/client_backend/openai/openai_http_client.h diff --git a/src/c++/perf_analyzer/client_backend/openai/CMakeLists.txt b/src/c++/perf_analyzer/client_backend/openai/CMakeLists.txt index 3ef867e9f..93963e378 100644 --- a/src/c++/perf_analyzer/client_backend/openai/CMakeLists.txt +++ b/src/c++/perf_analyzer/client_backend/openai/CMakeLists.txt @@ -28,15 +28,17 @@ cmake_minimum_required (VERSION 3.18) set( OPENAI_CLIENT_BACKEND_SRCS + http_client.cc openai_client_backend.cc - openai_http_client.cc + openai_client.cc openai_infer_input.cc ) set( OPENAI_CLIENT_BACKEND_HDRS + http_client.h openai_client_backend.h - openai_http_client.h + openai_client.h openai_infer_input.h ) @@ -48,7 +50,7 @@ add_library( target_link_libraries( openai-client-backend-library - PUBLIC CURL::libcurl + PUBLIC CURL::libcurl PUBLIC httpclient_static ) diff --git a/src/c++/perf_analyzer/client_backend/openai/http_client.cc b/src/c++/perf_analyzer/client_backend/openai/http_client.cc new file mode 100644 index 000000000..4c8632c52 --- /dev/null +++ b/src/c++/perf_analyzer/client_backend/openai/http_client.cc @@ -0,0 +1,267 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "http_client.h" + +#include +#include + +namespace triton { namespace perfanalyzer { namespace clientbackend { +namespace openai { + +HttpRequest::HttpRequest( + std::function&& completion_callback, const bool verbose) + : completion_callback_(std::move(completion_callback)), verbose_(verbose) +{ +} + +HttpRequest::~HttpRequest() +{ + if (header_list_ != nullptr) { + curl_slist_free_all(header_list_); + header_list_ = nullptr; + } +} + +void +HttpRequest::AddInput(uint8_t* buf, size_t byte_size) +{ + data_buffers_.push_back(std::pair(buf, byte_size)); + total_input_byte_size_ += byte_size; +} + +void +HttpRequest::GetNextInput(uint8_t* buf, size_t size, size_t* input_bytes) +{ + *input_bytes = 0; + + while (!data_buffers_.empty() && size > 0) { + const size_t csz = std::min(data_buffers_.front().second, size); + if (csz > 0) { + const uint8_t* input_ptr = data_buffers_.front().first; + std::copy(input_ptr, input_ptr + csz, buf); + size -= csz; + buf += csz; + *input_bytes += csz; + + data_buffers_.front().first += csz; + data_buffers_.front().second -= csz; + } + if (data_buffers_.front().second == 0) { + data_buffers_.pop_front(); + } + } +} + +HttpClient::HttpClient( + const std::string& server_url, bool verbose, + const HttpSslOptions& ssl_options) + : url_(server_url), verbose_(verbose), ssl_options_(ssl_options) +{ + auto* ver = curl_version_info(CURLVERSION_NOW); + if (ver->features & CURL_VERSION_THREADSAFE == 0) { + throw std::runtime_error( + "HTTP client has dependency on CURL library to have thread-safe " + "support (CURL_VERSION_THREADSAFE set)"); + } + if (curl_global_init(CURL_GLOBAL_ALL) != 0) { + throw std::runtime_error("CURL global initialization failed"); + } + + multi_handle_ = curl_multi_init(); + + worker_ = std::thread(&HttpClient::AsyncTransfer, this); +} + +HttpClient::~HttpClient() +{ + exiting_ = true; + + // thread not joinable if AsyncInfer() is not called + // (it is default constructed thread before the first AsyncInfer() call) + if (worker_.joinable()) { + cv_.notify_all(); + worker_.join(); + } + + for (auto& request : ongoing_async_requests_) { + CURL* easy_handle = reinterpret_cast(request.first); + curl_multi_remove_handle(multi_handle_, easy_handle); + curl_easy_cleanup(easy_handle); + } + curl_multi_cleanup(multi_handle_); + + curl_global_cleanup(); +} + +const std::string& +HttpClient::ParseSslCertType(HttpSslOptions::CERTTYPE cert_type) +{ + static std::string pem_str{"PEM"}; + static std::string der_str{"DER"}; + switch (cert_type) { + case HttpSslOptions::CERTTYPE::CERT_PEM: + return pem_str; + case HttpSslOptions::CERTTYPE::CERT_DER: + return der_str; + } + throw std::runtime_error( + "Unexpected SSL certificate type encountered. Only PEM and DER are " + "supported."); +} + +const std::string& +HttpClient::ParseSslKeyType(HttpSslOptions::KEYTYPE key_type) +{ + static std::string pem_str{"PEM"}; + static std::string der_str{"DER"}; + switch (key_type) { + case HttpSslOptions::KEYTYPE::KEY_PEM: + return pem_str; + case HttpSslOptions::KEYTYPE::KEY_DER: + return der_str; + } + throw std::runtime_error( + "unsupported SSL key type encountered. Only PEM and DER are " + "supported."); +} + +void +HttpClient::SetSSLCurlOptions(CURL* curl_handle) +{ + curl_easy_setopt( + curl_handle, CURLOPT_SSL_VERIFYPEER, ssl_options_.verify_peer); + curl_easy_setopt( + curl_handle, CURLOPT_SSL_VERIFYHOST, ssl_options_.verify_host); + if (!ssl_options_.ca_info.empty()) { + curl_easy_setopt(curl_handle, CURLOPT_CAINFO, ssl_options_.ca_info.c_str()); + } + const auto& curl_cert_type = ParseSslCertType(ssl_options_.cert_type); + curl_easy_setopt(curl_handle, CURLOPT_SSLCERTTYPE, curl_cert_type.c_str()); + if (!ssl_options_.cert.empty()) { + curl_easy_setopt(curl_handle, CURLOPT_SSLCERT, ssl_options_.cert.c_str()); + } + const auto& curl_key_type = ParseSslKeyType(ssl_options_.key_type); + curl_easy_setopt(curl_handle, CURLOPT_SSLKEYTYPE, curl_key_type.c_str()); + if (!ssl_options_.key.empty()) { + curl_easy_setopt(curl_handle, CURLOPT_SSLKEY, ssl_options_.key.c_str()); + } +} + +void +HttpClient::Send(CURL* handle, std::unique_ptr&& request) +{ + std::lock_guard lock(mutex_); + + auto insert_result = ongoing_async_requests_.emplace( + std::make_pair(reinterpret_cast(handle), std::move(request))); + if (!insert_result.second) { + curl_easy_cleanup(handle); + throw std::runtime_error( + "Failed to insert new asynchronous request context."); + } + curl_multi_add_handle(multi_handle_, handle); + cv_.notify_all(); +} + +void +HttpClient::AsyncTransfer() +{ + int place_holder = 0; + CURLMsg* msg = nullptr; + do { + std::vector> request_list; + + // sleep if no work is available + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { + if (this->exiting_) { + return true; + } + // wake up if an async request has been generated + return !this->ongoing_async_requests_.empty(); + }); + + CURLMcode mc = curl_multi_perform(multi_handle_, &place_holder); + int numfds; + if (mc == CURLM_OK) { + // Wait for activity. If there are no descriptors in the multi_handle_ + // then curl_multi_wait will return immediately + mc = curl_multi_wait(multi_handle_, NULL, 0, INT_MAX, &numfds); + if (mc == CURLM_OK) { + while ((msg = curl_multi_info_read(multi_handle_, &place_holder))) { + uintptr_t identifier = reinterpret_cast(msg->easy_handle); + auto itr = ongoing_async_requests_.find(identifier); + // This shouldn't happen + if (itr == ongoing_async_requests_.end()) { + std::cerr + << "Unexpected error: received completed request that is not " + "in the list of asynchronous requests" + << std::endl; + curl_multi_remove_handle(multi_handle_, msg->easy_handle); + curl_easy_cleanup(msg->easy_handle); + continue; + } + + long http_code = 400; + if (msg->data.result == CURLE_OK) { + curl_easy_getinfo( + msg->easy_handle, CURLINFO_RESPONSE_CODE, &http_code); + } else if (msg->data.result == CURLE_OPERATION_TIMEDOUT) { + http_code = 499; + } + + request_list.emplace_back(std::move(itr->second)); + ongoing_async_requests_.erase(itr); + curl_multi_remove_handle(multi_handle_, msg->easy_handle); + curl_easy_cleanup(msg->easy_handle); + + std::unique_ptr& async_request = request_list.back(); + async_request->http_code_ = http_code; + + if (msg->msg != CURLMSG_DONE) { + // Something wrong happened. + std::cerr << "Unexpected error: received CURLMsg=" << msg->msg + << std::endl; + } + } + } else { + std::cerr << "Unexpected error: curl_multi failed. Code:" << mc + << std::endl; + } + } else { + std::cerr << "Unexpected error: curl_multi failed. Code:" << mc + << std::endl; + } + lock.unlock(); + + for (auto& this_request : request_list) { + this_request->completion_callback_(this_request.get()); + } + } while (!exiting_); +} + +}}}} // namespace triton::perfanalyzer::clientbackend::openai diff --git a/src/c++/perf_analyzer/client_backend/openai/http_client.h b/src/c++/perf_analyzer/client_backend/openai/http_client.h new file mode 100644 index 000000000..3caa94992 --- /dev/null +++ b/src/c++/perf_analyzer/client_backend/openai/http_client.h @@ -0,0 +1,191 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +// [TODO] Below should already be a generic class for any HTTP use, +// relocate it so that it can be used elsewhere +namespace triton { namespace perfanalyzer { namespace clientbackend { +namespace openai { + +// [FIXME] add back "parameter" handling +// [FIXME] add back "compression" handling + +/// The key-value map type to be included in the request +/// as custom headers. +typedef std::map Headers; +/// The key-value map type to be included as URL parameters. +typedef std::map Parameters; + +// The options for authorizing and authenticating SSL/TLS connections. +struct HttpSslOptions { + enum CERTTYPE { CERT_PEM = 0, CERT_DER = 1 }; + enum KEYTYPE { + KEY_PEM = 0, + KEY_DER = 1 + // TODO: Support loading private key from crypto engine + // KEY_ENG = 2 + }; + explicit HttpSslOptions() + : verify_peer(1), verify_host(2), cert_type(CERTTYPE::CERT_PEM), + key_type(KEYTYPE::KEY_PEM) + { + } + // This option determines whether curl verifies the authenticity of the peer's + // certificate. A value of 1 means curl verifies; 0 (zero) means it does not. + // Default value is 1. See here for more details: + // https://curl.se/libcurl/c/CURLOPT_SSL_VERIFYPEER.html + long verify_peer; + // This option determines whether libcurl verifies that the server cert is for + // the server it is known as. The default value for this option is 2 which + // means that certificate must indicate that the server is the server to which + // you meant to connect, or the connection fails. See here for more details: + // https://curl.se/libcurl/c/CURLOPT_SSL_VERIFYHOST.html + long verify_host; + // File holding one or more certificates to verify the peer with. If not + // specified, client will look for the system path where cacert bundle is + // assumed to be stored, as established at build time. See here for more + // information: https://curl.se/libcurl/c/CURLOPT_CAINFO.html + std::string ca_info; + // The format of client certificate. By default it is CERT_PEM. See here for + // more details: https://curl.se/libcurl/c/CURLOPT_SSLCERTTYPE.html + CERTTYPE cert_type; + // The file name of your client certificate. See here for more details: + // https://curl.se/libcurl/c/CURLOPT_SSLCERT.html + std::string cert; + // The format of the private key. By default it is KEY_PEM. See here for more + // details: https://curl.se/libcurl/c/CURLOPT_SSLKEYTYPE.html. + KEYTYPE key_type; + // The private key. See here for more details: + // https://curl.se/libcurl/c/CURLOPT_SSLKEY.html. + std::string key; +}; + +// an HttpRequest object represents the context of a HTTP transaction. currently +// it is also designed to be the placeholder for response data, but how the +// response is stored can be revisited later. +// 'completion_callback' doesn't transfer ownership of HttpRequest, caller must +// not keep the reference and access HttpRequest object after +// 'completion_callback' returns +class HttpRequest { + public: + HttpRequest( + std::function&& completion_callback, + const bool verbose = false); + virtual ~HttpRequest(); + + // Adds the input data to be delivered to the server, note that the HTTP + // request does not own the buffer. + void AddInput(uint8_t* buf, size_t byte_size); + + // Helper function for CURL + // Copy into 'buf' up to 'size' bytes of input data. Return the + // actual amount copied in 'input_bytes'. + void GetNextInput(uint8_t* buf, size_t size, size_t* input_bytes); + + // [FIXME] define default callback like + // CURLOPT_READFUNCTION, CURLOPT_WRITEFUNCTION here? + // the specialized HttpRequest can override the callbacks when read / write + // schema has changed. + + // Buffer that accumulates the response body. + std::string response_buffer_; + + size_t total_input_byte_size_{0}; + + // HTTP response code for the inference request + long http_code_{200}; + + std::function completion_callback_{nullptr}; + + // Pointer to the list of the HTTP request header, keep it such that it will + // be valid during the transfer and can be freed once transfer is completed. + struct curl_slist* header_list_{nullptr}; + + protected: + const bool verbose_{false}; + + // The pointers to the input data. + std::deque> data_buffers_; +}; + +// Base class for common HTTP functionalities +class HttpClient { + public: + enum class CompressionType { NONE, DEFLATE, GZIP }; + + virtual ~HttpClient(); + + protected: + void SetSSLCurlOptions(CURL* curl_handle); + + HttpClient( + const std::string& server_url, bool verbose = false, + const HttpSslOptions& ssl_options = HttpSslOptions()); + + // Note that this function does not block + void Send(CURL* handle, std::unique_ptr&& request); + + // [FIXME] provide more helper functions to encapsulate CURL detail + + protected: + void AsyncTransfer(); + + bool exiting_{false}; + + std::thread worker_; + std::mutex mutex_; + std::condition_variable cv_; + + // The server url + const std::string url_; + // The options for authorizing and authenticating SSL/TLS connections + HttpSslOptions ssl_options_; + + using AsyncReqMap = std::map>; + // curl multi handle for processing asynchronous requests + void* multi_handle_; + // map to record ongoing asynchronous requests with pointer to easy handle + // or tag id as key + AsyncReqMap ongoing_async_requests_; + + bool verbose_; + + private: + // [FIXME] should belong to SSL option struct as helper function + const std::string& ParseSslKeyType(HttpSslOptions::KEYTYPE key_type); + const std::string& ParseSslCertType(HttpSslOptions::CERTTYPE cert_type); +}; +}}}} // namespace triton::perfanalyzer::clientbackend::openai diff --git a/src/c++/perf_analyzer/client_backend/openai/openai_client.cc b/src/c++/perf_analyzer/client_backend/openai/openai_client.cc new file mode 100644 index 000000000..0b7c85c00 --- /dev/null +++ b/src/c++/perf_analyzer/client_backend/openai/openai_client.cc @@ -0,0 +1,298 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Include this first to make sure we are a friend of common classes. +#define TRITON_INFERENCE_SERVER_CLIENT_CLASS InferenceServerHttpClient +#include "openai_client.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" + +#ifdef TRITON_ENABLE_ZLIB +#include +#endif + +extern "C" { +#include "cencode.h" +} + +#ifdef _WIN32 +#define strncasecmp(x, y, z) _strnicmp(x, y, z) +#undef min // NOMINMAX did not resolve std::min compile error +#endif //_WIN32 + +namespace triton { namespace perfanalyzer { namespace clientbackend { +namespace openai { + +//============================================================================== + +void +ChatCompletionRequest::SendResponse(bool is_final, bool is_null) +{ + response_callback_(new ChatCompletionResult( + http_code_, std::move(response_buffer_), is_final, is_null, request_id_)); +} + +ChatCompletionClient::ChatCompletionClient( + const std::string& url, bool verbose, const HttpSslOptions& ssl_options) + : HttpClient(url, verbose, ssl_options) +{ +} + +size_t +ChatCompletionClient::RequestProvider( + void* contents, size_t size, size_t nmemb, void* userp) +{ + auto request = reinterpret_cast(userp); + + size_t input_bytes = 0; + request->GetNextInput( + reinterpret_cast(contents), size * nmemb, &input_bytes); + + if (input_bytes == 0) { + request->timer_.CaptureTimestamp( + triton::client::RequestTimers::Kind::SEND_END); + } + + return input_bytes; +} + +size_t +ChatCompletionClient::ResponseHeaderHandler( + void* contents, size_t size, size_t nmemb, void* userp) +{ + auto request = reinterpret_cast(userp); + + char* buf = reinterpret_cast(contents); + size_t byte_size = size * nmemb; + + std::string hdr(buf, byte_size); + std::transform(hdr.begin(), hdr.end(), hdr.begin(), [](unsigned char c) { + return std::tolower(c); + }); + if (hdr.find("content-type") != std::string::npos) { + request->is_stream_ = (hdr.find("text/event-stream") != std::string::npos); + } + + return byte_size; +} + +size_t +ChatCompletionClient::ResponseHandler( + void* contents, size_t size, size_t nmemb, void* userp) +{ + // [WIP] verify if the SSE responses received are complete, or the response + // need to be stitched first + auto request = reinterpret_cast(userp); + if (request->timer_.Timestamp( + triton::client::RequestTimers::Kind::RECV_START) == 0) { + request->timer_.CaptureTimestamp( + triton::client::RequestTimers::Kind::RECV_START); + } + + char* buf = reinterpret_cast(contents); + size_t result_bytes = size * nmemb; + request->response_buffer_.append(buf, result_bytes); + // Send response now if streaming, otherwise wait until request has been + // completed + if (request->is_stream_) { + // [FIXME] assume it is proper chunked of response + auto done_signal = + (request->response_buffer_.find("data: [DONE]") != std::string::npos); + request->SendResponse( + done_signal /* is_final */, done_signal /* is_null */); + } + + // ResponseHandler may be called multiple times so we overwrite + // RECV_END so that we always have the time of the last. + request->timer_.CaptureTimestamp( + triton::client::RequestTimers::Kind::RECV_END); + + return result_bytes; +} + + +Error +ChatCompletionClient::AsyncInfer( + std::function callback, + std::string& serialized_request_body, + const std::string& request_id) +{ + if (callback == nullptr) { + return Error( + "Callback function must be provided along with AsyncInfer() call."); + } + + auto completion_callback = [this](HttpRequest* req) { + auto request = static_cast(req); + if (!request->is_stream_) { + request->SendResponse(true /* is_final */, false /* is_null */); + } + request->timer_.CaptureTimestamp( + triton::client::RequestTimers::Kind::REQUEST_END); + UpdateInferStat(request->timer_); + }; + std::unique_ptr request(new ChatCompletionRequest( + std::move(completion_callback), std::move(callback), request_id, verbose_)); + auto raw_request = static_cast(request.get()); + raw_request->timer_.CaptureTimestamp( + triton::client::RequestTimers::Kind::REQUEST_START); + request->AddInput( + reinterpret_cast(serialized_request_body.data()), + serialized_request_body.size()); + std::string request_uri(url_ + "/v1/chat/completions"); + + CURL* multi_easy_handle = curl_easy_init(); + Error err = PreRunProcessing(multi_easy_handle, request_uri, raw_request); + if (!err.IsOk()) { + curl_easy_cleanup(multi_easy_handle); + return err; + } + + raw_request->timer_.CaptureTimestamp( + triton::client::RequestTimers::Kind::SEND_START); + Send(multi_easy_handle, std::move(request)); + return Error::Success; +} + +Error +ChatCompletionClient::PreRunProcessing( + CURL* curl, std::string& request_uri, ChatCompletionRequest* request) +{ + curl_easy_setopt(curl, CURLOPT_URL, request_uri.c_str()); + curl_easy_setopt(curl, CURLOPT_USERAGENT, "libcurl-agent/1.0"); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_TCP_NODELAY, 1L); + + if (verbose_) { + curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L); + } + + const long buffer_byte_size = 16 * 1024 * 1024; + curl_easy_setopt(curl, CURLOPT_UPLOAD_BUFFERSIZE, buffer_byte_size); + curl_easy_setopt(curl, CURLOPT_BUFFERSIZE, buffer_byte_size); + + // request data provided by RequestProvider() + curl_easy_setopt(curl, CURLOPT_READFUNCTION, RequestProvider); + curl_easy_setopt(curl, CURLOPT_READDATA, request); + + // response headers handled by ResponseHeaderHandler() + curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, ResponseHeaderHandler); + curl_easy_setopt(curl, CURLOPT_HEADERDATA, request); + + // response data handled by ResponseHandler() + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, ResponseHandler); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, request); + + const curl_off_t post_byte_size = request->total_input_byte_size_; + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE_LARGE, post_byte_size); + + SetSSLCurlOptions(curl); + + struct curl_slist* list = nullptr; + list = curl_slist_append(list, "Expect:"); + list = curl_slist_append(list, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, list); + + // The list will be freed when the request is destructed + request->header_list_ = list; + + return Error::Success; +} + +Error +ChatCompletionClient::UpdateInferStat( + const triton::client::RequestTimers& timer) +{ + const uint64_t request_time_ns = timer.Duration( + triton::client::RequestTimers::Kind::REQUEST_START, + triton::client::RequestTimers::Kind::REQUEST_END); + const uint64_t send_time_ns = timer.Duration( + triton::client::RequestTimers::Kind::SEND_START, + triton::client::RequestTimers::Kind::SEND_END); + const uint64_t recv_time_ns = timer.Duration( + triton::client::RequestTimers::Kind::RECV_START, + triton::client::RequestTimers::Kind::RECV_END); + + if ((request_time_ns == std::numeric_limits::max()) || + (send_time_ns == std::numeric_limits::max()) || + (recv_time_ns == std::numeric_limits::max())) { + return Error( + "Timer not set correctly." + + ((timer.Timestamp(triton::client::RequestTimers::Kind::REQUEST_START) > + timer.Timestamp(triton::client::RequestTimers::Kind::REQUEST_END)) + ? (" Request time from " + + std::to_string(timer.Timestamp( + triton::client::RequestTimers::Kind::REQUEST_START)) + + " to " + + std::to_string(timer.Timestamp( + triton::client::RequestTimers::Kind::REQUEST_END)) + + ".") + : "") + + ((timer.Timestamp(triton::client::RequestTimers::Kind::SEND_START) > + timer.Timestamp(triton::client::RequestTimers::Kind::SEND_END)) + ? (" Send time from " + + std::to_string(timer.Timestamp( + triton::client::RequestTimers::Kind::SEND_START)) + + " to " + + std::to_string(timer.Timestamp( + triton::client::RequestTimers::Kind::SEND_END)) + + ".") + : "") + + ((timer.Timestamp(triton::client::RequestTimers::Kind::RECV_START) > + timer.Timestamp(triton::client::RequestTimers::Kind::RECV_END)) + ? (" Receive time from " + + std::to_string(timer.Timestamp( + triton::client::RequestTimers::Kind::RECV_START)) + + " to " + + std::to_string(timer.Timestamp( + triton::client::RequestTimers::Kind::RECV_END)) + + ".") + : "")); + } + + infer_stat_.completed_request_count++; + infer_stat_.cumulative_total_request_time_ns += request_time_ns; + infer_stat_.cumulative_send_time_ns += send_time_ns; + infer_stat_.cumulative_receive_time_ns += recv_time_ns; + + return Error::Success; +} + +//============================================================================== + +}}}} // namespace triton::perfanalyzer::clientbackend::openai diff --git a/src/c++/perf_analyzer/client_backend/openai/openai_client.h b/src/c++/perf_analyzer/client_backend/openai/openai_client.h new file mode 100644 index 000000000..38d0f8f04 --- /dev/null +++ b/src/c++/perf_analyzer/client_backend/openai/openai_client.h @@ -0,0 +1,186 @@ +// Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +/// \file + +#include +#include + +#include "../client_backend.h" +#include "common.h" +#include "http_client.h" + + +namespace triton { namespace perfanalyzer { namespace clientbackend { +namespace openai { + +class ChatCompletionResult : public InferResult { + public: + ChatCompletionResult( + uint32_t http_code, std::string&& serialized_response, bool is_final, + bool is_null, const std::string& request_id) + : http_code_(http_code), + serialized_response_(std::move(serialized_response)), + is_final_(is_final), is_null_(is_null), request_id_(request_id) + { + } + virtual ~ChatCompletionResult() = default; + + /// Get the id of the request which generated this response. + /// \param id Returns the request id that generated the result. + /// \return Error object indicating success or failure. + Error Id(std::string* id) const override + { + *id = request_id_; + return Error::Success; + } + + + /// Returns the status of the request. + /// \return Error object indicating the success or failure of the + /// request. + Error RequestStatus() const override + { + if ((http_code_ >= 400) && (http_code_ <= 599)) { + return Error( + "OpenAI response returns HTTP code" + std::to_string(http_code_)); + } + return Error::Success; + } + + /// Returns the raw data of the output. + /// \return Error object indicating the success or failure of the + /// request. + Error RawData( + const std::string& output_name, const uint8_t** buf, + size_t* byte_size) const override + { + // [FIXME] disregard "output_name" which is not compatible to + // OpenAI protocol + *buf = reinterpret_cast(serialized_response_.c_str()); + *byte_size = serialized_response_.size(); + return Error::Success; + } + + /// Get final response bool for this response. + /// \return Error object indicating the success or failure. + Error IsFinalResponse(bool* is_final_response) const override + { + *is_final_response = is_final_; + return Error::Success; + }; + + /// Get null response bool for this response. + /// \return Error object indicating the success or failure. + Error IsNullResponse(bool* is_null_response) const override + { + *is_null_response = is_null_; + return Error::Success; + }; + + private: + const uint32_t http_code_{200}; + const std::string serialized_response_; + const bool is_final_{false}; + const bool is_null_{false}; + const std::string request_id_; +}; + + +class ChatCompletionRequest : public HttpRequest { + public: + virtual ~ChatCompletionRequest() {} + ChatCompletionRequest( + std::function&& completion_callback, + std::function&& response_callback, + const std::string& request_id, + const bool verbose = false) + : HttpRequest(std::move(completion_callback), verbose), + response_callback_(std::move(response_callback)), + request_id_(request_id) + { + } + void SendResponse(bool is_final, bool is_null); + bool is_stream_{false}; + std::function response_callback_{nullptr}; + // The timers for infer request. + triton::client::RequestTimers timer_; + const std::string request_id_; +}; + +class ChatCompletionClient : public HttpClient { + public: + virtual ~ChatCompletionClient() = default; + + /// Create a client that can be used to communicate with the server. + /// \param client Returns a new InferenceServerHttpClient object. + /// \param server_url The inference server name, port, optional + /// scheme and optional base path in the following format: + /// host:port/. + /// \param verbose If true generate verbose output when contacting + /// the inference server. + /// \param ssl_options Specifies the settings for configuring + /// SSL encryption and authorization. Providing these options + /// do not ensure that SSL/TLS will be used in communication. + /// The use of SSL/TLS depends entirely on the server endpoint. + /// These options will be ignored if the server_url does not + /// expose `https://` scheme. + /// \return Error object indicating success or failure. + ChatCompletionClient( + const std::string& server_url, bool verbose = false, + const HttpSslOptions& ssl_options = HttpSslOptions()); + + /// Simplified AsyncInfer() where the request body is expected to be + /// prepared by the caller, the client here is responsible to communicate + /// with a OpenAI-compatible server in both streaming and non-streaming case. + Error AsyncInfer( + std::function callback, + std::string& serialized_request_body, + const std::string& request_id); + + const InferStat& ClientInferStat() { return infer_stat_; } + + /// [TODO?] Add AsyncInfer() variant that prepare the request body from + /// function arguments. Similar to Triton client library. + + private: + // setup curl handle + Error PreRunProcessing( + CURL* curl, std::string& request_uri, ChatCompletionRequest* request); + + static size_t ResponseHandler( + void* contents, size_t size, size_t nmemb, void* userp); + static size_t RequestProvider( + void* contents, size_t size, size_t nmemb, void* userp); + static size_t ResponseHeaderHandler( + void* contents, size_t size, size_t nmemb, void* userp); + + Error UpdateInferStat(const triton::client::RequestTimers& timer); + InferStat infer_stat_; +}; + +}}}} // namespace triton::perfanalyzer::clientbackend::openai diff --git a/src/c++/perf_analyzer/client_backend/openai/openai_client_backend.cc b/src/c++/perf_analyzer/client_backend/openai/openai_client_backend.cc index d017b8b23..968973d42 100644 --- a/src/c++/perf_analyzer/client_backend/openai/openai_client_backend.cc +++ b/src/c++/perf_analyzer/client_backend/openai/openai_client_backend.cc @@ -26,6 +26,8 @@ #include "openai_client_backend.h" +#include "openai_infer_input.h" + namespace triton { namespace perfanalyzer { namespace clientbackend { namespace openai { @@ -44,8 +46,8 @@ OpenAiClientBackend::Create( std::unique_ptr openai_client_backend( new OpenAiClientBackend(http_headers)); - RETURN_IF_CB_ERROR( - HttpClient::Create(&(openai_client_backend->http_client_), url, verbose)); + openai_client_backend->http_client_.reset( + new ChatCompletionClient(url, verbose)); *client_backend = std::move(openai_client_backend); @@ -58,14 +60,14 @@ OpenAiClientBackend::AsyncInfer( const std::vector& inputs, const std::vector& outputs) { - auto wrapped_callback = [callback](cb::openai::InferResult* client_result) { - cb::InferResult* result = new OpenAiInferResult(client_result); - callback(result); - }; - - RETURN_IF_CB_ERROR(http_client_->AsyncInfer( - wrapped_callback, options, inputs, outputs, *http_headers_)); + if (inputs.size() != 1) { + return Error("Only expecting one input"); + } + auto raw_input = dynamic_cast(inputs[0]); + raw_input->PrepareForRequest(); + RETURN_IF_CB_ERROR( + http_client_->AsyncInfer(callback, raw_input->DataString(), options.request_id_)); return Error::Success; } @@ -73,25 +75,10 @@ OpenAiClientBackend::AsyncInfer( Error OpenAiClientBackend::ClientInferStat(InferStat* infer_stat) { - // Reusing the common library utilities to collect and report the - // client side statistics. - tc::InferStat client_infer_stat; - - RETURN_IF_TRITON_ERROR(http_client_->ClientInferStat(&client_infer_stat)); - - ParseInferStat(client_infer_stat, infer_stat); - + *infer_stat = http_client_->ClientInferStat(); return Error::Success; } -void -OpenAiClientBackend::ParseInferStat( - const tc::InferStat& tfserve_infer_stat, InferStat* infer_stat) -{ - // TODO: Implement - return; -} - //============================================================================== Error @@ -118,35 +105,5 @@ OpenAiInferRequestedOutput::OpenAiInferRequestedOutput(const std::string& name) //============================================================================== -OpenAiInferResult::OpenAiInferResult(cb::openai::InferResult* result) -{ - result_.reset(result); -} - -Error -OpenAiInferResult::Id(std::string* id) const -{ - id->clear(); - return Error::Success; -} - -Error -OpenAiInferResult::RequestStatus() const -{ - RETURN_IF_CB_ERROR(result_->RequestStatus()); - return Error::Success; -} - -Error -OpenAiInferResult::RawData( - const std::string& output_name, const uint8_t** buf, - size_t* byte_size) const -{ - return Error( - "Output retrieval is not currently supported for OpenAi client backend"); -} - -//============================================================================== - }}}} // namespace triton::perfanalyzer::clientbackend::openai diff --git a/src/c++/perf_analyzer/client_backend/openai/openai_client_backend.h b/src/c++/perf_analyzer/client_backend/openai/openai_client_backend.h index c6c83222f..ea9a49a82 100644 --- a/src/c++/perf_analyzer/client_backend/openai/openai_client_backend.h +++ b/src/c++/perf_analyzer/client_backend/openai/openai_client_backend.h @@ -29,7 +29,8 @@ #include "../../perf_utils.h" #include "../client_backend.h" -#include "openai_http_client.h" +#include "openai_client.h" +#include "openai_infer_input.h" #define RETURN_IF_TRITON_ERROR(S) \ do { \ @@ -85,7 +86,7 @@ class OpenAiClientBackend : public ClientBackend { void ParseInferStat( const tc::InferStat& openai_infer_stat, InferStat* infer_stat); - std::unique_ptr http_client_; + std::unique_ptr http_client_; std::shared_ptr http_headers_; }; @@ -107,24 +108,4 @@ class OpenAiInferRequestedOutput : public InferRequestedOutput { std::unique_ptr output_; }; -//============================================================== -/// OpenAiInferResult is a wrapper around InferResult object of -/// OpenAi InferResult object. -/// -class OpenAiInferResult : public cb::InferResult { - public: - explicit OpenAiInferResult(cb::openai::InferResult* result); - /// See InferResult::Id() - Error Id(std::string* id) const override; - /// See InferResult::RequestStatus() - Error RequestStatus() const override; - /// See InferResult::RawData() - Error RawData( - const std::string& output_name, const uint8_t** buf, - size_t* byte_size) const override; - - private: - std::unique_ptr result_; -}; - }}}} // namespace triton::perfanalyzer::clientbackend::openai diff --git a/src/c++/perf_analyzer/client_backend/openai/openai_http_client.cc b/src/c++/perf_analyzer/client_backend/openai/openai_http_client.cc deleted file mode 100644 index 151eca2a6..000000000 --- a/src/c++/perf_analyzer/client_backend/openai/openai_http_client.cc +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions -// are met: -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// * Neither the name of NVIDIA CORPORATION nor the names of its -// contributors may be used to endorse or promote products derived -// from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -#include "openai_http_client.h" - -#include - - -namespace triton { namespace perfanalyzer { namespace clientbackend { -namespace openai { - - -Error -HttpClient::Create( - std::unique_ptr* client, const std::string& server_url, - bool verbose) -{ - client->reset(new HttpClient(server_url, verbose)); - return Error::Success; -} - -Error -HttpClient::AsyncInfer( - OpenAiOnCompleteFn callback, const InferOptions& options, - const std::vector& inputs, - const std::vector& outputs, - const Headers& headers) -{ - // TODO FIXME implement - - // TODO FIXME cleanup or remove this. It just proves the json data arrives - rapidjson::Document d{}; - - if (inputs.size() != 1) { - return Error("Only expecting one input"); - } - - auto raw_input = dynamic_cast(inputs[0]); - - raw_input->PrepareForRequest(); - bool end_of_input = false; - const uint8_t* buf; - size_t buf_size; - raw_input->GetNext(&buf, &buf_size, &end_of_input); - if (!end_of_input) { - return Error("Unexpected multiple json data inputs"); - } - if (buf == nullptr) { - return Error("Unexpected null json data"); - } - - std::string json_str(reinterpret_cast(buf), buf_size); - std::cout << "FIXME TODO: JSON data string is " << json_str << std::endl; - - - if (d.Parse(json_str.c_str()).HasParseError()) { - return Error("Unable to parse json string: " + json_str); - } - - // FIXME TKG -- where/how would the 'streaming' option get plugged in? - - // FIXME TKG -- GOOD GOD! Is it this hard to add a single value into a json - // object?? - // FIXME TKG -- what if the user supplied this in the input json file? - d.AddMember( - "model", - rapidjson::Value().SetString( - options.model_name_.c_str(), - static_cast(options.model_name_.length()), - d.GetAllocator()), - d.GetAllocator()); - - for (auto itr = d.MemberBegin(); itr != d.MemberEnd(); ++itr) { - std::cout << "FIXME TODO: valid JSON object has key " - << itr->name.GetString() << std::endl; - } - - return Error::Success; -} - -HttpClient::HttpClient(const std::string& url, bool verbose) - : InferenceServerClient(verbose), url_(url) -// ,easy_handle_(reinterpret_cast(curl_easy_init()) // TODO FIXME TKG -{ -} - -HttpClient::~HttpClient() -{ - exiting_ = true; - - // FIXME TODO TKG - // if (easy_handle_ != nullptr) { - // curl_easy_cleanup(reinterpret_cast(easy_handle_)); - //} -} - -}}}} // namespace triton::perfanalyzer::clientbackend::openai \ No newline at end of file diff --git a/src/c++/perf_analyzer/client_backend/openai/openai_http_client.h b/src/c++/perf_analyzer/client_backend/openai/openai_http_client.h deleted file mode 100644 index bbdaddfe9..000000000 --- a/src/c++/perf_analyzer/client_backend/openai/openai_http_client.h +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions -// are met: -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// * Neither the name of NVIDIA CORPORATION nor the names of its -// contributors may be used to endorse or promote products derived -// from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#pragma once - -#include "../client_backend.h" -#include "common.h" -#include "openai_infer_input.h" - - -namespace tc = triton::client; - -namespace triton { namespace perfanalyzer { namespace clientbackend { -namespace openai { - -class InferResult; -class HttpInferRequest; - -using OpenAiOnCompleteFn = std::function; - -//============================================================================== -/// An HttpClient object is used to perform any kind of communication with the -/// OpenAi service using -/// -/// \code -/// std::unique_ptr client; -/// HttpClient::Create(&client, "localhost:8080"); -/// ... -/// ... -/// \endcode -/// -class HttpClient : public tc::InferenceServerClient { - public: - ~HttpClient(); - - /// TODO: Adjust as needed - /// Create a client that can be used to communicate with the server. - /// \param client Returns a new InferenceServerHttpClient object. - /// \param server_url The inference server name and port. - /// \param verbose If true generate verbose output when contacting - /// the inference server. - /// \return Error object indicating success or failure. - static Error Create( - std::unique_ptr* client, const std::string& server_url, - const bool verbose); - - /// TODO FIXME: Update - /// Run asynchronous inference on server. - Error AsyncInfer( - OpenAiOnCompleteFn callback, const InferOptions& options, - const std::vector& inputs, - const std::vector& outputs = - std::vector(), - const Headers& headers = Headers()); - - private: - HttpClient(const std::string& url, bool verbose); - - // The server url - const std::string url_; -}; - -//====================================================================== - -class InferResult { - public: - static Error Create( - InferResult** infer_result, - std::shared_ptr infer_request); - Error RequestStatus() const { return Error::Success; } // TODO FIXME TKG - Error Id(std::string* id) const { return Error::Success; } // TODO FIXME TKG - - private: - InferResult(std::shared_ptr infer_request); - - // The status of the inference - Error status_; - // The pointer to the HttpInferRequest object - std::shared_ptr infer_request_; -}; - -//====================================================================== - -}}}} // namespace triton::perfanalyzer::clientbackend::openai diff --git a/src/c++/perf_analyzer/client_backend/openai/openai_infer_input.cc b/src/c++/perf_analyzer/client_backend/openai/openai_infer_input.cc index 70d827e85..834e27788 100644 --- a/src/c++/perf_analyzer/client_backend/openai/openai_infer_input.cc +++ b/src/c++/perf_analyzer/client_backend/openai/openai_infer_input.cc @@ -51,9 +51,10 @@ OpenAiInferInput::SetShape(const std::vector& shape) Error OpenAiInferInput::Reset() { + data_str_.clear(); + bufs_.clear(); buf_byte_sizes_.clear(); - bufs_idx_ = 0; byte_size_ = 0; return Error::Success; } @@ -61,18 +62,12 @@ OpenAiInferInput::Reset() Error OpenAiInferInput::AppendRaw(const uint8_t* input, size_t input_byte_size) { + data_str_.clear(); + byte_size_ += input_byte_size; bufs_.push_back(input); buf_byte_sizes_.push_back(input_byte_size); - - return Error::Success; -} - -Error -OpenAiInferInput::ByteSize(size_t* byte_size) const -{ - *byte_size = byte_size_; return Error::Success; } @@ -80,32 +75,19 @@ Error OpenAiInferInput::PrepareForRequest() { // Reset position so request sends entire input. - bufs_idx_ = 0; - buf_pos_ = 0; - return Error::Success; -} - -Error -OpenAiInferInput::GetNext( - const uint8_t** buf, size_t* input_bytes, bool* end_of_input) -{ - if (bufs_idx_ < bufs_.size()) { - *buf = bufs_[bufs_idx_]; - *input_bytes = buf_byte_sizes_[bufs_idx_]; - bufs_idx_++; - } else { - *buf = nullptr; - *input_bytes = 0; + if (data_str_.empty() && (byte_size_ != 0)) { + for (size_t i = 0; i < bufs_.size(); ++i) { + data_str_.append( + reinterpret_cast(bufs_[i]), buf_byte_sizes_[i]); + } } - *end_of_input = (bufs_idx_ >= bufs_.size()); - return Error::Success; } OpenAiInferInput::OpenAiInferInput( const std::string& name, const std::vector& dims, const std::string& datatype) - : InferInput(BackendKind::TENSORFLOW_SERVING, name, datatype), shape_(dims) + : InferInput(BackendKind::OPENAI, name, datatype), shape_(dims) { } diff --git a/src/c++/perf_analyzer/client_backend/openai/openai_infer_input.h b/src/c++/perf_analyzer/client_backend/openai/openai_infer_input.h index a10b9312f..9ccf0945c 100644 --- a/src/c++/perf_analyzer/client_backend/openai/openai_infer_input.h +++ b/src/c++/perf_analyzer/client_backend/openai/openai_infer_input.h @@ -51,14 +51,10 @@ class OpenAiInferInput : public InferInput { Error Reset() override; /// See InferInput::AppendRaw() Error AppendRaw(const uint8_t* input, size_t input_byte_size) override; - /// Gets the size of data added into this input in bytes. - /// \param byte_size The size of data added in bytes. - /// \return Error object indicating success or failure. - Error ByteSize(size_t* byte_size) const; /// Resets the heads to start providing data from the beginning. Error PrepareForRequest(); /// Get the next chunk of data if available. - Error GetNext(const uint8_t** buf, size_t* input_bytes, bool* end_of_input); + std::string& DataString() { return data_str_; } private: explicit OpenAiInferInput( @@ -68,9 +64,9 @@ class OpenAiInferInput : public InferInput { std::vector shape_; size_t byte_size_{0}; - size_t bufs_idx_, buf_pos_; std::vector bufs_; std::vector buf_byte_sizes_; + std::string data_str_; }; }}}} // namespace triton::perfanalyzer::clientbackend::openai