From cd0587b29ca629f9d3fa20faf677421339ce1b8a Mon Sep 17 00:00:00 2001 From: asaigal Date: Tue, 14 May 2024 17:44:09 +0000 Subject: [PATCH] #8264: Async Engine Optimizations - copy_borrowed_tensor_in_async_mode does not stall for device tensors anymore - Typechecking moved to compile time - work_executor optimizations: Pass shared ptrs down to workers, instead of lambda objects. Lock Free Queue is now statically initialized - launch_op optimization: lambda initialized outside multi-device for loop - Tensor deallocate optimization: Pass attribute ptr to lambda instead of passing entire tensor object - System Level Optimizations: Set process priority to 0. Bind CQ reader to core and use CV to toggle its state instead of calling sleep --- tt_eager/tensor/tensor.cpp | 52 +++++-------- tt_eager/tensor/tensor_utils.cpp | 72 ++++++++++------- tt_eager/tensor/types.hpp | 20 +++-- tt_eager/tt_dnn/op_library/run_operation.cpp | 72 ++++++++++------- tt_metal/impl/device/device.cpp | 4 + tt_metal/impl/device/device.hpp | 1 + tt_metal/impl/dispatch/command_queue.cpp | 42 +++++++--- tt_metal/impl/dispatch/command_queue.hpp | 6 ++ tt_metal/impl/dispatch/lock_free_queue.hpp | 60 +++++++++++---- tt_metal/impl/dispatch/work_executor.hpp | 81 ++++++++++++++------ 10 files changed, 271 insertions(+), 139 deletions(-) diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index 89fb187f5b6..2515e21d509 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -114,18 +114,16 @@ void Tensor::deallocate(bool force) { uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or not this->tensor_attributes->main_thread_tensor) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count; if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) { this->tensor_attributes->deallocated = true; - // Record ref count before sending to worker - uint32_t device_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count(); - this->workers.at(0)->push_work([force, *this] () mutable { + this->workers.at(0)->push_work(std::make_shared>([force, attr = this->tensor_attributes] () mutable { // Cross worker synchronization: If the tensor being deallocated is shared across workers (ex: all_gather op), // wait until all workers are done with this tensor before deallocating. - bool num_threads_sharing_tensor = this->tensor_attributes->num_sibling_workers_sharing_tensor; + bool num_threads_sharing_tensor = attr->num_sibling_workers_sharing_tensor; if (num_threads_sharing_tensor) { while (num_threads_sharing_tensor) { - num_threads_sharing_tensor = this->tensor_attributes->num_sibling_workers_sharing_tensor;; + num_threads_sharing_tensor = attr->num_sibling_workers_sharing_tensor;; } } - std::visit([force, this] (auto&& s) { + std::visit([force, attr] (auto&& s) { using type = std::decay_t; if constexpr (std::is_same_v) { if (force or s.buffer.use_count() == 1) { @@ -138,13 +136,11 @@ void Tensor::deallocate(bool force) { } else if constexpr(std::is_same_v) { // Manage Dynamic Storage (due to autoformat in async mode): Main thread sees this tensor as a device tensor, since worker has not updated // storage time. When the worker executes the dealloc request, the storage type has been appropriately updated to Owned. - TT_ASSERT(this->tensor_attributes->dynamic_storage, "Tensor storage type changed during runtime (device -> host), but dynamic storage was not marked."); + TT_ASSERT(attr->dynamic_storage, "Tensor storage type changed during runtime (device -> host), but dynamic storage was not marked."); std::visit([] (auto&& buffer) { buffer.reset(); }, s.buffer); } - }, this->tensor_attributes->storage); - }); - // Update ref count after sending to worker - this->tensor_attributes->update_main_thread_ref_count(this->workers.at(0), device_tensor_ref_count); + }, attr->storage); + })); } } else { TT_FATAL(this->deallocate_through_destructor, "Device tensors created in the main thread cannot be explictly deallocated in worker threads."); @@ -155,32 +151,26 @@ void Tensor::deallocate(bool force) { } } else if constexpr (std::is_same_v) { if (this->workers.at(0)->in_main_thread() or not this->tensor_attributes->main_thread_tensor) { - if (not this->tensor_attributes->main_thread_tensor) { - TT_ASSERT(not this->tensor_attributes->main_thread_ref_count, "main_thread_ref_count for tensors created inside a worker thread must be 0"); - } // If owned by the main thread, deallocate this tensor only from the main thread. If owned by worker thread, allow deallocation in worker and use shared_ptr ref count, since this is a thread_local tensor uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or not this->tensor_attributes->main_thread_tensor) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count; if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) { this->tensor_attributes->deallocated = true; - // Record ref count before sending to workers - uint32_t device_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count(); + auto dealloc_lambda = std::make_shared>([force, attr = this->tensor_attributes] (Device* worker) mutable { + ZoneScopedN("ShardDeallocate"); + auto& s = std::get(attr->storage); + if (s.buffers.find(worker->id()) != s.buffers.end()) { + if ((force or s.buffers.at(worker->id()).use_count() == 1)) { + DeallocateBuffer(*(s.buffers.at(worker->id()))); + } + s.buffers.at(worker->id()).reset(); + } + }); + for (auto worker : this->workers) { - worker->push_work([force, *this, worker] () mutable { - std::visit([force, worker] (auto&& s) { - using type = std::decay_t; - if constexpr (std::is_same_v) { - if (s.buffers.find(worker->id()) != s.buffers.end()) { - if (force or s.buffers.at(worker->id()).use_count() == 1) { - DeallocateBuffer(*(s.buffers.at(worker->id()))); - } - s.buffers.at(worker->id()).reset(); - } - } - }, this->tensor_attributes->storage); - }); + worker->push_work(std::make_shared>([worker, dealloc_lambda] () mutable { + (*dealloc_lambda)(worker); + })); } - // Update ref count after sending to workers - this->tensor_attributes->update_main_thread_ref_count(this->workers.at(0), device_tensor_ref_count); } } else { TT_FATAL(this->deallocate_through_destructor, "Device tensors created in the main thread cannot be explictly deallocated in worker threads."); diff --git a/tt_eager/tensor/tensor_utils.cpp b/tt_eager/tensor/tensor_utils.cpp index a275169749d..cd913e0675c 100644 --- a/tt_eager/tensor/tensor_utils.cpp +++ b/tt_eager/tensor/tensor_utils.cpp @@ -346,47 +346,63 @@ uint32_t num_buffers_in_tensor(const Tensor& tensor) { } else if (std::holds_alternative(tensor.get_storage()) || std::holds_alternative(tensor.get_storage()) || std::holds_alternative(tensor.get_storage())) { return 1; } else { - TT_FATAL(false, "get_shard_for_device only supports multi-device or device tensors"); + TT_FATAL(false, "num_buffers_in_tensor only supports multi-device or device tensors"); } } Tensor get_shard_for_device(const Tensor& tensor, Device* target_device, std::optional buffer_index) { - if (std::holds_alternative(tensor.get_storage())) { - auto device_storage = std::get(tensor.get_storage()); - auto shard_shape = device_storage.get_tensor_shape_for_device(target_device); - auto shard_buffer = device_storage.get_buffer_for_device(target_device); - return Tensor{DeviceStorage{shard_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()}; - } else if (std::holds_alternative(tensor.get_storage())) { - auto host_storage = std::get(tensor.get_storage()); - auto shard_shape = host_storage.get_tensor_shape(buffer_index.value()); - auto shard_buffer = host_storage.get_buffer(buffer_index.value()); - return Tensor{OwnedStorage{shard_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()}; - } else if (std::holds_alternative(tensor.get_storage()) || std::holds_alternative(tensor.get_storage()) || std::holds_alternative(tensor.get_storage())) { - return tensor; - } else { - TT_FATAL(false, "get_shard_for_device only supports multi-device or device tensors"); - } + ZoneScopedN("GetShardForDevice"); + Tensor shard = Tensor(); + auto& storage = tensor.get_storage(); + std::visit([target_device, buffer_index, &tensor, &shard] (auto&& s) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + auto shard_shape = s.get_tensor_shape_for_device(target_device); + auto shard_buffer = s.get_buffer_for_device(target_device); + shard = Tensor{DeviceStorage{shard_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()}; + } else if constexpr (std::is_same_v) { + auto shard_shape = s.get_tensor_shape(buffer_index.value()); + auto shard_buffer = s.get_buffer(buffer_index.value()); + shard = Tensor{OwnedStorage{shard_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()}; + } else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + shard = tensor; + } else { + TT_FATAL(false, "get_shard_for_device only supports multi-device or device tensors"); + } + }, storage); + return shard; } void insert_buffer_and_shape_for_device(Device* target_device, const Tensor& shard, Tensor& tensor_to_modify, std::optional buffer_index) { - if (std::holds_alternative(tensor_to_modify.tensor_attributes->storage)) { - std::get(tensor_to_modify.tensor_attributes->storage).insert_buffer_and_shape_for_device(buffer_index.value(), std::get(shard.get_storage()).get_buffer(), shard.get_legacy_shape()); - } else if (std::holds_alternative(tensor_to_modify.tensor_attributes->storage)) { - std::get(tensor_to_modify.tensor_attributes->storage).insert_buffer_and_shape_for_device(target_device, std::get(shard.get_storage()).get_buffer(), shard.get_legacy_shape()); - } else if (std::holds_alternative(tensor_to_modify.tensor_attributes->storage)) { - std::get(tensor_to_modify.tensor_attributes->storage).insert_buffer(std::get(shard.get_storage()).get_buffer()); - } else if (std::holds_alternative(tensor_to_modify.tensor_attributes->storage)) { - std::get(tensor_to_modify.tensor_attributes->storage).insert_buffer(std::get(shard.get_storage()).get_buffer()); - } else { - TT_FATAL(false, "Unsupported storage in insert_buffer_and_shape_for_device"); - } + ZoneScopedN("InsertBufferAndShapeForDevice"); + std::visit([target_device, &shard, &tensor_to_modify, buffer_index] (auto&& s) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + s.insert_buffer_and_shape_for_device(buffer_index.value(), std::get(shard.get_storage()).get_buffer(), shard.get_legacy_shape()); + } else if constexpr (std::is_same_v) { + s.insert_buffer_and_shape_for_device(target_device, std::get(shard.get_storage()).get_buffer(), shard.get_legacy_shape()); + } else if constexpr (std::is_same_v) { + s.insert_buffer(std::get(shard.get_storage()).get_buffer()); + } else if constexpr (std::is_same_v) { + s.insert_buffer(std::get(shard.get_storage()).get_buffer()); + } else { + TT_FATAL(false, "Unsupported storage in insert_buffer_and_shape_for_device"); + } + }, tensor_to_modify.tensor_attributes->storage); } + Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor) { // When using async mode, tensors with borrowed storage cannot be passed to workers. // They need to be copied to owned storage before being passed to the worker. ZoneScopedN("ConvertBorrowedToOwned"); - if (worker->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS and tensor.storage_type() == StorageType::BORROWED) { + // Tensor has workers (on device) or runtime mode is synchronous or tensor has multiple buffers. + // No need to check for borrowed storage. + if (worker->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or + tensor.get_workers().size() or + tensor.tensor_attributes->tensor_populated.size() > 1) return tensor; + + if (tensor.storage_type() == StorageType::BORROWED) { ZoneScopedN("CopyBorrowedStorage"); auto borrowed_buffer = std::get(tensor.get_storage()).buffer; Tensor owned_tensor; diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index c2b9bc29d40..0332081201a 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -366,22 +366,25 @@ struct MultiDeviceHostStorage { DistributedTensorConfig strategy; std::vector buffers; std::vector shapes; - std::mutex mtx; + mutable std::mutex mtx; MultiDeviceHostStorage() = default; MultiDeviceHostStorage(DistributedTensorConfig strategy_, std::vector buffers_, std::vector shapes_) : strategy(strategy_), buffers(buffers_), shapes(shapes_) {} MultiDeviceHostStorage(MultiDeviceHostStorage &&other) { + std::lock_guard lock(mtx); strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; } MultiDeviceHostStorage(const MultiDeviceHostStorage &other) { + std::lock_guard lock(mtx); strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; } MultiDeviceHostStorage &operator=(const MultiDeviceHostStorage &other) { + std::lock_guard lock(mtx); strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; @@ -389,6 +392,7 @@ struct MultiDeviceHostStorage { } MultiDeviceHostStorage &operator=( MultiDeviceHostStorage &&other) { + std::lock_guard lock(mtx); strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; @@ -410,13 +414,13 @@ struct MultiDeviceHostStorage { shapes[buffer_index] = shape; } - OwnedBuffer get_buffer(int buffer_index) { + OwnedBuffer get_buffer(int buffer_index) const { std::lock_guard lock(mtx); TT_FATAL(buffer_index < buffers.size(), "Buffer not found for buffer_index " + std::to_string(buffer_index)); return buffers[buffer_index];; } - Shape get_tensor_shape(int shape_index) { + Shape get_tensor_shape(int shape_index) const { std::lock_guard lock(mtx); TT_FATAL(shape_index < shapes.size(), "Buffer not found for device " + std::to_string(shape_index)); return shapes[shape_index]; @@ -443,12 +447,14 @@ struct MultiDeviceHostStorage { std::unordered_map shapes_) : strategy(strategy_), ordered_device_ids(ordered_device_ids_), buffers(buffers_), shapes(shapes_) {} MultiDeviceStorage(MultiDeviceStorage &&other) { + std::lock_guard lock(mtx); ordered_device_ids = other.ordered_device_ids; strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; } MultiDeviceStorage(const MultiDeviceStorage &other) { + std::lock_guard lock(other.mtx); ordered_device_ids = other.ordered_device_ids; strategy = other.strategy; buffers = other.buffers; @@ -456,6 +462,7 @@ struct MultiDeviceHostStorage { } MultiDeviceStorage &operator=(const MultiDeviceStorage &other) { + std::lock_guard lock(other.mtx); ordered_device_ids = other.ordered_device_ids; strategy = other.strategy; buffers = other.buffers; @@ -464,6 +471,7 @@ struct MultiDeviceHostStorage { } MultiDeviceStorage &operator=( MultiDeviceStorage &&other) { + std::lock_guard lock(mtx); ordered_device_ids = other.ordered_device_ids; strategy = other.strategy; buffers = other.buffers; @@ -497,18 +505,20 @@ struct MultiDeviceHostStorage { // Helper Functions - Getters and setters to get/modify storage attributes. These are needed to // preinitialize empty tensor handles and use/populate them in the worker threads. void insert_buffer_and_shape_for_device(Device* device, const DeviceBuffer buffer, const Shape shape) { + TT_FATAL(device == buffer->device(), "Mismatch between device derived from buffer and device derived from MultiDeviceStorage."); std::lock_guard lock(mtx); buffers.insert({device->id(), buffer}); shapes.insert({device->id(), shape}); } - DeviceBuffer get_buffer_for_device(Device* device) { + DeviceBuffer get_buffer_for_device(Device* device) const { std::lock_guard lock(mtx); TT_FATAL(buffers.find(device->id()) != buffers.end(), "Buffer not found for device " + std::to_string(device->id())); + TT_FATAL(buffers.at(device->id())->device() == device, "Mismatch between device derived from buffer and device derived from MultiDeviceStorage."); return buffers.at(device->id()); } - Shape get_tensor_shape_for_device(Device* device) { + Shape get_tensor_shape_for_device(Device* device) const { std::lock_guard lock(mtx); TT_FATAL(shapes.find(device->id()) != shapes.end(), "Shape not found for device " + std::to_string(device->id())); return shapes.at(device->id()); diff --git a/tt_eager/tt_dnn/op_library/run_operation.cpp b/tt_eager/tt_dnn/op_library/run_operation.cpp index 9d33e590526..05f7747ad5d 100644 --- a/tt_eager/tt_dnn/op_library/run_operation.cpp +++ b/tt_eager/tt_dnn/op_library/run_operation.cpp @@ -654,6 +654,7 @@ void launch_op( // Assert to ensure that worker threads are specified. ZoneScopedN("LaunchOp"); auto& workers = output_tensors.at(0).workers; + std::size_t workers_size = workers.size(); if (not enable_autoformat_device and workers.empty()) { // Run on the host output_tensors = op_func(input_tensors, optional_input_tensors, optional_output_tensors); @@ -665,47 +666,47 @@ void launch_op( } validate_worker_modes(workers); // Record ref counts for all tensors before pushing to worker queue. - std::vector input_tensor_ref_count = {}; - std::vector optional_input_tensor_ref_count = {}; - std::vector output_tensor_ref_count = {}; - std::vector optional_output_tensor_ref_count = {}; + std::vector input_tensor_ref_count = std::vector(input_tensors.size()); + std::vector optional_input_tensor_ref_count = std::vector(optional_input_tensors.size()); + std::vector output_tensor_ref_count = std::vector(output_tensors.size()); + std::vector optional_output_tensor_ref_count = std::vector(optional_output_tensors.size());; - std::vector async_safe_input_tensors = {}; + std::vector async_safe_input_tensors = std::vector(input_tensors.size()); std::vector> async_safe_optional_input_tensors = {}; std::unordered_set cross_worker_input_tensor_idx = {}; std::unordered_set cross_worker_optional_input_tensor_idx = {}; // When running on a single device, input tensors can be using borrowed storage. If so, when running in async mode, // copy borrowed tensors to owned storage. for (int i = 0; i < input_tensors.size(); i++) { - async_safe_input_tensors.push_back(copy_borrowed_tensor_in_async_mode(workers.at(0), input_tensors.at(i))); - input_tensor_ref_count.push_back(async_safe_input_tensors[i].tensor_attributes->record_main_thread_ref_count()); + async_safe_input_tensors[i] = copy_borrowed_tensor_in_async_mode(workers.at(0), input_tensors.at(i)); + input_tensor_ref_count[i] = async_safe_input_tensors[i].tensor_attributes->record_main_thread_ref_count(); } for (int i = 0; i < optional_input_tensors.size(); i++) { if (optional_input_tensors[i].has_value()) { async_safe_optional_input_tensors.push_back(copy_borrowed_tensor_in_async_mode(workers.at(0), optional_input_tensors[i].value())); - optional_input_tensor_ref_count.push_back(async_safe_optional_input_tensors[i].value().tensor_attributes->record_main_thread_ref_count()); + optional_input_tensor_ref_count[i] = async_safe_optional_input_tensors[i].value().tensor_attributes->record_main_thread_ref_count(); } else { async_safe_optional_input_tensors.push_back(std::nullopt); - optional_input_tensor_ref_count.push_back(0); + optional_input_tensor_ref_count[i] = 0; } } for (int i = 0; i < output_tensors.size(); i++) { - output_tensor_ref_count.push_back(output_tensors[i].tensor_attributes->record_main_thread_ref_count()); + output_tensor_ref_count[i] = output_tensors[i].tensor_attributes->record_main_thread_ref_count(); } for (int i = 0; i < optional_output_tensors.size(); i++) { if (optional_output_tensors[i].has_value()) { - optional_output_tensor_ref_count.push_back(optional_output_tensors[i].value().tensor_attributes->record_main_thread_ref_count()); + optional_output_tensor_ref_count[i] = optional_output_tensors[i].value().tensor_attributes->record_main_thread_ref_count(); } else { - optional_output_tensor_ref_count.push_back(0); + optional_output_tensor_ref_count[i] = 0; } } // Check if this op dispatch step relies on tensors from other workers. // If so, mark them in use by current worker. Tensors shared across workers // are only supported when each tensor is tied to a single device/worker // (example all-gather). - if (workers.size() == 1) { + if (workers_size == 1) { // Single worker per tensor and. for (int i = 0; i < async_safe_input_tensors.size(); i++) { if (async_safe_input_tensors.at(i).get_workers().size() and async_safe_input_tensors.at(i).get_workers().at(0) != workers.at(0)) { @@ -724,17 +725,19 @@ void launch_op( { ZoneScopedN("PushOpToWorkers"); - for (auto target_device : workers) { - target_device->push_work([target_device, workers, op_func, optional_output_tensors, async_safe_optional_input_tensors, inputs = async_safe_input_tensors, outputs = output_tensors, shared_input_idx = cross_worker_input_tensor_idx, shared_optional_input_idx = cross_worker_optional_input_tensor_idx] () mutable { - - std::vector input_shards = {}; - std::vector> optional_input_shards = {}; - std::vector> optional_output_shards = {}; - // Initialize all optional_outputs to std::nullopt - optional_output_shards.resize(optional_output_tensors.size()); - for (const auto& input : inputs) { - input_shards.push_back(get_shard_for_device(input, target_device)); + auto work_lambda = std::make_shared>([workers_size, op_func, optional_output_tensors, async_safe_optional_input_tensors, inputs = async_safe_input_tensors, outputs = output_tensors, shared_input_idx = cross_worker_input_tensor_idx, shared_optional_input_idx = cross_worker_optional_input_tensor_idx] (Device* target_device) mutable { + std::vector input_shards = std::vector(inputs.size(), Tensor()); + std::vector> optional_input_shards = {}; + std::vector> optional_output_shards = {}; + // Initialize all optional_outputs to std::nullopt + optional_output_shards.resize(optional_output_tensors.size()); + + { + ZoneScopedN("CreateShards"); + for (int i = 0; i < input_shards.size(); i++) { + input_shards[i] = get_shard_for_device(inputs[i], target_device); } + for (auto& input : async_safe_optional_input_tensors) { if (input.has_value()) { optional_input_shards.push_back(get_shard_for_device(input.value(), target_device)); @@ -743,24 +746,31 @@ void launch_op( optional_input_shards.push_back(std::nullopt); } } + for (std::size_t optional_output_idx = 0; optional_output_idx < optional_output_tensors.size(); optional_output_idx++) { if (optional_output_tensors[optional_output_idx].has_value()) { optional_output_shards[optional_output_idx] = get_shard_for_device(optional_output_tensors[optional_output_idx].value(), target_device); } } - auto local_tensors = op_func(input_shards, optional_input_shards, optional_output_shards); + } + + auto local_tensors = op_func(input_shards, optional_input_shards, optional_output_shards); + + { + ZoneScopedN("OpPostProcess"); // Release shared ownership of tensors belonging to other workers. // If the workers for this tensor are stalled to deallocate for (auto& shared_input : shared_input_idx) { inputs.at(shared_input).tensor_attributes->num_sibling_workers_sharing_tensor--; } + for (auto& shared_optional_input : shared_optional_input_idx) { async_safe_optional_input_tensors.at(shared_optional_input).value().tensor_attributes->num_sibling_workers_sharing_tensor--; } + for (int i = 0; i < local_tensors.size(); i++) { if (local_tensors.at(i).storage_type() == StorageType::OWNED) { TT_ASSERT(outputs.at(i).tensor_attributes->dynamic_storage, "launch_with_autoformat must be used if output tensor for op can be placed on host."); - TT_ASSERT(std::holds_alternative(outputs.at(i).tensor_attributes->storage), "All inputs and outputs to an op must be on device for multi-device tensors."); // Make this a host side tensor - Set storage = Owned and clear workers outputs.at(i).tensor_attributes->storage = OwnedStorage(); outputs.at(i).workers = {}; @@ -769,19 +779,25 @@ void launch_op( outputs.at(i).tensor_attributes->dynamic_storage = false; } insert_buffer_and_shape_for_device(target_device, local_tensors.at(i), outputs.at(i)); - if (not target_device->id() or workers.size() == 1) { + if (not target_device->id() or workers_size == 1) { outputs.at(i).set_shape(local_tensors.at(i).get_shape()); outputs.at(i).set_dtype(local_tensors.at(i).get_dtype()); outputs.at(i).set_layout(local_tensors.at(i).get_layout()); } - if (workers.size() == 1) { + if (workers_size == 1) { outputs.at(i).set_populated(); } else { outputs.at(i).set_populated(target_device); } } - }); + } + }); + + for (auto target_device : workers) { + target_device->push_work(std::make_shared>([target_device, work_lambda] () mutable { + (*work_lambda)(target_device); + })); } } diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 36e77d2acf5..5f56184d6a7 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -1859,6 +1859,10 @@ void Device::push_work(std::function&& work, bool blocking) { this->work_executor.push_work(work, blocking); } +void Device::push_work(std::shared_ptr> work, bool blocking) { + this->work_executor.push_work(work, blocking); +} + void Device::synchronize() { this->work_executor.synchronize(); } diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index 92bc514756e..16f2f3fe936 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -230,6 +230,7 @@ class Device { // APIs to access this device's work executor void push_work(std::function&& work, bool blocking = false); + void push_work(std::shared_ptr> work, bool blocking = false); void synchronize(); void set_worker_mode(const WorkExecutorMode& mode); void enable_async(bool enable); diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index a9a45955508..27b4e03d929 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -1168,6 +1168,8 @@ HWCommandQueue::HWCommandQueue(Device* device, uint32_t id) : this->exit_condition = false; std::thread completion_queue_thread = std::thread(&HWCommandQueue::read_completion_queue, this); this->completion_queue_thread = std::move(completion_queue_thread); + // Set the affinity of the completion queue reader. + set_device_thread_affinity(this->completion_queue_thread, device->id()); this->expected_num_workers_completed = 0; } @@ -1182,11 +1184,29 @@ HWCommandQueue::~HWCommandQueue() { TT_ASSERT( this->num_entries_in_completion_q == this->num_completed_completion_q_reads, "There shouldn't be any commands in flight after closing our completion queue thread. Num uncompleted commands: {}", this->num_entries_in_completion_q - this->num_completed_completion_q_reads); - this->exit_condition = true; + this->set_exit_condition(); this->completion_queue_thread.join(); } } +void HWCommandQueue::increment_num_entries_in_completion_q() { + // Increment num_entries_in_completion_q and inform reader thread + // that there is work in the completion queue to process + this->num_entries_in_completion_q++; + { + std::lock_guard lock(this->reader_thread_cv_mutex); + this->reader_thread_cv.notify_one(); + } +} + +void HWCommandQueue::set_exit_condition() { + this->exit_condition = true; + { + std::lock_guard lock(this->reader_thread_cv_mutex); + this->reader_thread_cv.notify_one(); + } +} + template void HWCommandQueue::enqueue_command(T& command, bool blocking) { command.process(); @@ -1202,7 +1222,6 @@ void HWCommandQueue::enqueue_read_buffer(std::shared_ptr buffer, void* d // Read buffer command is enqueued in the issue region and device writes requested buffer data into the completion region void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blocking) { ZoneScopedN("HWCommandQueue_read_buffer"); - chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(this->device->id()); uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(this->device->id()); CoreType dispatch_core_type = dispatch_core_manager::get(this->device->num_hw_cqs()).get_dispatch_core_type(this->device->id()); @@ -1232,9 +1251,8 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin this->issued_completion_q_reads.push( detail::ReadBufferDescriptor(buffer, padded_page_size, dst, unpadded_dst_offset, num_pages_to_read, src_page_index, linear_page_copy) ); - this->num_entries_in_completion_q++; - this->enqueue_command(command, false); + this->increment_num_entries_in_completion_q(); } } if (blocking) { @@ -1251,9 +1269,8 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin this->issued_completion_q_reads.push( detail::ReadBufferDescriptor(buffer, padded_page_size, dst, unpadded_dst_offset, pages_to_read, src_page_index) ); - this->num_entries_in_completion_q++; - this->enqueue_command(command, blocking); + this->increment_num_entries_in_completion_q(); if (not blocking) { // should this be unconditional? std::shared_ptr event = std::make_shared(); this->enqueue_record_event(event); @@ -1475,7 +1492,7 @@ void HWCommandQueue::enqueue_record_event(std::shared_ptr event, bool cle this->trace_ctx->num_completion_q_reads++; } else { this->issued_completion_q_reads.push(detail::ReadEventDescriptor(event->event_id)); - this->num_entries_in_completion_q++; + this->increment_num_entries_in_completion_q(); } } @@ -1511,7 +1528,7 @@ void HWCommandQueue::enqueue_trace(const uint32_t trace_id, bool blocking) { } else if constexpr (std::is_same_v) { read_descriptor.set_global_offset(event_id); this->issued_completion_q_reads.push(read_descriptor); - this->num_entries_in_completion_q++; + this->increment_num_entries_in_completion_q(); num_events++; } }, @@ -1708,6 +1725,10 @@ void HWCommandQueue::read_completion_queue() { chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(this->device->id()); uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(this->device->id()); while (true) { + { + std::unique_lock lock(this->reader_thread_cv_mutex); + this->reader_thread_cv.wait(lock, [this] {return this->num_entries_in_completion_q > this->num_completed_completion_q_reads or this->exit_condition;}); + } if (this->num_entries_in_completion_q > this->num_completed_completion_q_reads) { uint32_t num_events_to_read = this->num_entries_in_completion_q - this->num_completed_completion_q_reads; for (uint32_t i = 0; i < num_events_to_read; i++) { @@ -1746,7 +1767,6 @@ void HWCommandQueue::read_completion_queue() { } else if (this->exit_condition) { return; } - std::this_thread::sleep_for(std::chrono::microseconds(10)); } } @@ -1760,13 +1780,13 @@ void HWCommandQueue::finish() { while (this->num_entries_in_completion_q > this->num_completed_completion_q_reads) { if (DPrintServerHangDetected()) { // DPrint Server hang. Mark state and early exit. Assert in main thread. - this->exit_condition = true; this->dprint_server_hang = true; + this->set_exit_condition(); return; } else if (tt::watcher_server_killed_due_to_error()) { // Illegal NOC txn killed watcher. Mark state and early exit. Assert in main thread. - this->exit_condition = true; this->illegal_noc_txn_hang = true; + this->set_exit_condition(); return; } } diff --git a/tt_metal/impl/dispatch/command_queue.hpp b/tt_metal/impl/dispatch/command_queue.hpp index 7a413e0f8b3..6acfe40a447 100644 --- a/tt_metal/impl/dispatch/command_queue.hpp +++ b/tt_metal/impl/dispatch/command_queue.hpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -494,6 +495,9 @@ class HWCommandQueue { Device* device; + std::condition_variable reader_thread_cv; + std::mutex reader_thread_cv_mutex; + CoreType get_dispatch_core_type(); void copy_into_user_space(const detail::ReadBufferDescriptor &read_buffer_descriptor, chip_id_t mmio_device_id, uint16_t channel); @@ -512,6 +516,8 @@ class HWCommandQueue { void enqueue_trace(const uint32_t trace_id, bool blocking); void finish(); void terminate(); + void increment_num_entries_in_completion_q(); + void set_exit_condition(); friend void EnqueueTraceImpl(CommandQueue& cq, uint32_t trace_id, bool blocking); friend void EnqueueProgramImpl(CommandQueue& cq, std::variant < std::reference_wrapper, std::shared_ptr > program, bool blocking); friend void EnqueueReadBufferImpl(CommandQueue& cq, std::variant, std::shared_ptr > buffer, void* dst, bool blocking); diff --git a/tt_metal/impl/dispatch/lock_free_queue.hpp b/tt_metal/impl/dispatch/lock_free_queue.hpp index 6909fcb0bb3..f6c6cd15422 100644 --- a/tt_metal/impl/dispatch/lock_free_queue.hpp +++ b/tt_metal/impl/dispatch/lock_free_queue.hpp @@ -23,7 +23,7 @@ class LockFreeQueue { std::atomic head; std::atomic tail; - Node* pop_head() { + inline Node* pop_head() { Node* oldHead = head.load(); if (oldHead == tail.load()) { return nullptr; // Queue is empty @@ -31,33 +31,67 @@ class LockFreeQueue { head.store(oldHead->next); return oldHead; } + // Statically allocated ring buffer containing + // node objects, which contain handles to data + // and another node object to traverse ring buffer. + const static uint32_t ring_buffer_size = 8192; + Node ring_buffer[ring_buffer_size]; public: // Optional - Set these if the worker and parent thread state needs to be tracked std::atomic worker_thread_id = 0; std::atomic parent_thread_id = 0; - LockFreeQueue() : head(new Node), tail(head.load()) {} + LockFreeQueue() + { + // Initialize ring buffer for traversal. Each node points to the subsequent node, except for the last one, which points to the head. + for (int node_idx = 0; node_idx < ring_buffer_size; node_idx++) { + (node_idx < ring_buffer_size - 1) ? ring_buffer[node_idx].next = (&ring_buffer[node_idx + 1]) : ring_buffer[node_idx].next = &(ring_buffer[0]); + } + // Initialize head and tail ptrs to start of ring buffer. + this->head = ring_buffer; + this->tail = ring_buffer; + } + LockFreeQueue(LockFreeQueue&& other) { + Node ring_buffer = other.ring_buffer; head.store(other.head.load()); tail.store(other.tail.load()); worker_thread_id.store(other.worker_thread_id.load()); parent_thread_id.store(other.parent_thread_id.load()); } - void push(const T& value) { - std::shared_ptr newData(std::make_shared(value)); - Node* newNode = new Node; - tail.load()->data = newData; - tail.load()->next = newNode; - tail.store(newNode); + + inline void push(const T& value) { + // Legacy Push API allowing copy by value + // for object T. + + // Stall condition: this push will update the tail (wptr) + // to match the location of head (rptr). The current push can + // thus overwrite data that's being read. Stall until head + // has progressed (data has been read). + while(tail.load()->next == head.load()) {}; + tail.load()->data = std::make_shared(value); + tail.store(tail.load()->next); } - std::shared_ptr pop() { + inline void push(std::shared_ptr value) { + // Latest Push API, passing ptrs around. + // Usually faster, since no data-copies. + + // Stall condition: this push will update the tail (wptr) + // to match the location of head (rptr). The current push can + // thus overwrite data that's being read. Stall until head + // has progressed (data has been read). + while(tail.load()->next == head.load()) {}; + tail.load()->data = value; + tail.store(tail.load()->next); + } + + inline std::shared_ptr pop() { Node* oldHead = pop_head(); - if (!oldHead) { - TT_THROW("Queue is empty"); - } std::shared_ptr result(oldHead->data); - delete oldHead; + // Does not actually delete oldHead->data. + // Just mark is to null to mark prev node as empty. + (oldHead->data).reset(); return result; } diff --git a/tt_metal/impl/dispatch/work_executor.hpp b/tt_metal/impl/dispatch/work_executor.hpp index 1b8da029b93..19e075281da 100644 --- a/tt_metal/impl/dispatch/work_executor.hpp +++ b/tt_metal/impl/dispatch/work_executor.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -28,6 +29,30 @@ enum class WorkerState { IDLE = 2, }; +inline void set_device_thread_affinity(std::thread& thread_, int managed_device_id) { + // Bind a device worker/reader thread to a CPU core, determined using round-robin. + static int num_online_cores = sysconf(_SC_NPROCESSORS_ONLN); + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(managed_device_id % num_online_cores, &cpuset); + int rc = pthread_setaffinity_np(thread_.native_handle(), sizeof(cpu_set_t), &cpuset); + if (rc) { + log_warning(tt::LogMetal, "Unable to bind worker thread to CPU Core. May see performance degradation. Error Code: {}", rc); + } +} + +inline void set_process_priority(int requested_priority) { + // Get priority for calling process + int process_priority = getpriority(PRIO_PROCESS, 0); + log_debug(tt::LogMetal, "Initial Process Priority: {}", process_priority); + if (process_priority == requested_priority) return; + // Set priority for calling process to user specified value + int rc = setpriority(PRIO_PROCESS, 0, requested_priority); + if (rc) { + log_warning(tt::LogMetal, "Unable to set process priority to {}, error code: {}", requested_priority, rc); + } +} + class WorkExecutor { // In asynchronous mode, each device has a worker thread that processes all host <--> cluster commands for this device. // Commands are pushed to the worker queue and picked up + executed asyncrhonously. @@ -37,6 +62,7 @@ class WorkExecutor { LockFreeQueue> worker_queue; WorkExecutor(int device_id) : managed_device_id(device_id) { + set_process_priority(0); if (this->worker_queue_mode == WorkExecutorMode::ASYNCHRONOUS) { this->start_worker(); } @@ -73,27 +99,43 @@ class WorkExecutor { inline void push_work(const std::function& work_executor, bool blocking = false) { ZoneScopedN("PushWork"); - if (this->worker_queue_mode == WorkExecutorMode::ASYNCHRONOUS) { - if (std::hash{}(std::this_thread::get_id()) == worker_queue.parent_thread_id.load()) { - // Push function executor to worker queue - this->worker_queue.push(work_executor); - { - std::lock_guard lock(this->cv_mutex); - cv.notify_one(); - } - if (blocking) { - this->synchronize(); - } - } else { - TT_ASSERT(std::hash{}(std::this_thread::get_id()) == worker_queue.worker_thread_id.load(), "Only main thread or worker thread can push to device worker queue."); - work_executor(); + if (std::hash{}(std::this_thread::get_id()) == worker_queue.parent_thread_id.load()) { + // Parent thread id is non-zero (using async mode) and parent is calling push_work. + // Push function executor to worker queue + this->worker_queue.push(work_executor); + { + std::lock_guard lock(this->cv_mutex); + cv.notify_one(); + } + if (blocking) { + this->synchronize(); } } else { - // Synchronous execution: Run function right away. + // Either push work is called from worker itself or async mode is not being used. work_executor(); } } + inline void push_work(std::shared_ptr> work_executor, bool blocking = false) { + // Latest push API, passing ptrs around for work container. Usually faster, since no data-copies. + ZoneScopedN("PushWork"); + if (std::hash{}(std::this_thread::get_id()) == worker_queue.parent_thread_id.load()) { + // Parent thread id is non-zero (using async mode) and parent is calling push_work. + // Push function executor to worker queue + this->worker_queue.push(work_executor); + { + std::lock_guard lock(this->cv_mutex); + cv.notify_one(); + } + if (blocking) { + this->synchronize(); + } + } else { + // Either push work is called from worker itself or async mode is not being used. + (*work_executor)(); + } + } + inline void synchronize() { if (this->worker_queue_mode == WorkExecutorMode::ASYNCHRONOUS and std::hash{}(std::this_thread::get_id()) == worker_queue.parent_thread_id.load()) { // Blocking = wait for queue flushed. Only main thread can explcitly insert a synchronize, otherwise we have a deadlock. @@ -138,14 +180,7 @@ class WorkExecutor { this->worker_thread = std::thread(&WorkExecutor::run_worker, this); this->worker_queue.worker_thread_id = std::hash{}(this->worker_thread.get_id()); // Bind a worker tied to a device to a specific CPU core in round robin fashion. Thread affinity == Better Perf. - static int num_online_cores = sysconf(_SC_NPROCESSORS_ONLN); - cpu_set_t cpuset; - CPU_ZERO(&cpuset); - CPU_SET(managed_device_id % num_online_cores, &cpuset); - int rc = pthread_setaffinity_np(worker_thread.native_handle(), sizeof(cpu_set_t), &cpuset); - if (rc) { - log_warning(tt::LogMetal, "Unable to bind worker thread to CPU Core. May see performance degradation. Error Code: {}", rc); - } + set_device_thread_affinity(this->worker_thread, this->managed_device_id); } inline void stop_worker() {