diff --git a/torchrec/inference/include/torchrec/inference/SingleGPUExecutor.h b/torchrec/inference/include/torchrec/inference/SingleGPUExecutor.h index cef83bc8c..8b9e910f5 100644 --- a/torchrec/inference/include/torchrec/inference/SingleGPUExecutor.h +++ b/torchrec/inference/include/torchrec/inference/SingleGPUExecutor.h @@ -33,7 +33,8 @@ class SingleGPUExecutor { size_t numGpu, std::shared_ptr observer = std::make_shared(), - c10::Device resultDevice = c10::kCPU); + c10::Device resultDevice = c10::kCPU, + size_t numProcessThreads = 1u); // Moveable only SingleGPUExecutor(SingleGPUExecutor&& executor) noexcept = default; @@ -48,12 +49,13 @@ class SingleGPUExecutor { std::shared_ptr manager_; const ExecInfos execInfos_; const size_t numGpu_; + const size_t numProcessThreads_; const c10::Device resultDevice_; std::shared_ptr observer_; folly::MPMCQueue> requests_; + std::unique_ptr processExecutor_; std::unique_ptr completionExecutor_; std::atomic roundRobinExecInfoNextIdx_; - std::thread processThread_; }; } // namespace torchrec diff --git a/torchrec/inference/src/SingleGPUExecutor.cpp b/torchrec/inference/src/SingleGPUExecutor.cpp index bde1ee92b..83b502741 100644 --- a/torchrec/inference/src/SingleGPUExecutor.cpp +++ b/torchrec/inference/src/SingleGPUExecutor.cpp @@ -19,17 +19,23 @@ SingleGPUExecutor::SingleGPUExecutor( ExecInfos execInfos, size_t numGpu, std::shared_ptr observer, - c10::Device resultDevice) + c10::Device resultDevice, + size_t numProcessThreads) : manager_(manager), execInfos_(std::move(execInfos)), numGpu_(numGpu), + numProcessThreads_(numProcessThreads), resultDevice_(resultDevice), observer_(observer), requests_(kQUEUE_CAPACITY), + processExecutor_( + std::make_unique(numProcessThreads)), completionExecutor_( std::make_unique(execInfos_.size())), - roundRobinExecInfoNextIdx_(0u), - processThread_([&]() { process(); }) { + roundRobinExecInfoNextIdx_(0u) { + for (size_t i = 0; i < numProcessThreads_; ++i) { + processExecutor_->add([&]() { process(); }); + } for (const auto& exec_info : execInfos_) { TORCHREC_CHECK(exec_info.interpIdx < manager_->allInstances().size()); } @@ -37,8 +43,10 @@ SingleGPUExecutor::SingleGPUExecutor( } SingleGPUExecutor::~SingleGPUExecutor() { - requests_.blockingWrite(nullptr); - processThread_.join(); + for (size_t i = 0; i < numProcessThreads_; ++i) { + requests_.blockingWrite(nullptr); + } + processExecutor_->join(); completionExecutor_->join(); }