diff --git a/src/c++/perf_analyzer/CMakeLists.txt b/src/c++/perf_analyzer/CMakeLists.txt index a0fdec47f..a6ae75cda 100644 --- a/src/c++/perf_analyzer/CMakeLists.txt +++ b/src/c++/perf_analyzer/CMakeLists.txt @@ -69,6 +69,8 @@ set( sequence_manager.cc profile_data_collector.cc profile_data_exporter.cc + periodic_concurrency_manager.cc + periodic_concurrency_worker.cc ) set( @@ -109,6 +111,8 @@ set( request_record.h profile_data_collector.h profile_data_exporter.h + periodic_concurrency_manager.h + periodic_concurrency_worker.h ) add_executable( diff --git a/src/c++/perf_analyzer/command_line_parser.h b/src/c++/perf_analyzer/command_line_parser.h index beea82020..a0706525c 100644 --- a/src/c++/perf_analyzer/command_line_parser.h +++ b/src/c++/perf_analyzer/command_line_parser.h @@ -130,7 +130,8 @@ struct PerfAnalyzerParameters { { return ( using_concurrency_range || using_old_options || - !(using_request_rate_range || using_custom_intervals)); + !(using_request_rate_range || using_custom_intervals || + is_using_periodic_concurrency_mode)); } // Sets the threshold for PA client overhead. @@ -148,6 +149,11 @@ struct PerfAnalyzerParameters { // The profile export file path. std::string profile_export_file{""}; + + bool is_using_periodic_concurrency_mode{false}; + + Range periodic_concurrency_range{1, 1, 1}; + uint64_t periodic_concurrency_request_period{10}; }; using PAParamsPtr = std::shared_ptr; diff --git a/src/c++/perf_analyzer/concurrency_manager.h b/src/c++/perf_analyzer/concurrency_manager.h index 4b3870899..15e211ca1 100644 --- a/src/c++/perf_analyzer/concurrency_manager.h +++ b/src/c++/perf_analyzer/concurrency_manager.h @@ -95,7 +95,6 @@ class ConcurrencyManager : public LoadManager { std::shared_ptr, std::shared_ptr); - private: ConcurrencyManager( const bool async, const bool streaming, const int32_t batch_size, const size_t max_threads, const size_t max_concurrency, @@ -103,6 +102,16 @@ class ConcurrencyManager : public LoadManager { const std::shared_ptr& parser, const std::shared_ptr& factory); + // The number of worker threads with non-zero concurrencies + size_t active_threads_; + + bool execute_; + + size_t max_concurrency_; + + std::vector> threads_config_; + + private: void InitManagerFinalize() override; // Pause all worker threads that are working on sequences @@ -118,14 +127,6 @@ class ConcurrencyManager : public LoadManager { // void ResumeSequenceWorkers(); - // The number of worker threads with non-zero concurrencies - size_t active_threads_; - - bool execute_; - - size_t max_concurrency_; - std::vector> threads_config_; - #ifndef DOCTEST_CONFIG_DISABLE friend TestConcurrencyManager; diff --git a/src/c++/perf_analyzer/concurrency_worker.cc b/src/c++/perf_analyzer/concurrency_worker.cc index 6fb6cee81..37a562f76 100644 --- a/src/c++/perf_analyzer/concurrency_worker.cc +++ b/src/c++/perf_analyzer/concurrency_worker.cc @@ -46,33 +46,34 @@ ConcurrencyWorker::Infer() // run inferencing until receiving exit signal to maintain server load. do { - HandleExecuteOff(); - - if (HandleNoConcurrency()) { - return; - } - - CreateContextsAsNecessary(); - - if (HandleExitConditions()) { - return; - } - - SendInferRequests(); - - if (HandleExitConditions()) { - return; - } - - WaitForResponses(); - - if (HandleExitConditions()) { - return; + if (RunInference()) { + break; } - } while (true); } +bool +ConcurrencyWorker::RunInference() +{ + HandleExecuteOff(); + if (HandleNoConcurrency()) { + return true; + } + CreateContextsAsNecessary(); + if (HandleExitConditions()) { + return true; + } + SendInferRequests(); + if (HandleExitConditions()) { + return true; + } + WaitForResponses(); + if (HandleExitConditions()) { + return true; + } + return false; +} + void ConcurrencyWorker::CreateCtxIdTracker() { diff --git a/src/c++/perf_analyzer/concurrency_worker.h b/src/c++/perf_analyzer/concurrency_worker.h index 746bef296..94cb90fbe 100644 --- a/src/c++/perf_analyzer/concurrency_worker.h +++ b/src/c++/perf_analyzer/concurrency_worker.h @@ -50,9 +50,11 @@ class NaggyMockConcurrencyWorker; class ConcurrencyWorker : public LoadWorker { public: struct ThreadConfig { - ThreadConfig(size_t thread_id) - : thread_id_(thread_id), concurrency_(0), seq_stat_index_offset_(0), - is_paused_(false) + ThreadConfig( + size_t thread_id, size_t concurrency = 0, + size_t seq_stat_index_offset = 0) + : thread_id_(thread_id), concurrency_(concurrency), + seq_stat_index_offset_(seq_stat_index_offset), is_paused_(false) { } @@ -91,7 +93,15 @@ class ConcurrencyWorker : public LoadWorker { { } - void Infer() override; + virtual void Infer() override; + + protected: + bool RunInference(); + + void CreateCtxIdTracker(); + + // Reserve vector size for contexts + void ReserveContexts(); private: const size_t max_concurrency_; @@ -101,11 +111,6 @@ class ConcurrencyWorker : public LoadWorker { std::shared_ptr thread_config_; - void CreateCtxIdTracker(); - - // Reserve vector size for contexts - void ReserveContexts(); - // Handle the case where execute_ is false void HandleExecuteOff(); diff --git a/src/c++/perf_analyzer/infer_context.cc b/src/c++/perf_analyzer/infer_context.cc index 16855ca53..100a8a1d8 100644 --- a/src/c++/perf_analyzer/infer_context.cc +++ b/src/c++/perf_analyzer/infer_context.cc @@ -258,6 +258,7 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result) return; } it->second.response_times_.push_back(std::chrono::system_clock::now()); + num_responses_++; if (is_null_response == true) { it->second.has_null_last_response_ = true; } @@ -267,6 +268,7 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result) return; } if (is_final_response) { + has_received_final_response_ = is_final_response; thread_stat_->request_records_.emplace_back( it->second.start_time_, it->second.response_times_, it->second.sequence_end_, it->second.delayed_, @@ -279,8 +281,13 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result) } } + if (worker_callback_) { + worker_callback_(id_); + } + if (is_final_response) { total_ongoing_requests_--; + num_responses_ = 0; if (async_callback_finalize_func_ != nullptr) { async_callback_finalize_func_(id_); diff --git a/src/c++/perf_analyzer/infer_context.h b/src/c++/perf_analyzer/infer_context.h index fb048546e..912cd2ca0 100644 --- a/src/c++/perf_analyzer/infer_context.h +++ b/src/c++/perf_analyzer/infer_context.h @@ -116,18 +116,28 @@ class InferContext { // object and have not returned uint GetNumOngoingRequests() { return total_ongoing_requests_; } + // Returns the number of responses for the current request + uint64_t GetNumResponsesForCurrentRequest() { return num_responses_; } + // Register a function that will get called after every async request returns void RegisterAsyncCallbackFinalize(std::function callback) { async_callback_finalize_func_ = callback; } + void RegisterWorkerCallback(std::function worker_callback) + { + worker_callback_ = worker_callback; + } + // TODO REFACTOR TMA-1043 this should be in memory class void SetNumActiveThreads(size_t num_threads) { num_active_threads_ = num_threads; } + bool HasReceivedFinalResponse() { return has_received_final_response_; } + protected: /// A helper function to issue inference request to the server. /// \param request_id The unique id to be associated with the request. @@ -191,6 +201,9 @@ class InferContext { std::reference_wrapper execute_{execute_placeholder_}; std::shared_ptr sequence_manager_{nullptr}; + uint64_t num_responses_{0}; + std::function worker_callback_{nullptr}; + bool has_received_final_response_{false}; #ifndef DOCTEST_CONFIG_DISABLE friend NaggyMockInferContext; diff --git a/src/c++/perf_analyzer/inference_profiler.h b/src/c++/perf_analyzer/inference_profiler.h index 6274e1bd8..913b23ded 100644 --- a/src/c++/perf_analyzer/inference_profiler.h +++ b/src/c++/perf_analyzer/inference_profiler.h @@ -43,6 +43,7 @@ #include "metrics_manager.h" #include "model_parser.h" #include "mpi_utils.h" +#include "periodic_concurrency_manager.h" #include "profile_data_collector.h" #include "request_rate_manager.h" @@ -306,6 +307,18 @@ class InferenceProfiler { return cb::Error::Success; } + cb::Error ProfilePeriodicConcurrencyMode() + { + auto& manager{dynamic_cast(*manager_)}; + std::vector request_records{manager.RunExperiment()}; + // FIXME - Refactor collector class to not need ID or window in the case of + // periodic concurrency mode + InferenceLoadMode id{1, 0.0}; + collector_->AddWindow(id, 0, UINT64_MAX); + collector_->AddData(id, std::move(request_records)); + return cb::Error::Success; + } + bool IncludeServerStats() { return include_server_stats_; } private: diff --git a/src/c++/perf_analyzer/perf_analyzer.cc b/src/c++/perf_analyzer/perf_analyzer.cc index 6ae375034..44ec520f2 100644 --- a/src/c++/perf_analyzer/perf_analyzer.cc +++ b/src/c++/perf_analyzer/perf_analyzer.cc @@ -27,6 +27,7 @@ #include "perf_analyzer.h" #include "perf_analyzer_exception.h" +#include "periodic_concurrency_manager.h" #include "report_writer.h" #include "request_rate_manager.h" @@ -159,6 +160,12 @@ PerfAnalyzer::CreateAnalyzerObjects() } std::unique_ptr manager; + params_->is_using_periodic_concurrency_mode = true; + params_->periodic_concurrency_range = { + std::stoi(std::getenv("MY_START")), std::stoi(std::getenv("MY_END")), + std::stoi(std::getenv("MY_STEP"))}; + params_->periodic_concurrency_request_period = + std::stoi(std::getenv("MY_REQUEST_PERIOD")); if (params_->targeting_concurrency()) { if ((parser_->SchedulerType() == pa::ModelParser::SEQUENCE) || @@ -209,6 +216,13 @@ PerfAnalyzer::CreateAnalyzerObjects() factory, &manager), "failed to create concurrency manager"); + } else if (params_->is_using_periodic_concurrency_mode) { + manager = std::make_unique( + params_->async, params_->streaming, params_->batch_size, + params_->max_threads, params_->max_concurrency, + params_->shared_memory_type, params_->output_shm_size, parser_, factory, + params_->periodic_concurrency_range, + params_->periodic_concurrency_request_period); } else if (params_->using_request_rate_range) { if ((params_->sequence_id_range != 0) && (params_->sequence_id_range < params_->num_of_sequences)) { @@ -370,6 +384,8 @@ PerfAnalyzer::Profile() err = profiler_->Profile( params_->concurrency_range.start, params_->concurrency_range.end, params_->concurrency_range.step, params_->search_mode, perf_statuses_); + } else if (params_->is_using_periodic_concurrency_mode) { + err = profiler_->ProfilePeriodicConcurrencyMode(); } else { err = profiler_->Profile( params_->request_rate_range[pa::SEARCH_RANGE::kSTART], @@ -393,7 +409,7 @@ PerfAnalyzer::Profile() void PerfAnalyzer::WriteReport() { - if (!perf_statuses_.size()) { + if (!perf_statuses_.size() || params_->is_using_periodic_concurrency_mode) { return; } diff --git a/src/c++/perf_analyzer/periodic_concurrency_manager.cc b/src/c++/perf_analyzer/periodic_concurrency_manager.cc new file mode 100644 index 000000000..1a5527b7b --- /dev/null +++ b/src/c++/perf_analyzer/periodic_concurrency_manager.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "periodic_concurrency_manager.h" + +namespace triton { namespace perfanalyzer { + +std::vector +PeriodicConcurrencyManager::RunExperiment() +{ + AddConcurrentRequests(concurrency_range_.start); + WaitForRequestsToFinish(); + return GetRequestRecords(); +} + +std::shared_ptr +PeriodicConcurrencyManager::MakeWorker( + std::shared_ptr thread_stat, + std::shared_ptr thread_config) +{ + uint32_t id = workers_.size(); + auto worker = std::make_shared( + id, thread_stat, thread_config, parser_, data_loader_, factory_, + on_sequence_model_, async_, max_concurrency_, using_json_data_, + streaming_, batch_size_, wake_signal_, wake_mutex_, active_threads_, + execute_, infer_data_manager_, sequence_manager_, request_period_, + period_completed_callback_, request_completed_callback_); + return worker; +}; + +void +PeriodicConcurrencyManager::AddConcurrentRequests( + uint64_t num_concurrent_requests) +{ + for (size_t i = 0; i < num_concurrent_requests; i++) { + AddConcurrentRequest(i); + } + num_incomplete_periods_ = num_concurrent_requests; +} + +void +PeriodicConcurrencyManager::AddConcurrentRequest(size_t seq_stat_index_offset) +{ + threads_stat_.emplace_back(std::make_shared()); + threads_config_.emplace_back( + std::make_shared( + threads_config_.size(), 1, seq_stat_index_offset)); + workers_.emplace_back( + MakeWorker(threads_stat_.back(), threads_config_.back())); + threads_.emplace_back(&IWorker::Infer, workers_.back()); + active_threads_++; +} + +void +PeriodicConcurrencyManager::PeriodCompletedCallback() +{ + std::lock_guard lock(period_completed_callback_mutex_); + num_incomplete_periods_--; + if (num_incomplete_periods_ == 0) { + steps_completed_++; + uint64_t num_requests_sent{steps_completed_ * concurrency_range_.step}; + if (num_requests_sent < concurrency_range_.end) { + AddConcurrentRequests(concurrency_range_.step); + } + } +} + +void +PeriodicConcurrencyManager::RequestCompletedCallback() +{ + std::lock_guard lock(request_completed_callback_mutex_); + num_completed_requests_++; + if (num_completed_requests_ == concurrency_range_.end) { + all_requests_completed_promise_.set_value(true); + } +} + +void +PeriodicConcurrencyManager::WaitForRequestsToFinish() +{ + std::future all_requests_completed_future{ + all_requests_completed_promise_.get_future()}; + all_requests_completed_future.get(); +} + +std::vector +PeriodicConcurrencyManager::GetRequestRecords() +{ + std::vector request_records{}; + for (const auto& thread_stat : threads_stat_) { + request_records.insert( + request_records.end(), thread_stat->request_records_.cbegin(), + thread_stat->request_records_.cend()); + } + return request_records; +} + +}} // namespace triton::perfanalyzer diff --git a/src/c++/perf_analyzer/periodic_concurrency_manager.h b/src/c++/perf_analyzer/periodic_concurrency_manager.h new file mode 100644 index 000000000..dca2797b7 --- /dev/null +++ b/src/c++/perf_analyzer/periodic_concurrency_manager.h @@ -0,0 +1,89 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include + +#include "concurrency_manager.h" +#include "periodic_concurrency_worker.h" + +namespace triton { namespace perfanalyzer { + +/// @brief Concurrency manager for periodically increasing concurrency by a step +/// amount based on the number of responses received (request period) by the +/// latest N (step or start concurrency for first-issued concurrent requests) +/// concurrent requests/workers. +class PeriodicConcurrencyManager : public ConcurrencyManager { + public: + PeriodicConcurrencyManager( + const bool async, const bool streaming, const int32_t batch_size, + const size_t max_threads, const size_t max_concurrency, + const SharedMemoryType shared_memory_type, const size_t output_shm_size, + const std::shared_ptr& parser, + const std::shared_ptr& factory, + const Range concurrency_range, const uint64_t request_period) + : ConcurrencyManager( + async, streaming, batch_size, max_threads, max_concurrency, + shared_memory_type, output_shm_size, parser, factory), + concurrency_range_(concurrency_range), request_period_(request_period) + { + } + + std::vector RunExperiment(); + + private: + std::shared_ptr MakeWorker( + std::shared_ptr thread_stat, + std::shared_ptr thread_config) + override; + + void AddConcurrentRequests(uint64_t num_concurrent_requests); + + void AddConcurrentRequest(size_t seq_stat_index_offset); + + void PeriodCompletedCallback(); + + void RequestCompletedCallback(); + + void WaitForRequestsToFinish(); + + std::vector GetRequestRecords(); + + Range concurrency_range_{1, 1, 1}; + uint64_t request_period_{0}; + uint64_t steps_completed_{0}; + uint64_t num_incomplete_periods_{0}; + uint64_t num_completed_requests_{0}; + std::mutex period_completed_callback_mutex_{}; + std::mutex request_completed_callback_mutex_{}; + std::promise all_requests_completed_promise_{}; + std::function period_completed_callback_{ + std::bind(&PeriodicConcurrencyManager::PeriodCompletedCallback, this)}; + std::function request_completed_callback_{ + std::bind(&PeriodicConcurrencyManager::RequestCompletedCallback, this)}; +}; + +}} // namespace triton::perfanalyzer diff --git a/src/c++/perf_analyzer/periodic_concurrency_worker.cc b/src/c++/perf_analyzer/periodic_concurrency_worker.cc new file mode 100644 index 000000000..9fbaee3cc --- /dev/null +++ b/src/c++/perf_analyzer/periodic_concurrency_worker.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "periodic_concurrency_worker.h" + +namespace triton { namespace perfanalyzer { + +void +PeriodicConcurrencyWorker::Infer() +{ + CreateCtxIdTracker(); + ReserveContexts(); + RunInference(); +} + +std::shared_ptr +PeriodicConcurrencyWorker::CreateInferContext() +{ + std::shared_ptr infer_context{std::make_shared( + id_, ctxs_.size(), async_, streaming_, on_sequence_model_, + using_json_data_, batch_size_, thread_stat_, data_loader_, parser_, + factory_, execute_, infer_data_manager_, sequence_manager_)}; + infer_context->RegisterWorkerCallback(worker_callback_); + return infer_context; +} + +void +PeriodicConcurrencyWorker::WorkerCallback(uint32_t infer_context_id) +{ + if (ctxs_.at(infer_context_id)->GetNumResponsesForCurrentRequest() == + request_period_) { + period_completed_callback_(); + } + if (ctxs_.at(infer_context_id)->HasReceivedFinalResponse()) { + request_completed_callback_(); + } +} + +}} // namespace triton::perfanalyzer diff --git a/src/c++/perf_analyzer/periodic_concurrency_worker.h b/src/c++/perf_analyzer/periodic_concurrency_worker.h new file mode 100644 index 000000000..7242219b9 --- /dev/null +++ b/src/c++/perf_analyzer/periodic_concurrency_worker.h @@ -0,0 +1,80 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include + +#include "concurrency_worker.h" + +namespace triton { namespace perfanalyzer { + +/// @brief Worker class for periodic concurrency mode. Issues one request only +/// and waits for all responses to come in. Notifies manager when N responses +/// (request period) have been received. Notifies manager when final response +/// has been received. +class PeriodicConcurrencyWorker : public ConcurrencyWorker { + public: + PeriodicConcurrencyWorker( + uint32_t id, std::shared_ptr thread_stat, + std::shared_ptr thread_config, + const std::shared_ptr parser, + std::shared_ptr data_loader, + const std::shared_ptr factory, + const bool on_sequence_model, const bool async, + const size_t max_concurrency, const bool using_json_data, + const bool streaming, const int32_t batch_size, + std::condition_variable& wake_signal, std::mutex& wake_mutex, + size_t& active_threads, bool& execute, + const std::shared_ptr& infer_data_manager, + std::shared_ptr sequence_manager, + uint64_t request_period, std::function period_completed_callback, + std::function request_completed_callback) + : ConcurrencyWorker( + id, thread_stat, thread_config, parser, data_loader, factory, + on_sequence_model, async, max_concurrency, using_json_data, + streaming, batch_size, wake_signal, wake_mutex, active_threads, + execute, infer_data_manager, sequence_manager), + request_period_(request_period), + period_completed_callback_(period_completed_callback), + request_completed_callback_(request_completed_callback) + { + } + + void Infer() override; + + std::shared_ptr CreateInferContext() override; + + void WorkerCallback(uint32_t infer_context_id); + + private: + uint64_t request_period_{0}; + std::function period_completed_callback_{nullptr}; + std::function request_completed_callback_{nullptr}; + std::function worker_callback_{std::bind( + &PeriodicConcurrencyWorker::WorkerCallback, this, std::placeholders::_1)}; +}; + +}} // namespace triton::perfanalyzer