Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add request parameters #405

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/c++/library/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <mutex>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>

#ifdef TRITON_INFERENCE_SERVER_CLIENT_CLASS
Expand Down Expand Up @@ -153,6 +154,12 @@ class InferenceServerClient {
InferStat infer_stat_;
};

struct RequestParameter {
std::string name;
std::string value;
std::string type;
};

//==============================================================================
/// Structure to hold options for Inference Request.
///
Expand Down Expand Up @@ -221,6 +228,8 @@ struct InferOptions {
uint64_t client_timeout_;
/// Whether to tell Triton to enable an empty final response.
bool triton_enable_empty_final_response_;
/// Additional parameters to pass to the model
std::unordered_map<std::string, RequestParameter> request_parameters;
};

//==============================================================================
Expand Down
18 changes: 18 additions & 0 deletions src/c++/library/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <iostream>
#include <mutex>
#include <sstream>
#include <string>

#include "common.h"

Expand Down Expand Up @@ -1408,6 +1409,23 @@ InferenceServerGrpcClient::PreRunProcessing(
options.server_timeout_);
}


for (auto& param : options.request_parameters) {
if (param.second.type == "string") {
(*infer_request_.mutable_parameters())[param.first].set_string_param(
param.second.value);
} else if (param.second.type == "int") {
(*infer_request_.mutable_parameters())[param.first].set_int64_param(
std::stoi(param.second.value));
} else if (param.second.type == "bool") {
bool val = false;
if (param.second.value == "true") {
val = true;
}
(*infer_request_.mutable_parameters())[param.first].set_bool_param(val);
}
}

int index = 0;
infer_request_.mutable_raw_input_contents()->Clear();
for (const auto input : inputs) {
Expand Down
12 changes: 12 additions & 0 deletions src/c++/perf_analyzer/client_backend/client_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ struct ModelStatistics {
uint64_t cache_miss_time_ns_;
};

///
/// Structure to hold Request parameter data for Inference Request.
///
struct RequestParameter {
std::string name;
std::string value;
std::string type;
};

//==============================================================================
/// Structure to hold options for Inference Request.
///
Expand Down Expand Up @@ -230,6 +239,9 @@ struct InferOptions {
bool sequence_end_;
/// Whether to tell Triton to enable an empty final response.
bool triton_enable_empty_final_response_;

/// Additional parameters to pass to the model
std::unordered_map<std::string, RequestParameter> request_parameters_;
};

struct SslOptionsBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,14 @@ TritonClientBackend::ParseInferOptionsToTriton(
}
triton_options->triton_enable_empty_final_response_ =
options.triton_enable_empty_final_response_;

for (auto& map_entry : options.request_parameters_) {
auto rp = tc::RequestParameter();
rp.name = map_entry.second.name;
rp.value = map_entry.second.value;
rp.type = map_entry.second.type;
triton_options->request_parameters[map_entry.first] = rp;
}
}


Expand Down
19 changes: 4 additions & 15 deletions src/c++/perf_analyzer/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1601,21 +1601,10 @@ CLParser::ParseCommandLine(int argc, char** argv)
std::string value{values[1]};
std::string type{values[2]};

RequestParameter param;
if (type == "bool") {
param.type = RequestParameterType::BOOL;
param.bool_value = value == "true" ? true : false;
} else if (type == "int") {
param.type = RequestParameterType::INT;
param.int_value = std::stoll(value);
} else if (type == "string") {
param.type = RequestParameterType::STRING;
param.str_value = value;
} else {
Usage(
"Failed to parse --request-parameter. Unsupported type: '" +
type + "'.");
}
cb::RequestParameter param;
param.name = name;
param.value = value;
param.type = type;
params_->request_parameters[name] = param;
break;
}
Expand Down
2 changes: 1 addition & 1 deletion src/c++/perf_analyzer/command_line_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct PerfAnalyzerParameters {
uint64_t measurement_window_ms = 5000;
bool using_concurrency_range = false;
Range<uint64_t> concurrency_range{1, 1, 1};
std::unordered_map<std::string, RequestParameter> request_parameters;
std::unordered_map<std::string, cb::RequestParameter> request_parameters;
uint64_t latency_threshold_ms = NO_LIMIT;
double stability_threshold = 0.1;
size_t max_trials = 10;
Expand Down
13 changes: 9 additions & 4 deletions src/c++/perf_analyzer/concurrency_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ ConcurrencyManager::Create(
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory,
std::unique_ptr<LoadManager>* manager)
std::unique_ptr<LoadManager>* manager,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters)
{
std::unique_ptr<ConcurrencyManager> local_manager(new ConcurrencyManager(
async, streaming, batch_size, max_threads, max_concurrency,
shared_memory_type, output_shm_size, parser, factory));
shared_memory_type, output_shm_size, parser, factory,
request_parameters));

*manager = std::move(local_manager);

Expand All @@ -60,10 +63,12 @@ ConcurrencyManager::ConcurrencyManager(
const size_t max_threads, const size_t max_concurrency,
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory)
const std::shared_ptr<cb::ClientBackendFactory>& factory,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters)
: LoadManager(
async, streaming, batch_size, max_threads, shared_memory_type,
output_shm_size, parser, factory),
output_shm_size, parser, factory, request_parameters),
execute_(true), max_concurrency_(max_concurrency)
{
threads_config_.reserve(max_threads);
Expand Down
9 changes: 7 additions & 2 deletions src/c++/perf_analyzer/concurrency_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,17 @@ class ConcurrencyManager : public LoadManager {
/// \param factory The ClientBackendFactory object used to create
/// client to the server.
/// \param manager Returns a new ConcurrencyManager object.
/// \param request_parameters Custom request parameters to send to the server
/// \return cb::Error object indicating success or failure.
static cb::Error Create(
const bool async, const bool streaming, const int32_t batch_size,
const size_t max_threads, const size_t max_concurrency,
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory,
std::unique_ptr<LoadManager>* manager);
std::unique_ptr<LoadManager>* manager,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters);

/// Adjusts the number of concurrent requests to be the same as
/// 'concurrent_request_count' (by creating or pausing threads)
Expand All @@ -100,7 +103,9 @@ class ConcurrencyManager : public LoadManager {
const size_t max_threads, const size_t max_concurrency,
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory);
const std::shared_ptr<cb::ClientBackendFactory>& factory,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters);

// The number of worker threads with non-zero concurrencies
size_t active_threads_;
Expand Down
13 changes: 9 additions & 4 deletions src/c++/perf_analyzer/custom_load_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ CustomLoadManager::Create(
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const bool serial_sequences, const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory,
std::unique_ptr<LoadManager>* manager)
std::unique_ptr<LoadManager>* manager,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters)
{
std::unique_ptr<CustomLoadManager> local_manager(new CustomLoadManager(
async, streaming, request_intervals_file, batch_size,
measurement_window_ms, max_trials, max_threads, num_of_sequences,
shared_memory_type, output_shm_size, serial_sequences, parser, factory));
shared_memory_type, output_shm_size, serial_sequences, parser, factory,
request_parameters));

*manager = std::move(local_manager);

Expand All @@ -60,12 +63,14 @@ CustomLoadManager::CustomLoadManager(
const size_t max_threads, const uint32_t num_of_sequences,
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const bool serial_sequences, const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory)
const std::shared_ptr<cb::ClientBackendFactory>& factory,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters)
: RequestRateManager(
async, streaming, Distribution::CUSTOM, batch_size,
measurement_window_ms, max_trials, max_threads, num_of_sequences,
shared_memory_type, output_shm_size, serial_sequences, parser,
factory),
factory, request_parameters),
request_intervals_file_(request_intervals_file)
{
}
Expand Down
9 changes: 7 additions & 2 deletions src/c++/perf_analyzer/custom_load_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class CustomLoadManager : public RequestRateManager {
/// \param factory The ClientBackendFactory object used to create
/// client to the server.
/// \param manager Returns a new ConcurrencyManager object.
/// \param request_parameters Custom request parameters to send to the server
/// \return cb::Error object indicating success or failure.
static cb::Error Create(
const bool async, const bool streaming,
Expand All @@ -81,7 +82,9 @@ class CustomLoadManager : public RequestRateManager {
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const bool serial_sequences, const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory,
std::unique_ptr<LoadManager>* manager);
std::unique_ptr<LoadManager>* manager,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameter);

/// Initializes the load manager with the provided file containing request
/// intervals
Expand All @@ -103,7 +106,9 @@ class CustomLoadManager : public RequestRateManager {
const size_t max_threads, const uint32_t num_of_sequences,
const SharedMemoryType shared_memory_type, const size_t output_shm_size,
const bool serial_sequences, const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory);
const std::shared_ptr<cb::ClientBackendFactory>& factory,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters);

cb::Error GenerateSchedule();

Expand Down
3 changes: 3 additions & 0 deletions src/c++/perf_analyzer/infer_data_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ InferDataManager::InitInferDataInput(
infer_input->AppendRaw(input_data.data_ptr, input_data.batch1_size));
}
}

AddInferDataParameters(infer_data);

return cb::Error::Success;
}

Expand Down
5 changes: 4 additions & 1 deletion src/c++/perf_analyzer/infer_data_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ class InferDataManager : public InferDataManagerBase {
public:
InferDataManager(
const size_t max_threads, const int32_t batch_size,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters,
const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory,
const std::shared_ptr<DataLoader>& data_loader)
: max_threads_(max_threads),
InferDataManagerBase(batch_size, parser, factory, data_loader)
InferDataManagerBase(
batch_size, request_parameters, parser, factory, data_loader)
{
}

Expand Down
5 changes: 5 additions & 0 deletions src/c++/perf_analyzer/infer_data_manager_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,5 +179,10 @@ InferDataManagerBase::CreateInferInput(
return cb::InferInput::Create(infer_input, kind, name, dims, datatype);
}

void
InferDataManagerBase::AddInferDataParameters(InferData& infer_data)
{
infer_data.options_->request_parameters_ = request_parameters_;
}

}} // namespace triton::perfanalyzer
13 changes: 10 additions & 3 deletions src/c++/perf_analyzer/infer_data_manager_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ namespace triton { namespace perfanalyzer {
class InferDataManagerBase : public IInferDataManager {
public:
InferDataManagerBase(
const int32_t batch_size, const std::shared_ptr<ModelParser>& parser,
const int32_t batch_size,
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters,
const std::shared_ptr<ModelParser>& parser,
const std::shared_ptr<cb::ClientBackendFactory>& factory,
const std::shared_ptr<DataLoader>& data_loader)
: batch_size_(batch_size), parser_(parser), factory_(factory),
data_loader_(data_loader), backend_kind_(factory->Kind())
: batch_size_(batch_size), request_parameters_(request_parameters),
parser_(parser), factory_(factory), data_loader_(data_loader),
backend_kind_(factory->Kind())
{
}

Expand All @@ -72,6 +76,7 @@ class InferDataManagerBase : public IInferDataManager {
std::shared_ptr<DataLoader> data_loader_;
std::unique_ptr<cb::ClientBackend> backend_;
cb::BackendKind backend_kind_;
std::unordered_map<std::string, cb::RequestParameter> request_parameters_;

/// Gets the input data for the specified input for the specified batch size
///
Expand Down Expand Up @@ -135,6 +140,8 @@ class InferDataManagerBase : public IInferDataManager {
virtual cb::Error InitInferDataOutput(
const std::string& name, InferData& infer_data) = 0;

void AddInferDataParameters(InferData& infer_data);

#ifndef DOCTEST_CONFIG_DISABLE
public:
InferDataManagerBase() = default;
Expand Down
Loading
Loading