Skip to content

Commit

Permalink
Pass endpoint to openai client
Browse files Browse the repository at this point in the history
  • Loading branch information
tgerdesnv committed Mar 4, 2024
1 parent 867003b commit ddc3266
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 33 deletions.
20 changes: 10 additions & 10 deletions src/c++/perf_analyzer/client_backend/client_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ BackendToGrpcType(const GrpcCompressionAlgorithm compression_algorithm)
//
Error
ClientBackendFactory::Create(
const BackendKind kind, const std::string& url, const ProtocolType protocol,
const SslOptionsBase& ssl_options,
const BackendKind kind, const std::string& url, const std::string& endpoint,
const ProtocolType protocol, const SslOptionsBase& ssl_options,
const std::map<std::string, std::vector<std::string>> trace_options,
const GrpcCompressionAlgorithm compression_algorithm,
std::shared_ptr<Headers> http_headers,
Expand All @@ -128,9 +128,10 @@ ClientBackendFactory::Create(
std::shared_ptr<ClientBackendFactory>* factory)
{
factory->reset(new ClientBackendFactory(
kind, url, protocol, ssl_options, trace_options, compression_algorithm,
http_headers, triton_server_path, model_repository_path, verbose,
metrics_url, input_tensor_format, output_tensor_format));
kind, url, endpoint, protocol, ssl_options, trace_options,
compression_algorithm, http_headers, triton_server_path,
model_repository_path, verbose, metrics_url, input_tensor_format,
output_tensor_format));
return Error::Success;
}

Expand All @@ -139,7 +140,7 @@ ClientBackendFactory::CreateClientBackend(
std::unique_ptr<ClientBackend>* client_backend)
{
RETURN_IF_CB_ERROR(ClientBackend::Create(
kind_, url_, protocol_, ssl_options_, trace_options_,
kind_, url_, endpoint_, protocol_, ssl_options_, trace_options_,
compression_algorithm_, http_headers_, verbose_, triton_server_path,
model_repository_path_, metrics_url_, input_tensor_format_,
output_tensor_format_, client_backend));
Expand All @@ -157,8 +158,8 @@ ClientBackendFactory::Kind()
//
Error
ClientBackend::Create(
const BackendKind kind, const std::string& url, const ProtocolType protocol,
const SslOptionsBase& ssl_options,
const BackendKind kind, const std::string& url, const std::string& endpoint,
const ProtocolType protocol, const SslOptionsBase& ssl_options,
const std::map<std::string, std::vector<std::string>> trace_options,
const GrpcCompressionAlgorithm compression_algorithm,
std::shared_ptr<Headers> http_headers, const bool verbose,
Expand All @@ -177,10 +178,9 @@ ClientBackend::Create(
&local_backend));
}
#ifdef TRITON_ENABLE_PERF_ANALYZER_OPENAI
// TODO -- I think this needs endpoint to be passed in?
else if (kind == OPENAI) {
RETURN_IF_CB_ERROR(openai::OpenAiClientBackend::Create(
url, protocol, http_headers, verbose, &local_backend));
url, endpoint, protocol, http_headers, verbose, &local_backend));
}
#endif // TRITON_ENABLE_PERF_ANALYZER_OPENAI
#ifdef TRITON_ENABLE_PERF_ANALYZER_TFS
Expand Down
15 changes: 10 additions & 5 deletions src/c++/perf_analyzer/client_backend/client_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class ClientBackendFactory {
/// Create a factory that can be used to construct Client Backends.
/// \param kind The kind of client backend to create.
/// \param url The inference server url and port.
/// \param endpoint The endpoint on the inference server to send requests to
/// \param protocol The protocol type used.
/// \param ssl_options The SSL options used with client backend.
/// \param compression_algorithm The compression algorithm to be used
Expand All @@ -290,7 +291,8 @@ class ClientBackendFactory {
/// \return Error object indicating success or failure.
static Error Create(
const BackendKind kind, const std::string& url,
const ProtocolType protocol, const SslOptionsBase& ssl_options,
const std::string& endpoint, const ProtocolType protocol,
const SslOptionsBase& ssl_options,
const std::map<std::string, std::vector<std::string>> trace_options,
const GrpcCompressionAlgorithm compression_algorithm,
std::shared_ptr<Headers> http_headers,
Expand All @@ -309,16 +311,17 @@ class ClientBackendFactory {
private:
ClientBackendFactory(
const BackendKind kind, const std::string& url,
const ProtocolType protocol, const SslOptionsBase& ssl_options,
const std::string& endpoint, const ProtocolType protocol,
const SslOptionsBase& ssl_options,
const std::map<std::string, std::vector<std::string>> trace_options,
const GrpcCompressionAlgorithm compression_algorithm,
const std::shared_ptr<Headers> http_headers,
const std::string& triton_server_path,
const std::string& model_repository_path, const bool verbose,
const std::string& metrics_url, const TensorFormat input_tensor_format,
const TensorFormat output_tensor_format)
: kind_(kind), url_(url), protocol_(protocol), ssl_options_(ssl_options),
trace_options_(trace_options),
: kind_(kind), url_(url), endpoint_(endpoint), protocol_(protocol),
ssl_options_(ssl_options), trace_options_(trace_options),
compression_algorithm_(compression_algorithm),
http_headers_(http_headers), triton_server_path(triton_server_path),
model_repository_path_(model_repository_path), verbose_(verbose),
Expand All @@ -329,6 +332,7 @@ class ClientBackendFactory {

const BackendKind kind_;
const std::string url_;
const std::string endpoint_;
const ProtocolType protocol_;
const SslOptionsBase& ssl_options_;
const std::map<std::string, std::vector<std::string>> trace_options_;
Expand Down Expand Up @@ -361,7 +365,8 @@ class ClientBackend {
public:
static Error Create(
const BackendKind kind, const std::string& url,
const ProtocolType protocol, const SslOptionsBase& ssl_options,
const std::string& endpoint, const ProtocolType protocol,
const SslOptionsBase& ssl_options,
const std::map<std::string, std::vector<std::string>> trace_options,
const GrpcCompressionAlgorithm compression_algorithm,
std::shared_ptr<Headers> http_headers, const bool verbose,
Expand Down
6 changes: 3 additions & 3 deletions src/c++/perf_analyzer/client_backend/openai/openai_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ ChatCompletionRequest::SendResponse(bool is_final, bool is_null)
}

ChatCompletionClient::ChatCompletionClient(
const std::string& url, bool verbose, const HttpSslOptions& ssl_options)
: HttpClient(
std::string(url + "/v1/chat/completions"), verbose, ssl_options)
const std::string& url, const std::string& endpoint, bool verbose,
const HttpSslOptions& ssl_options)
: HttpClient(std::string(url + "/" + endpoint), verbose, ssl_options)
{
}

Expand Down
6 changes: 4 additions & 2 deletions src/c++/perf_analyzer/client_backend/openai/openai_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ChatCompletionResult : public InferResult {
{
if ((http_code_ >= 400) && (http_code_ <= 599)) {
return Error(
"OpenAI response returns HTTP code" + std::to_string(http_code_));
"OpenAI response returns HTTP code " + std::to_string(http_code_));
}
return Error::Success;
}
Expand Down Expand Up @@ -139,6 +139,7 @@ class ChatCompletionClient : public HttpClient {
/// \param server_url The inference server name, port, optional
/// scheme and optional base path in the following format:
/// <scheme://>host:port/<base-path>.
/// \param endpoint The name of the endpoint to send requests to
/// \param verbose If true generate verbose output when contacting
/// the inference server.
/// \param ssl_options Specifies the settings for configuring
Expand All @@ -148,7 +149,8 @@ class ChatCompletionClient : public HttpClient {
/// These options will be ignored if the server_url does not
/// expose `https://` scheme.
ChatCompletionClient(
const std::string& server_url, bool verbose = false,
const std::string& server_url, const std::string& endpoint,
bool verbose = false,
const HttpSslOptions& ssl_options = HttpSslOptions());

/// Simplified AsyncInfer() where the request body is expected to be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ namespace openai {

Error
OpenAiClientBackend::Create(
const std::string& url, const ProtocolType protocol,
std::shared_ptr<Headers> http_headers, const bool verbose,
std::unique_ptr<ClientBackend>* client_backend)
const std::string& url, const std::string& endpoint,
const ProtocolType protocol, std::shared_ptr<Headers> http_headers,
const bool verbose, std::unique_ptr<ClientBackend>* client_backend)
{
if (protocol == ProtocolType::GRPC) {
return Error(
Expand All @@ -47,7 +47,7 @@ OpenAiClientBackend::Create(
new OpenAiClientBackend(http_headers));

openai_client_backend->http_client_.reset(
new ChatCompletionClient(url, verbose));
new ChatCompletionClient(url, endpoint, verbose));

*client_backend = std::move(openai_client_backend);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class OpenAiClientBackend : public ClientBackend {
/// Create an OpenAI client backend which can be used to interact with the
/// server.
/// \param url The inference server url and port.
/// \param endpoint The endpoint on the inference server to send requests to
/// \param protocol The protocol type used.
/// \param http_headers Map of HTTP headers. The map key/value indicates
/// the header name/value.
Expand All @@ -64,9 +65,9 @@ class OpenAiClientBackend : public ClientBackend {
/// object.
/// \return Error object indicating success or failure.
static Error Create(
const std::string& url, const ProtocolType protocol,
std::shared_ptr<Headers> http_headers, const bool verbose,
std::unique_ptr<ClientBackend>* client_backend);
const std::string& url, const std::string& endpoint,
const ProtocolType protocol, std::shared_ptr<Headers> http_headers,
const bool verbose, std::unique_ptr<ClientBackend>* client_backend);

/// See ClientBackend::AsyncInfer()
Error AsyncInfer(
Expand Down
13 changes: 7 additions & 6 deletions src/c++/perf_analyzer/perf_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ PerfAnalyzer::CreateAnalyzerObjects()
std::shared_ptr<cb::ClientBackendFactory> factory;
FAIL_IF_ERR(
cb::ClientBackendFactory::Create(
params_->kind, params_->url, params_->protocol, params_->ssl_options,
params_->trace_options, params_->compression_algorithm,
params_->http_headers, params_->triton_server_path,
params_->model_repository_path, params_->extra_verbose,
params_->metrics_url, params_->input_tensor_format,
params_->output_tensor_format, &factory),
params_->kind, params_->url, params_->endpoint, params_->protocol,
params_->ssl_options, params_->trace_options,
params_->compression_algorithm, params_->http_headers,
params_->triton_server_path, params_->model_repository_path,
params_->extra_verbose, params_->metrics_url,
params_->input_tensor_format, params_->output_tensor_format,
&factory),
"failed to create client factory");

FAIL_IF_ERR(
Expand Down

0 comments on commit ddc3266

Please sign in to comment.