Skip to content

Commit

Permalink
add request parameters (#405)
Browse files Browse the repository at this point in the history
* Initial parameter passing support

* Fix parameter ordering

* Remove commented code

* Remove unnecessary type in request parameter

* Fix includes and map assignment

* Update grpc request parameters to use only strings in PA
  • Loading branch information
debermudez authored Sep 29, 2023
1 parent ef09eba commit 1c89222
Show file tree
Hide file tree
Showing 30 changed files with 210 additions and 123 deletions.
9 changes: 9 additions & 0 deletions src/c++/library/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <mutex>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>

#ifdef TRITON_INFERENCE_SERVER_CLIENT_CLASS
Expand Down Expand Up @@ -153,6 +154,12 @@ class InferenceServerClient {
InferStat infer_stat_;
};

struct RequestParameter {
std::string name;
std::string value;
std::string type;
};

//==============================================================================
/// Structure to hold options for Inference Request.
///
Expand Down Expand Up @@ -221,6 +228,8 @@ struct InferOptions {
uint64_t client_timeout_;
/// Whether to tell Triton to enable an empty final response.
bool triton_enable_empty_final_response_;
/// Additional parameters to pass to the model
std::unordered_map<std::string, RequestParameter> request_parameters;
};

//==============================================================================
Expand Down
18 changes: 18 additions & 0 deletions src/c++/library/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <iostream>
#include <mutex>
#include <sstream>
#include <string>

#include "common.h"

Expand Down Expand Up @@ -1408,6 +1409,23 @@ InferenceServerGrpcClient::PreRunProcessing(
options.server_timeout_);
}


for (auto& param : options.request_parameters) {
if (param.second.type == "string") {
(*infer_request_.mutable_parameters())[param.first].set_string_param(
param.second.value);
} else if (param.second.type == "int") {
(*infer_request_.mutable_parameters())[param.first].set_int64_param(
std::stoi(param.second.value));
} else if (param.second.type == "bool") {
bool val = false;
if (param.second.value == "true") {
val = true;
}
(*infer_request_.mutable_parameters())[param.first].set_bool_param(val);
}
}

int index = 0;
infer_request_.mutable_raw_input_contents()->Clear();
for (const auto input : inputs) {
Expand Down
12 changes: 12 additions & 0 deletions src/c++/perf_analyzer/client_backend/client_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ struct ModelStatistics {
uint64_t cache_miss_time_ns_;
};

///
/// Structure to hold Request parameter data for Inference Request.
///
struct RequestParameter {
std::string name;
std::string value;
std::string type;
};

//==============================================================================
/// Structure to hold options for Inference Request.
///
Expand Down Expand Up @@ -230,6 +239,9 @@ struct InferOptions {
bool sequence_end_;
/// Whether to tell Triton to enable an empty final response.
bool triton_enable_empty_final_response_;

/// Additional parameters to pass to the model
std::unordered_map<std::string, RequestParameter> request_parameters_;
};

struct SslOptionsBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,14 @@ TritonClientBackend::ParseInferOptionsToTriton(
}
triton_options->triton_enable_empty_final_response_ =
options.triton_enable_empty_final_response_;

for (auto& map_entry : options.request_parameters_) {
auto rp = tc::RequestParameter();
rp.name = map_entry.second.name;
rp.value = map_entry.second.value;
rp.type = map_entry.second.type;
triton_options->request_parameters[map_entry.first] = rp;
}
}


Expand Down
19 changes: 4 additions & 15 deletions src/c++/perf_analyzer/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1601,21 +1601,10 @@ CLParser::ParseCommandLine(int argc, char** argv)
std::string value{values[1]};
std::string type{values[2]};

RequestParameter param;
if (type == "bool") {
param.type = RequestParameterType::BOOL;
param.bool_value = value == "true" ? true : false;
} else if (type == "int") {
param.type = RequestParameterType::INT;
param.int_value = std::stoll(value);
} else if (type == "string") {
param.type = RequestParameterType::STRING;
param.str_value = value;
} else {
Usage(
"Failed to parse --request-parameter. Unsupported type: '" +
type + "'.");
}
cb::RequestParameter param;
param.name = name;
param.value = value;
param.type = type;
params_->request_parameters[name] = param;
break;
}
Expand Down
2 changes: 1 addition & 1 deletion src/c++/perf_analyzer/command_line_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct PerfAnalyzerParameters {
uint64_t measurement_window_ms = 5000;
bool using_concurrency_range = false;
Range<uint64_t> concurrency_range{1, 1, 1};
std::unordered_map<std::string, RequestParameter> request_parameters;
std::unordered_map<std::string, cb::RequestParameter> request_parameters;
uint64_t latency_threshold_ms = NO_LIMIT;
double stability_threshold = 0.1;
size_t max_trials = 10;
Expand Down
13 changes: 9 additions & 4 deletions src/c++/perf_analyzer/concurrency_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ ConcurrencyManager::Create(
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory,
std::unique_ptr<LoadManager>* manager)
std::unique_ptr<LoadManager>* manager,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters)
{
std::unique_ptr<ConcurrencyManager> local_manager(new ConcurrencyManager(
async, streaming, batch_size, max_threads, max_concurrency,
shared_memory_type, output_shm_size, parser, factory));
shared_memory_type, output_shm_size, parser, factory,
request_parameters));

*manager = std::move(local_manager);

Expand All @@ -60,10 +63,12 @@ ConcurrencyManager::ConcurrencyManager(
const size_t max_threads, const size_t max_concurrency,
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory)
const std::shared_ptr<cb::ClientBackendFactory>& factory,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters)
: LoadManager(
async, streaming, batch_size, max_threads, shared_memory_type,
output_shm_size, parser, factory),
output_shm_size, parser, factory, request_parameters),
execute_(true), max_concurrency_(max_concurrency)
{
threads_config_.reserve(max_threads);
Expand Down
9 changes: 7 additions & 2 deletions src/c++/perf_analyzer/concurrency_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,17 @@ class ConcurrencyManager : public LoadManager {
/// \param factory The ClientBackendFactory object used to create
/// client to the server.
/// \param manager Returns a new ConcurrencyManager object.
/// \param request_parameters Custom request parameters to send to the server
/// \return cb::Error object indicating success or failure.
static cb::Error Create(
const bool async, const bool streaming, const int32_t batch_size,
const size_t max_threads, const size_t max_concurrency,
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory,
std::unique_ptr<LoadManager>* manager);
std::unique_ptr<LoadManager>* manager,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters);

/// Adjusts the number of concurrent requests to be the same as
/// 'concurrent_request_count' (by creating or pausing threads)
Expand All @@ -100,7 +103,9 @@ class ConcurrencyManager : public LoadManager {
const size_t max_threads, const size_t max_concurrency,
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory);
const std::shared_ptr<cb::ClientBackendFactory>& factory,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters);

// The number of worker threads with non-zero concurrencies
size_t active_threads_;
Expand Down
13 changes: 9 additions & 4 deletions src/c++/perf_analyzer/custom_load_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ CustomLoadManager::Create(
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const bool serial_sequences, const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory,
std::unique_ptr<LoadManager>* manager)
std::unique_ptr<LoadManager>* manager,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters)
{
std::unique_ptr<CustomLoadManager> local_manager(new CustomLoadManager(
async, streaming, request_intervals_file, batch_size,
measurement_window_ms, max_trials, max_threads, num_of_sequences,
shared_memory_type, output_shm_size, serial_sequences, parser, factory));
shared_memory_type, output_shm_size, serial_sequences, parser, factory,
request_parameters));

*manager = std::move(local_manager);

Expand All @@ -60,12 +63,14 @@ CustomLoadManager::CustomLoadManager(
const size_t max_threads, const uint32_t num_of_sequences,
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const bool serial_sequences, const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory)
const std::shared_ptr<cb::ClientBackendFactory>& factory,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters)
: RequestRateManager(
async, streaming, Distribution::CUSTOM, batch_size,
measurement_window_ms, max_trials, max_threads, num_of_sequences,
shared_memory_type, output_shm_size, serial_sequences, parser,
factory),
factory, request_parameters),
request_intervals_file_(request_intervals_file)
{
}
Expand Down
9 changes: 7 additions & 2 deletions src/c++/perf_analyzer/custom_load_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class CustomLoadManager : public RequestRateManager {
/// \param factory The ClientBackendFactory object used to create
/// client to the server.
/// \param manager Returns a new ConcurrencyManager object.
/// \param request_parameters Custom request parameters to send to the server
/// \return cb::Error object indicating success or failure.
static cb::Error Create(
const bool async, const bool streaming,
Expand All @@ -81,7 +82,9 @@ class CustomLoadManager : public RequestRateManager {
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const bool serial_sequences, const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory,
std::unique_ptr<LoadManager>* manager);
std::unique_ptr<LoadManager>* manager,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameter);

/// Initializes the load manager with the provided file containing request
/// intervals
Expand All @@ -103,7 +106,9 @@ class CustomLoadManager : public RequestRateManager {
const size_t max_threads, const uint32_t num_of_sequences,
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const bool serial_sequences, const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory);
const std::shared_ptr<cb::ClientBackendFactory>& factory,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters);

cb::Error GenerateSchedule();

Expand Down
3 changes: 3 additions & 0 deletions src/c++/perf_analyzer/infer_data_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ InferDataManager::InitInferDataInput(
infer_input->AppendRaw(input_data.data_ptr, input_data.batch1_size));
}
}

AddInferDataParameters(infer_data);

return cb::Error::Success;
}

Expand Down
5 changes: 4 additions & 1 deletion src/c++/perf_analyzer/infer_data_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ class InferDataManager : public InferDataManagerBase {
public:
InferDataManager(
const size_t max_threads, const int32_t batch_size,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters,
const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory,
const std::shared_ptr<DataLoader>& data_loader)
: max_threads_(max_threads),
InferDataManagerBase(batch_size, parser, factory, data_loader)
InferDataManagerBase(
batch_size, request_parameters, parser, factory, data_loader)
{
}

Expand Down
5 changes: 5 additions & 0 deletions src/c++/perf_analyzer/infer_data_manager_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,5 +179,10 @@ InferDataManagerBase::CreateInferInput(
return cb::InferInput::Create(infer_input, kind, name, dims, datatype);
}

void
InferDataManagerBase::AddInferDataParameters(InferData& infer_data)
{
infer_data.options_->request_parameters_ = request_parameters_;
}

}} // namespace triton::perfanalyzer
13 changes: 10 additions & 3 deletions src/c++/perf_analyzer/infer_data_manager_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ namespace triton { namespace perfanalyzer {
class InferDataManagerBase : public IInferDataManager {
public:
InferDataManagerBase(
const int32_t batch_size, const std::shared_ptr<ModelParser>& parser,
const int32_t batch_size,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters,
const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory,
const std::shared_ptr<DataLoader>& data_loader)
: batch_size_(batch_size), parser_(parser), factory_(factory),
data_loader_(data_loader), backend_kind_(factory->Kind())
: batch_size_(batch_size), request_parameters_(request_parameters),
parser_(parser), factory_(factory), data_loader_(data_loader),
backend_kind_(factory->Kind())
{
}

Expand All @@ -72,6 +76,7 @@ class InferDataManagerBase : public IInferDataManager {
std::shared_ptr<DataLoader> data_loader_;
std::unique_ptr<cb::ClientBackend> backend_;
cb::BackendKind backend_kind_;
std::unordered_map<std::string, cb::RequestParameter> request_parameters_;

/// Gets the input data for the specified input for the specified batch size
///
Expand Down Expand Up @@ -135,6 +140,8 @@ class InferDataManagerBase : public IInferDataManager {
virtual cb::Error InitInferDataOutput(
const std::string& name, InferData& infer_data) = 0;

void AddInferDataParameters(InferData& infer_data);

#ifndef DOCTEST_CONFIG_DISABLE
public:
InferDataManagerBase() = default;
Expand Down
Loading

0 comments on commit 1c89222

Please sign in to comment.