diff --git a/tt_eager/tensor/CMakeLists.txt b/tt_eager/tensor/CMakeLists.txt index 5e60b1f5313..87dcfca87e3 100644 --- a/tt_eager/tensor/CMakeLists.txt +++ b/tt_eager/tensor/CMakeLists.txt @@ -1,6 +1,5 @@ set(TENSOR_SRCS - ${CMAKE_CURRENT_SOURCE_DIR}/tensor_impl_wrapper.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor_impl.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/types.cpp diff --git a/tt_eager/tensor/module.mk b/tt_eager/tensor/module.mk index 0a52cc25046..f986138f91e 100644 --- a/tt_eager/tensor/module.mk +++ b/tt_eager/tensor/module.mk @@ -1,5 +1,4 @@ TENSOR_SRCS = \ - tt_eager/tensor/tensor_impl_wrapper.cpp \ tt_eager/tensor/tensor_impl.cpp \ tt_eager/tensor/tensor.cpp \ tt_eager/tensor/types.cpp \ diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index f619ca2a27f..63ace6ac5f7 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -2,24 +2,23 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "tensor/tensor.hpp" + #include #include -#include "tensor/tensor.hpp" +#include "common/bfloat16.hpp" +#include "llrt/llrt.hpp" +#include "queue/queue.hpp" #include "tensor/tensor_impl.hpp" #include "tensor/tensor_impl_wrapper.hpp" #include "tensor/tensor_utils.hpp" -#include "common/bfloat16.hpp" -#include "llrt/llrt.hpp" #include "tensor/types.hpp" +#include "third_party/magic_enum/magic_enum.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/common/math.hpp" - -#include "tt_metal/tt_stl/reflection.hpp" - -#include "third_party/magic_enum/magic_enum.hpp" #include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp" -#include "queue/queue.hpp" +#include "tt_metal/tt_stl/reflection.hpp" using namespace tt::constants; @@ -33,15 +32,15 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L deallocate_through_destructor(false) { ZoneScoped; std::visit( - [&] (auto&& storage) { + [&](auto&& storage) { using StorageType = std::decay_t; if constexpr (std::is_same_v) { this->tensor_attributes->tensor_populated = {true}; - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { TT_ASSERT(storage.buffer->device() != nullptr); workers = {storage.buffer->device()}; - tensor_impl::validate_on_device_dtype_and_layout(storage.buffer->device(), shape.value(), dtype, layout); + tensor_impl::validate_on_device_dtype_and_layout( + storage.buffer->device(), shape.value(), dtype, layout); // Increment main thread ref count for all tensors on device this->tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); // This tensor is being created from scratch in a worker. Track this and allow it to be explicitly @@ -50,8 +49,7 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L this->tensor_attributes->main_thread_tensor = false; } this->tensor_attributes->tensor_populated = {true}; - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { this->tensor_attributes->tensor_populated = {true}; } else if constexpr (std::is_same_v) { workers.reserve(storage.buffers.size()); @@ -76,7 +74,8 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L } else { raise_unsupported_storage(); } - }, storage); + }, + storage); } Tensor::Tensor(const Storage storage, const Shape shape, DataType dtype, Layout layout) : @@ -99,63 +98,92 @@ void Tensor::deallocate(bool force) { // Check if the attributes didn't get moved to another tensor. // If not, we can deallocate this tensor. std::visit( - [force, this](auto& storage) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - if (this->tensor_attributes.use_count() == 1) { - std::visit([](auto&& buffer) { buffer.reset(); }, storage.buffer); + [force, this](auto& storage) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + if (this->tensor_attributes.use_count() == 1) { + std::visit([](auto&& buffer) { buffer.reset(); }, storage.buffer); + } + } else if constexpr (std::is_same_v) { + if (this->workers.at(0)->in_main_thread() or not this->tensor_attributes->main_thread_tensor) { + if (not this->tensor_attributes->main_thread_tensor) { + TT_ASSERT( + not this->tensor_attributes->main_thread_ref_count, + "main_thread_ref_count for tensors created inside a worker thread must be 0"); } - } else if constexpr (std::is_same_v) { - if (this->workers.at(0)->in_main_thread() or not this->tensor_attributes->main_thread_tensor) { - if (not this->tensor_attributes->main_thread_tensor) { - TT_ASSERT(not this->tensor_attributes->main_thread_ref_count, "main_thread_ref_count for tensors created inside a worker thread must be 0"); - } - // If owned by the main thread, deallocate this tensor only from the main thread. If owned by worker thread, allow deallocation in worker and use shared_ptr ref count, since this is a thread_local tensor - uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or not this->tensor_attributes->main_thread_tensor) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count; - if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) { - this->tensor_attributes->deallocated = true; - this->workers.at(0)->push_work(std::make_shared>([force, attr = this->tensor_attributes] () mutable { - // Cross worker synchronization: If the tensor being deallocated is shared across workers (ex: all_gather op), - // wait until all workers are done with this tensor before deallocating. + // If owned by the main thread, deallocate this tensor only from the main thread. If owned by + // worker thread, allow deallocation in worker and use shared_ptr ref count, since this is a + // thread_local tensor + uint32_t ref_count_to_use = + (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or + not this->tensor_attributes->main_thread_tensor) + ? this->tensor_attributes.use_count() + : this->tensor_attributes->main_thread_ref_count; + if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) { + this->tensor_attributes->deallocated = true; + this->workers.at(0)->push_work(std::make_shared>( + [force, attr = this->tensor_attributes]() mutable { + // Cross worker synchronization: If the tensor being deallocated is shared across + // workers (ex: all_gather op), wait until all workers are done with this tensor + // before deallocating. bool num_threads_sharing_tensor = attr->num_sibling_workers_sharing_tensor; if (num_threads_sharing_tensor) { while (num_threads_sharing_tensor) { - num_threads_sharing_tensor = attr->num_sibling_workers_sharing_tensor;; + num_threads_sharing_tensor = attr->num_sibling_workers_sharing_tensor; + ; } } - std::visit([force, attr] (auto&& s) { - using type = std::decay_t; - if constexpr (std::is_same_v) { - if (force or s.buffer.use_count() == 1) { - DeallocateBuffer(*(s.buffer)); + std::visit( + [force, attr](auto&& s) { + using type = std::decay_t; + if constexpr (std::is_same_v) { + if (force or s.buffer.use_count() == 1) { + DeallocateBuffer(*(s.buffer)); + } + // Safe to reset this buf object since this is the last reference (in + // the main thread) to the tensor attr object holding this buffer. If + // any other tensor handles hold this buffer, it will not be deleted, + // until the last handle goes out of scope or is deallocated. + s.buffer.reset(); + } else if constexpr (std::is_same_v) { + // Manage Dynamic Storage (due to autoformat in async mode): Main thread + // sees this tensor as a device tensor, since worker has not updated + // storage time. When the worker executes the dealloc request, the + // storage type has been appropriately updated to Owned. + TT_ASSERT( + attr->dynamic_storage, + "Tensor storage type changed during runtime (device -> host), but " + "dynamic storage was not marked."); + std::visit([](auto&& buffer) { buffer.reset(); }, s.buffer); } - // Safe to reset this buf object since this is the last reference (in the main thread) to the tensor attr object holding this buffer. - // If any other tensor handles hold this buffer, it will not be deleted, until the last handle goes out of scope - // or is deallocated. - s.buffer.reset(); - } else if constexpr(std::is_same_v) { - // Manage Dynamic Storage (due to autoformat in async mode): Main thread sees this tensor as a device tensor, since worker has not updated - // storage time. When the worker executes the dealloc request, the storage type has been appropriately updated to Owned. - TT_ASSERT(attr->dynamic_storage, "Tensor storage type changed during runtime (device -> host), but dynamic storage was not marked."); - std::visit([] (auto&& buffer) { buffer.reset(); }, s.buffer); - } - }, attr->storage); + }, + attr->storage); })); - } - } else { - TT_FATAL(this->deallocate_through_destructor, "Device tensors created in the main thread cannot be explictly deallocated in worker threads."); } - } else if constexpr (std::is_same_v) { - if (force) { - TT_THROW("Cannot deallocate tensor with borrowed storage!"); - } - } else if constexpr (std::is_same_v) { - if (this->workers.at(0)->in_main_thread() or not this->tensor_attributes->main_thread_tensor) { - // If owned by the main thread, deallocate this tensor only from the main thread. If owned by worker thread, allow deallocation in worker and use shared_ptr ref count, since this is a thread_local tensor - uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or not this->tensor_attributes->main_thread_tensor) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count; - if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) { - this->tensor_attributes->deallocated = true; - auto dealloc_lambda = std::make_shared>([force, attr = this->tensor_attributes] (Device* worker) mutable { + } else { + TT_FATAL( + this->deallocate_through_destructor, + "Device tensors created in the main thread cannot be explictly deallocated in worker " + "threads."); + } + } else if constexpr (std::is_same_v) { + if (force) { + TT_THROW("Cannot deallocate tensor with borrowed storage!"); + } + } else if constexpr (std::is_same_v) { + if (this->workers.at(0)->in_main_thread() or not this->tensor_attributes->main_thread_tensor) { + // If owned by the main thread, deallocate this tensor only from the main thread. If owned by + // worker thread, allow deallocation in worker and use shared_ptr ref count, since this is a + // thread_local tensor + uint32_t ref_count_to_use = + (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or + not this->tensor_attributes->main_thread_tensor) + ? this->tensor_attributes.use_count() + : this->tensor_attributes->main_thread_ref_count; + if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) { + this->tensor_attributes->deallocated = true; + auto dealloc_lambda = std::make_shared>( + [force, attr = this->tensor_attributes](Device* worker) mutable { ZoneScopedN("ShardDeallocate"); auto& s = std::get(attr->storage); if (s.buffers.find(worker->id()) != s.buffers.end()) { @@ -166,27 +194,29 @@ void Tensor::deallocate(bool force) { } }); - for (auto worker : this->workers) { - worker->push_work(std::make_shared>([worker, dealloc_lambda] () mutable { - (*dealloc_lambda)(worker); - })); - } - } - } else { - TT_FATAL(this->deallocate_through_destructor, "Device tensors created in the main thread cannot be explictly deallocated in worker threads."); - } - } else if constexpr (std::is_same_v) { - if (this->tensor_attributes.use_count() == 1) { - // Same logic as above for host tensors - for (auto& current_buffer : storage.buffers) { - std::visit([](auto&& buffer) { buffer.reset(); }, current_buffer); + for (auto worker : this->workers) { + worker->push_work(std::make_shared>( + [worker, dealloc_lambda]() mutable { (*dealloc_lambda)(worker); })); } } } else { - raise_unsupported_storage(); + TT_FATAL( + this->deallocate_through_destructor, + "Device tensors created in the main thread cannot be explictly deallocated in worker " + "threads."); } - }, - this->tensor_attributes->storage); + } else if constexpr (std::is_same_v) { + if (this->tensor_attributes.use_count() == 1) { + // Same logic as above for host tensors + for (auto& current_buffer : storage.buffers) { + std::visit([](auto&& buffer) { buffer.reset(); }, current_buffer); + } + } + } else { + raise_unsupported_storage(); + } + }, + this->tensor_attributes->storage); } } @@ -195,7 +225,8 @@ void Tensor::perform_cleanup_for_async_mode() { // or move assignment operator if (this->tensor_attributes) { // Object has tensor_attributes that will be reassigned - if (this->workers.size() and this->workers.at(0)->in_main_thread() and this->workers.at(0)->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS) { + if (this->workers.size() and this->workers.at(0)->in_main_thread() and + this->workers.at(0)->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS) { // Operator called in main thread with async mode. Main thread Ref Count must be decremented. // This is the last tensor in the main thread holding these attributes. Deallocate the buffer // for this tensor. @@ -215,7 +246,8 @@ void Tensor::wait_for_tensor_data_populated() const { for (int i = 0; i < this->tensor_attributes->tensor_populated.size(); i++) { while (true) { std::scoped_lock lock(this->tensor_attributes->populated_mutex); - if (this->tensor_attributes->tensor_populated.at(i)) break; + if (this->tensor_attributes->tensor_populated.at(i)) + break; } } } @@ -227,7 +259,8 @@ void Tensor::wait_for_tensor_metadata_populated() const { // Stall until this worker is done while (true) { std::scoped_lock lock(this->tensor_attributes->populated_mutex); - if (this->tensor_attributes->tensor_populated.at(0)) break; + if (this->tensor_attributes->tensor_populated.at(0)) + break; }; } @@ -239,8 +272,7 @@ void Tensor::set_populated(Device* worker) { for (int i = 0; i < this->tensor_attributes->tensor_populated.size(); i++) { this->tensor_attributes->tensor_populated.at(i) = true; } - } - else { + } else { this->tensor_attributes->tensor_populated.at(worker->id()) = true; } } @@ -266,17 +298,22 @@ void Tensor::populate_buffers_and_metadata(const Tensor& other) { this->set_dtype(other.get_dtype()); this->set_layout(other.get_layout()); // Populate storage container with buffers + shapes - std::visit([this] (auto&& storage) { - using StorageType = std::decay_t; - if constexpr(std::is_same_v or std::is_same_v) { - std::get(this->tensor_attributes->storage).insert_buffer(storage.get_buffer()); - this->tensor_attributes->tensor_populated = {true}; - } else if constexpr(std::is_same_v or std::is_same_v) { - std::get(this->tensor_attributes->storage).buffers = storage.buffers; - std::get(this->tensor_attributes->storage).shapes = storage.shapes; - this->tensor_attributes->tensor_populated = std::vector(storage.buffers.size(), true); - } - }, other.get_storage()); // Non blocking storage query, since this is done for tensors that get created inside the worker thread + std::visit( + [this](auto&& storage) { + using StorageType = std::decay_t; + if constexpr (std::is_same_v or std::is_same_v) { + std::get(this->tensor_attributes->storage).insert_buffer(storage.get_buffer()); + this->tensor_attributes->tensor_populated = {true}; + } else if constexpr ( + std::is_same_v or + std::is_same_v) { + std::get(this->tensor_attributes->storage).buffers = storage.buffers; + std::get(this->tensor_attributes->storage).shapes = storage.shapes; + this->tensor_attributes->tensor_populated = std::vector(storage.buffers.size(), true); + } + }, + other.get_storage()); // Non blocking storage query, since this is done for tensors that get created inside the + // worker thread } std::vector Tensor::get_workers(bool blocking) const { @@ -290,36 +327,44 @@ std::vector Tensor::get_workers(bool blocking) const { this->wait_for_tensor_metadata_populated(); } - std::visit([this, blocking, &workers] (auto&& storage) { - using StorageType = std::decay_t; - // Assign workers only to device tensors - if constexpr (std::is_same_v) { - // Either explictly syncing or workers are pre-populated (this will happen for device tensors if using the correct APIs). - TT_FATAL(blocking or (this->workers.size() == 1), "Worker Handles for tensor must be populated or blocking = true must be set in get_workers()."); - if (this->workers.size() != 1) { - // Not populated - sync. - this->wait_for_tensor_data_populated(); - workers = {this->device()}; - } else { - // Already populated. - workers = this->workers; - } - } else if constexpr (std::is_same_v) { - // Either explictly syncing or workers are pre-populated (this will happen for device tensors if using the correct APIs). - TT_FATAL(blocking or (this->workers.size()), "Worker Handles for tensor must be populated or blocking = true must be set in get_workers()."); - if (not this->workers.size()) { - // Not populated - sync. - this->wait_for_tensor_data_populated(); - workers.reserve(storage.buffers.size()); - for (int i = 0; i < storage.ordered_device_ids.size(); ++i) { - auto device_id = storage.ordered_device_ids[i]; - workers.push_back(storage.buffers[device_id]->device()); + std::visit( + [this, blocking, &workers](auto&& storage) { + using StorageType = std::decay_t; + // Assign workers only to device tensors + if constexpr (std::is_same_v) { + // Either explictly syncing or workers are pre-populated (this will happen for device tensors if using + // the correct APIs). + TT_FATAL( + blocking or (this->workers.size() == 1), + "Worker Handles for tensor must be populated or blocking = true must be set in get_workers()."); + if (this->workers.size() != 1) { + // Not populated - sync. + this->wait_for_tensor_data_populated(); + workers = {this->device()}; + } else { + // Already populated. + workers = this->workers; + } + } else if constexpr (std::is_same_v) { + // Either explictly syncing or workers are pre-populated (this will happen for device tensors if using + // the correct APIs). + TT_FATAL( + blocking or (this->workers.size()), + "Worker Handles for tensor must be populated or blocking = true must be set in get_workers()."); + if (not this->workers.size()) { + // Not populated - sync. + this->wait_for_tensor_data_populated(); + workers.reserve(storage.buffers.size()); + for (int i = 0; i < storage.ordered_device_ids.size(); ++i) { + auto device_id = storage.ordered_device_ids[i]; + workers.push_back(storage.buffers[device_id]->device()); + } + } else { + workers = this->workers; } - } else { - workers = this->workers; } - } - }, this->tensor_attributes->storage); + }, + this->tensor_attributes->storage); return workers; } @@ -347,7 +392,7 @@ const Storage& Tensor::get_storage() const { return this->tensor_attributes->storage; } -Tensor Tensor::to(CommandQueue & queue, const MemoryConfig & mem_config) const { +Tensor Tensor::to(CommandQueue& queue, const MemoryConfig& mem_config) const { ZoneScoped; auto target_device = queue.device(); // Tensor can be using borrowed storage. If so, when running in async mode, copy this tensor to owned storage. @@ -358,14 +403,18 @@ Tensor Tensor::to(CommandQueue & queue, const MemoryConfig & mem_config) const { // Record main thread ref count for tensors before pushing to queue. uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count(); - queue.device()->push_work([async_safe_tensor, device_tensor, mem_config, target_device] () mutable { + queue.device()->push_work([async_safe_tensor, device_tensor, mem_config, target_device]() mutable { if (async_safe_tensor.storage_type() == StorageType::DEVICE) { TT_ASSERT(async_safe_tensor.device() == target_device && "Currently do not support moving between devices"); device_tensor.populate_buffers_and_metadata(async_safe_tensor); - } - else { - tensor_impl::validate_on_device_dtype_and_layout(target_device, async_safe_tensor.get_legacy_shape(), async_safe_tensor.get_dtype(), async_safe_tensor.get_layout()); - auto local_tensor = tensor_impl::to_device_wrapper(async_safe_tensor, target_device, mem_config); + } else { + tensor_impl::validate_on_device_dtype_and_layout( + target_device, + async_safe_tensor.get_legacy_shape(), + async_safe_tensor.get_dtype(), + async_safe_tensor.get_layout()); + auto local_tensor = + tensor_impl::to_device_wrapper(async_safe_tensor, target_device, mem_config, std::nullopt); // Populate device tensor device_tensor.populate_buffers_and_metadata(local_tensor); } @@ -373,11 +422,12 @@ Tensor Tensor::to(CommandQueue & queue, const MemoryConfig & mem_config) const { // Update main thread ref count for tensors after pushing to queue (update original tensor and returned tensor, // since both can be on device). device_tensor.tensor_attributes->update_main_thread_ref_count(device_tensor.workers.at(0), device_tensor_ref_count); - async_safe_tensor.tensor_attributes->update_main_thread_ref_count(device_tensor.workers.at(0), original_tensor_ref_count); + async_safe_tensor.tensor_attributes->update_main_thread_ref_count( + device_tensor.workers.at(0), original_tensor_ref_count); return device_tensor; } -Tensor Tensor::to(Device *target_device, const MemoryConfig &mem_config) const { +Tensor Tensor::to(Device* target_device, const MemoryConfig& mem_config) const { ZoneScoped; // Tensor can be using borrowed storage. If so, when running in async mode, copy this tensor to owned storage. Tensor async_safe_tensor = copy_borrowed_tensor_in_async_mode(target_device, *this); @@ -387,14 +437,18 @@ Tensor Tensor::to(Device *target_device, const MemoryConfig &mem_config) const { // Record main thread ref count for tensors before pushing to queue. uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count(); - target_device->push_work([async_safe_tensor, device_tensor, mem_config, target_device] () mutable { + target_device->push_work([async_safe_tensor, device_tensor, mem_config, target_device]() mutable { if (async_safe_tensor.storage_type() == StorageType::DEVICE) { TT_ASSERT(async_safe_tensor.device() == target_device && "Currently do not support moving between devices"); device_tensor.populate_buffers_and_metadata(async_safe_tensor); - } - else { - tensor_impl::validate_on_device_dtype_and_layout(target_device, async_safe_tensor.get_legacy_shape(), async_safe_tensor.get_dtype(), async_safe_tensor.get_layout()); - auto local_tensor = tensor_impl::to_device_wrapper(async_safe_tensor, target_device, mem_config); + } else { + tensor_impl::validate_on_device_dtype_and_layout( + target_device, + async_safe_tensor.get_legacy_shape(), + async_safe_tensor.get_dtype(), + async_safe_tensor.get_layout()); + auto local_tensor = + tensor_impl::to_device_wrapper(async_safe_tensor, target_device, mem_config, std::nullopt); // Populate device tensor device_tensor.populate_buffers_and_metadata(local_tensor); } @@ -402,22 +456,25 @@ Tensor Tensor::to(Device *target_device, const MemoryConfig &mem_config) const { // Update main thread ref count for tensors after pushing to queue (update original tensor and returned tensor, // since both can be on device). device_tensor.tensor_attributes->update_main_thread_ref_count(device_tensor.workers.at(0), device_tensor_ref_count); - async_safe_tensor.tensor_attributes->update_main_thread_ref_count(device_tensor.workers.at(0), original_tensor_ref_count); + async_safe_tensor.tensor_attributes->update_main_thread_ref_count( + device_tensor.workers.at(0), original_tensor_ref_count); return device_tensor; } -Tensor Tensor::to(DeviceMesh *device_mesh, const MemoryConfig &mem_config) const { +Tensor Tensor::to(DeviceMesh* device_mesh, const MemoryConfig& mem_config) const { ZoneScoped; return this->to(device_mesh->get_devices(), mem_config); } -Tensor Tensor::to(const std::vector& workers, const MemoryConfig &mem_config) const { +Tensor Tensor::to(const std::vector& workers, const MemoryConfig& mem_config) const { ZoneScoped; - TT_FATAL(validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); + TT_FATAL( + validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); // When broadcasting a single shard to all devices, we use all workers. // When sending a MultiDeviceHost tensor to the cluster, send it only to devices for which shards exist auto workers_to_use = workers; - if (std::holds_alternative(this->get_storage()) or std::holds_alternative(this->get_storage())) { + if (std::holds_alternative(this->get_storage()) or + std::holds_alternative(this->get_storage())) { workers_to_use = std::vector(workers.begin(), workers.begin() + num_buffers_in_tensor(*this)); } Tensor device_tensor = Tensor(workers_to_use); @@ -426,20 +483,21 @@ Tensor Tensor::to(const std::vector& workers, const MemoryConfig &mem_c uint32_t num_workers = workers_to_use.size(); for (int worker_index = 0; worker_index < workers_to_use.size(); ++worker_index) { auto& worker = workers_to_use[worker_index]; - worker->push_work( - [worker, *this, device_tensor, mem_config, num_workers, worker_index] () mutable { - auto shard = get_shard_for_device(*this, worker, worker_index); - if (shard.storage_type() == StorageType::OWNED) { - shard = tensor_impl::to_device_wrapper(shard, worker, mem_config); - } - insert_buffer_and_shape_for_device(worker, shard, device_tensor, worker_index); - if (not worker->id()) { - device_tensor.set_shape(this->get_shape()); - device_tensor.set_dtype(this->get_dtype()); - device_tensor.set_layout(this->get_layout()); - } - if (num_workers > 1) device_tensor.set_populated(worker); - else device_tensor.set_populated(); + worker->push_work([worker, *this, device_tensor, mem_config, num_workers, worker_index]() mutable { + auto shard = get_shard_for_device(*this, worker, worker_index); + if (shard.storage_type() == StorageType::OWNED) { + shard = tensor_impl::to_device_wrapper(shard, worker, mem_config, std::nullopt); + } + insert_buffer_and_shape_for_device(worker, shard, device_tensor, worker_index); + if (not worker->id()) { + device_tensor.set_shape(this->get_shape()); + device_tensor.set_dtype(this->get_dtype()); + device_tensor.set_layout(this->get_layout()); + } + if (num_workers > 1) + device_tensor.set_populated(worker); + else + device_tensor.set_populated(); }); } device_tensor.tensor_attributes->update_main_thread_ref_count(workers.at(0), device_tensor_ref_count); @@ -456,13 +514,16 @@ Tensor Tensor::cpu(bool blocking) const { // tensor accessors will stall until tensor is populated. return *this; } - TT_FATAL(validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); + TT_FATAL( + validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); Tensor host_tensor({}, workers.size()); uint32_t original_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count(); for (int worker_index = 0; worker_index < workers.size(); worker_index++) { auto target_device = workers[worker_index]; - target_device->push_work([host_tensor, blocking, target_device, *this, workers, worker_index] () mutable { - TT_ASSERT(this->storage_type() == StorageType::DEVICE or this->storage_type() == StorageType::MULTI_DEVICE, "Can only use worker queue for cpu call if tensor is on device."); + target_device->push_work([host_tensor, blocking, target_device, *this, workers, worker_index]() mutable { + TT_ASSERT( + this->storage_type() == StorageType::DEVICE or this->storage_type() == StorageType::MULTI_DEVICE, + "Can only use worker queue for cpu call if tensor is on device."); auto shard = get_shard_for_device(*this, target_device); shard = tensor_impl::to_host_wrapper(shard, blocking); insert_buffer_and_shape_for_device(target_device, shard, host_tensor, worker_index); @@ -473,8 +534,7 @@ Tensor Tensor::cpu(bool blocking) const { } if (workers.size() == 1) { host_tensor.set_populated(); - } - else { + } else { host_tensor.set_populated(target_device); } }); @@ -491,21 +551,18 @@ Tensor Tensor::cpu(bool blocking) const { Tensor Tensor::cpu_sharded() const { ZoneScoped; - return tensor_impl::to_host_wrapper_sharded(*this); + return tensor_impl::to_host_sharded_wrapper(*this); } - -Tensor Tensor::extract_shard(const CoreCoord & core) const{ +Tensor Tensor::extract_shard(const CoreCoord& core) const { ZoneScoped; auto buffer_page_mapping = generate_buffer_page_mapping(*this->buffer()); auto core_id = buffer_page_mapping.core_to_core_id_.at(core); return this->extract_shard(core_id); } -Tensor Tensor::extract_shard(const uint32_t & core_id) const{ - - return tensor_impl::to_extract_shard_wrapper(*this, core_id); - +Tensor Tensor::extract_shard(const uint32_t& core_id) const { + return tensor_impl::extract_shard_wrapper(*this, core_id); } Tensor Tensor::to(Layout target_layout, Device* worker) const { @@ -515,8 +572,11 @@ Tensor Tensor::to(Layout target_layout, Device* worker) const { // Tensor can be using borrowed storage. If so, when running in async mode, copy this tensor to owned storage. Tensor async_safe_tensor = copy_borrowed_tensor_in_async_mode(worker, *this); Tensor tensor_modified_layout = Tensor({}, 1); - worker->push_work([async_safe_tensor, tensor_modified_layout, target_layout] () mutable { - TT_ASSERT(async_safe_tensor.storage_type() == StorageType::OWNED or async_safe_tensor.storage_type() == StorageType::BORROWED && "to(layout) must be called on host tensors with a single buffer when a single worker is specified"); + worker->push_work([async_safe_tensor, tensor_modified_layout, target_layout]() mutable { + TT_ASSERT( + async_safe_tensor.storage_type() == StorageType::OWNED or + async_safe_tensor.storage_type() == StorageType::BORROWED && + "to(layout) must be called on host tensors with a single buffer when a single worker is specified"); auto local_tensor = tensor_impl::to_layout_wrapper(async_safe_tensor, target_layout); // Populate modified layout tensor tensor_modified_layout.populate_buffers_and_metadata(local_tensor); @@ -524,7 +584,9 @@ Tensor Tensor::to(Layout target_layout, Device* worker) const { return tensor_modified_layout; } // Running without worker threads (non-async) - TT_ASSERT(this->storage_type() != StorageType::DEVICE or this->storage_type() != StorageType::MULTI_DEVICE && "Bring tensor to host before converting to target layout"); + TT_ASSERT( + this->storage_type() != StorageType::DEVICE or + this->storage_type() != StorageType::MULTI_DEVICE && "Bring tensor to host before converting to target layout"); return tensor_impl::to_layout_wrapper(*this, target_layout); } @@ -533,16 +595,22 @@ Tensor Tensor::to(Layout target_layout, DeviceMesh* device_mesh) const { if (device_mesh) { auto all_workers = device_mesh->get_devices(); auto workers = std::vector(all_workers.begin(), all_workers.begin() + num_buffers_in_tensor(*this)); - TT_FATAL(validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); + TT_FATAL( + validate_worker_modes(workers), + "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); Tensor tensor_modified_layout = Tensor({}, workers.size()); for (int worker_index = 0; worker_index < workers.size(); ++worker_index) { auto& worker = workers[worker_index]; - worker->push_work([*this, tensor_modified_layout, target_layout, worker, worker_index] () mutable { - TT_ASSERT(this->storage_type() == StorageType::MULTI_DEVICE_HOST && "to(layout) must be called on host tensors with MULTI_DEVICE_HOST_STORAGE when multiple workers are specified");; + worker->push_work([*this, tensor_modified_layout, target_layout, worker, worker_index]() mutable { + TT_ASSERT( + this->storage_type() == StorageType::MULTI_DEVICE_HOST && + "to(layout) must be called on host tensors with MULTI_DEVICE_HOST_STORAGE when multiple workers " + "are specified"); + ; auto shard = get_shard_for_device(*this, worker, worker_index); shard = tensor_impl::to_layout_wrapper(shard, target_layout); insert_buffer_and_shape_for_device(worker, shard, tensor_modified_layout, worker_index); - if (not (worker->id())) { + if (not(worker->id())) { tensor_modified_layout.set_shape(this->get_shape()); tensor_modified_layout.set_dtype(this->get_dtype()); tensor_modified_layout.set_layout(target_layout); @@ -553,7 +621,9 @@ Tensor Tensor::to(Layout target_layout, DeviceMesh* device_mesh) const { return tensor_modified_layout; } // Running without worker threads (non-async) - TT_ASSERT(this->storage_type() != StorageType::DEVICE or this->storage_type() != StorageType::MULTI_DEVICE && "Bring tensor to host before converting to target layout"); + TT_ASSERT( + this->storage_type() != StorageType::DEVICE or + this->storage_type() != StorageType::MULTI_DEVICE && "Bring tensor to host before converting to target layout"); return tensor_impl::to_layout_wrapper(*this, target_layout); } @@ -564,8 +634,7 @@ void Tensor::print() const { std::cout << write_to_string() << std::endl; } Tensor Tensor::pad(const Shape& output_tensor_shape, const Shape& input_tensor_start, float pad_value) const { ZoneScoped; TT_ASSERT( - this->storage_type() == StorageType::OWNED or - this->storage_type() == StorageType::MULTI_DEVICE_HOST or + this->storage_type() == StorageType::OWNED or this->storage_type() == StorageType::MULTI_DEVICE_HOST or this->storage_type() == StorageType::BORROWED && "Tensor must be on host for padding"); TT_ASSERT(this->get_layout() == Layout::ROW_MAJOR && "Tensor layout must be ROW_MAJOR for padding"); @@ -574,7 +643,7 @@ Tensor Tensor::pad(const Shape& output_tensor_shape, const Shape& input_tensor_s for (auto index = 0; index < input_shape.rank(); index++) { auto front = input_tensor_start[index]; auto back = output_tensor_shape[index] - (input_tensor_start[index] + input_shape[index]); - dimensions_pads.push_back(Padding::PadDimension{.front=front, .back=back}); + dimensions_pads.push_back(Padding::PadDimension{.front = front, .back = back}); } const auto padding = Padding(dimensions_pads, Padding::PadValue::Any); auto output_shape_with_padding = Shape(output_tensor_shape, padding); @@ -582,7 +651,7 @@ Tensor Tensor::pad(const Shape& output_tensor_shape, const Shape& input_tensor_s return tensor_impl::pad_wrapper(*this, output_shape_with_padding, input_tensor_start, pad_value); } -Tensor Tensor::unpad(const Shape &output_tensor_start, const Shape &output_tensor_end) const { +Tensor Tensor::unpad(const Shape& output_tensor_start, const Shape& output_tensor_end) const { ZoneScoped; TT_ASSERT(this->get_layout() == Layout::ROW_MAJOR && "Tensor layout must be ROW_MAJOR for unpadding"); return tensor_impl::unpad_wrapper(*this, output_tensor_start, output_tensor_end); @@ -615,7 +684,7 @@ Tensor Tensor::pad_to_tile(float pad_value) const { return this->pad(Shape(shape, padded_shape), Shape{input_tensor_start}, pad_value); } -Tensor Tensor::unpad_from_tile(const Shape &output_tensor_shape) const { +Tensor Tensor::unpad_from_tile(const Shape& output_tensor_shape) const { ZoneScoped; for (auto index = 0; index < this->get_legacy_shape().rank() - 2; index++) { @@ -643,9 +712,7 @@ const bool Tensor::is_sharded() const { return is_tensor_on_device_or_multidevice(*this) ? this->memory_config().is_sharded() : false; } -uint32_t Tensor::element_size() const { - return tensor_impl::element_size_bytes_wrapper(this->get_dtype()); -} +uint32_t Tensor::element_size() const { return tensor_impl::element_size_bytes(this->get_dtype()); } Tensor Tensor::reshape(int N, int C, int H, int W) const { ZoneScoped; @@ -661,11 +728,12 @@ Tensor Tensor::reshape(const Shape& new_shape) const { this->volume(), tt::tt_metal::compute_volume(new_shape)); if (this->get_layout() == Layout::TILE) { - TT_ASSERT(new_shape[-2] % TILE_HEIGHT == 0 && new_shape[-1] % TILE_WIDTH == 0 && "Expected a multiple of 32 for H, W (or -1 evaluating to such) in Tensor::reshape()!"); + TT_ASSERT( + new_shape[-2] % TILE_HEIGHT == 0 && new_shape[-1] % TILE_WIDTH == 0 && + "Expected a multiple of 32 for H, W (or -1 evaluating to such) in Tensor::reshape()!"); } return std::visit( - [this, &new_shape](auto&& storage) -> Tensor - { + [this, &new_shape](auto&& storage) -> Tensor { using T = std::decay_t; const auto& tensor = *this; if constexpr (std::is_same_v) { @@ -688,33 +756,27 @@ Tensor Tensor::reshape(const Shape& new_shape) const { return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout()); } }, - this->get_storage() - ); + this->get_storage()); } bool Tensor::is_allocated() const { ZoneScoped; return std::visit( - [](auto&& storage) -> bool - { + [](auto&& storage) -> bool { using T = std::decay_t; if constexpr (std::is_same_v) { return std::visit([](auto&& buffer) -> bool { return buffer.is_allocated(); }, storage.buffer); - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { return bool(storage.buffer) and storage.buffer->size() > 0; - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { return true; - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { bool is_allocated = true; for (const auto& buffer : storage.buffers) { is_allocated &= std::visit([](auto&& buffer) -> bool { return buffer.is_allocated(); }, buffer); } return is_allocated; - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { bool is_allocated = true; for (int i = 0; i < storage.ordered_device_ids.size(); ++i) { auto device_id = storage.ordered_device_ids[i]; @@ -722,16 +784,14 @@ bool Tensor::is_allocated() const { is_allocated &= bool(buffer) and buffer->size() > 0; } return is_allocated; - } - else { + } else { raise_unsupported_storage(); } }, - this->get_storage() - ); + this->get_storage()); } -std::vector Tensor::host_page_ordering(){ +std::vector Tensor::host_page_ordering() { auto buffer_page_mapping = generate_buffer_page_mapping(*this->buffer()); auto cores = buffer_page_mapping.all_cores_; auto shard_size = buffer()->shard_spec().size(); @@ -739,8 +799,8 @@ std::vector Tensor::host_page_ordering(){ std::vector ret_vec; ret_vec.reserve(num_pages); - for(int page_id = 0; page_id Tensor::host_page_ordering(){ StorageType Tensor::storage_type() const { return std::visit( - [] (auto&& storage) -> StorageType - { + [](auto&& storage) -> StorageType { using T = std::decay_t; if constexpr (std::is_same_v) { return StorageType::OWNED; - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { return StorageType::DEVICE; - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { return StorageType::BORROWED; - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { return StorageType::MULTI_DEVICE; - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { return StorageType::MULTI_DEVICE_HOST; - } - else { + } else { raise_unsupported_storage(); } }, - this->get_storage() - ); + this->get_storage()); } namespace detail { @@ -785,15 +838,14 @@ const Shape compute_strides(const Shape& shape) { } return strides; } -} +} // namespace detail -const Shape Tensor::strides() const { - return detail::compute_strides(this->get_legacy_shape()); -} +const Shape Tensor::strides() const { return detail::compute_strides(this->get_legacy_shape()); } uint32_t Tensor::volume() const { return tt::tt_metal::compute_volume(this->get_legacy_shape()); } -Tensor create_device_tensor(const Shape& shape, DataType data_type, Layout layout, Device *device, const MemoryConfig& memory_config) { +Tensor create_device_tensor( + const Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config) { ZoneScoped; if (memory_config.is_sharded()) { TT_ASSERT(memory_config.shard_spec.has_value()); @@ -808,21 +860,28 @@ Tensor create_device_tensor(const Shape& shape, DataType data_type, Layout layou other_dims *= shape[i]; } - auto element_size = tensor_impl::element_size_bytes_wrapper(data_type); + auto element_size = tensor_impl::element_size_bytes(data_type); auto page_shape = tensor_impl::get_sharded_page_shape(layout, data_type, shard_spec.shape); - std::array tensor2d_size = {other_dims/page_shape[0], width/page_shape[1]}; + std::array tensor2d_size = {other_dims / page_shape[0], width / page_shape[1]}; ShardSpecBuffer shard_spec_buffer(shard_spec, page_shape, tensor2d_size); uint32_t packed_size_in_bytes; - packed_size_in_bytes = tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(shape, data_type)); - auto device_buffer = tensor_impl::allocate_buffer_on_device(packed_size_in_bytes, device, shape, - data_type, layout, memory_config, - std::make_optional(shard_spec_buffer) - ); + packed_size_in_bytes = + tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(shape, data_type)); + auto device_buffer = tensor_impl::allocate_buffer_on_device( + packed_size_in_bytes, + device, + shape, + data_type, + layout, + memory_config, + std::make_optional(shard_spec_buffer)); return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout); } else { - uint32_t packed_size_in_bytes = tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(shape, data_type)); - auto device_buffer = tensor_impl::allocate_buffer_on_device(packed_size_in_bytes, device, shape, data_type, layout, memory_config); + uint32_t packed_size_in_bytes = + tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(shape, data_type)); + auto device_buffer = tensor_impl::allocate_buffer_on_device( + packed_size_in_bytes, device, shape, data_type, layout, memory_config); return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout); } } @@ -840,7 +899,8 @@ void* get_raw_host_data_ptr(const Tensor& tensor) { return dispatch_map.at(tensor.get_dtype())(tensor); } -void memcpy(CommandQueue& queue, void* dst, const Tensor& src, const std::optional transfer_size, bool blocking) { +void memcpy( + CommandQueue& queue, void* dst, const Tensor& src, const std::optional transfer_size, bool blocking) { TT_ASSERT(not transfer_size.has_value(), "transfer_size is not supported for memcpy right now!"); if (not is_device_tensor(src)) { TT_THROW("memcpy: src tensor must be on device"); @@ -901,21 +961,25 @@ void memcpy(Tensor& dst, const Tensor& src, const std::optional tra } } -Tensor allocate_tensor_on_device(const ttnn::Shape& shape, DataType data_type, Layout layout, Device *device, const MemoryConfig& memory_config) { +Tensor allocate_tensor_on_device( + const ttnn::Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config) { // Top level wrapper to asynchronously create a device tensor (single device) Tensor device_tensor = Tensor({device}); uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); - device->push_work( - [shape, data_type, layout, device, memory_config, device_tensor] () mutable { - auto local_tensor = create_device_tensor(shape.value(), data_type, layout, device, memory_config); - device_tensor.populate_buffers_and_metadata(local_tensor); - } - ); + device->push_work([shape, data_type, layout, device, memory_config, device_tensor]() mutable { + auto local_tensor = create_device_tensor(shape.value(), data_type, layout, device, memory_config); + device_tensor.populate_buffers_and_metadata(local_tensor); + }); device_tensor.tensor_attributes->update_main_thread_ref_count(device, device_tensor_ref_count); return device_tensor; } -Tensor allocate_tensor_on_device(const ttnn::Shape& shape, DataType data_type, Layout layout, DeviceMesh *device_mesh, const MemoryConfig& memory_config) { +Tensor allocate_tensor_on_device( + const ttnn::Shape& shape, + DataType data_type, + Layout layout, + DeviceMesh* device_mesh, + const MemoryConfig& memory_config) { // Top level wrapper to asynchronously create a device tensor (multi-device) Tensor device_tensor = Tensor(device_mesh->get_devices()); uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); @@ -924,18 +988,16 @@ Tensor allocate_tensor_on_device(const ttnn::Shape& shape, DataType data_type, L for (int worker_index = 0; worker_index < num_workers; ++worker_index) { auto& worker = workers[worker_index]; - worker->push_work( - [shape, data_type, layout, worker, memory_config, device_tensor, worker_index] () mutable { - auto local_tensor = create_device_tensor(shape.value(), data_type, layout, worker, memory_config); - insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index); - if (not worker->id()) { - device_tensor.set_shape(ttnn::Shape(shape)); - device_tensor.set_dtype(data_type); - device_tensor.set_layout(layout); - } - device_tensor.set_populated(worker); + worker->push_work([shape, data_type, layout, worker, memory_config, device_tensor, worker_index]() mutable { + auto local_tensor = create_device_tensor(shape.value(), data_type, layout, worker, memory_config); + insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index); + if (not worker->id()) { + device_tensor.set_shape(ttnn::Shape(shape)); + device_tensor.set_dtype(data_type); + device_tensor.set_layout(layout); } - ); + device_tensor.set_populated(worker); + }); } device_tensor.tensor_attributes->update_main_thread_ref_count(workers.at(0), device_tensor_ref_count); return device_tensor; @@ -950,37 +1012,48 @@ void write_tensor(Tensor host_tensor, Tensor device_tensor, uint8_t cq_id) { for (int worker_index = 0; worker_index < device_tensor.workers.size(); ++worker_index) { auto& worker = device_tensor.workers[worker_index]; - worker->push_work( - [cq_id, worker, worker_index, async_safe_tensor, device_tensor] () mutable { - TT_FATAL(async_safe_tensor.storage_type() == StorageType::BORROWED or async_safe_tensor.storage_type() == StorageType::OWNED or async_safe_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST, "write_tensor only supports host_tensor to device_tensor data transfer"); - TT_FATAL(device_tensor.storage_type() == StorageType::DEVICE or device_tensor.storage_type() == StorageType::MULTI_DEVICE, "write_tensor only supports host_tensor to device_tensor data transfer"); - TT_FATAL(async_safe_tensor.get_shape() == device_tensor.get_shape()); - TT_FATAL(async_safe_tensor.get_dtype() == device_tensor.get_dtype()); - TT_FATAL(async_safe_tensor.get_layout() == device_tensor.get_layout()); - std::visit( - [worker_index, worker, cq_id, &async_safe_tensor] (auto&& s) { - void* host_data = nullptr; - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - if (std::holds_alternative(async_safe_tensor.get_storage())) { - // Handle case when writing borrowed tensor single device tensor (only allowed for sync mode) - auto host_storage = std::get(async_safe_tensor.get_storage()); - std::visit([&host_data] (auto&& b) { host_data = b.data(); }, host_storage.buffer); - } else { - auto host_storage = std::get(async_safe_tensor.get_storage()); - std::visit([&host_data] (auto&& b) { host_data = b.begin(); }, host_storage.get_buffer()); - } - EnqueueWriteBuffer(worker->command_queue(cq_id), s.get_buffer(), host_data, false); - } else if constexpr (std::is_same_v) { - auto host_storage = std::get(async_safe_tensor.get_storage()); - std::visit([worker_index, &host_data] (auto&& b) { host_data = b.begin(); }, host_storage.get_buffer(worker_index)); - EnqueueWriteBuffer(worker->command_queue(cq_id), s.get_buffer_for_device(worker), host_data, false); + worker->push_work([cq_id, worker, worker_index, async_safe_tensor, device_tensor]() mutable { + TT_FATAL( + async_safe_tensor.storage_type() == StorageType::BORROWED or + async_safe_tensor.storage_type() == StorageType::OWNED or + async_safe_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST, + "write_tensor only supports host_tensor to device_tensor data transfer"); + TT_FATAL( + device_tensor.storage_type() == StorageType::DEVICE or + device_tensor.storage_type() == StorageType::MULTI_DEVICE, + "write_tensor only supports host_tensor to device_tensor data transfer"); + TT_FATAL(async_safe_tensor.get_shape() == device_tensor.get_shape()); + TT_FATAL(async_safe_tensor.get_dtype() == device_tensor.get_dtype()); + TT_FATAL(async_safe_tensor.get_layout() == device_tensor.get_layout()); + std::visit( + [worker_index, worker, cq_id, &async_safe_tensor](auto&& s) { + void* host_data = nullptr; + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + if (std::holds_alternative(async_safe_tensor.get_storage())) { + // Handle case when writing borrowed tensor single device tensor (only allowed for sync + // mode) + auto host_storage = std::get(async_safe_tensor.get_storage()); + std::visit([&host_data](auto&& b) { host_data = b.data(); }, host_storage.buffer); + } else { + auto host_storage = std::get(async_safe_tensor.get_storage()); + std::visit([&host_data](auto&& b) { host_data = b.begin(); }, host_storage.get_buffer()); } - }, device_tensor.get_storage()); - } - ); + EnqueueWriteBuffer(worker->command_queue(cq_id), s.get_buffer(), host_data, false); + } else if constexpr (std::is_same_v) { + auto host_storage = std::get(async_safe_tensor.get_storage()); + std::visit( + [worker_index, &host_data](auto&& b) { host_data = b.begin(); }, + host_storage.get_buffer(worker_index)); + EnqueueWriteBuffer( + worker->command_queue(cq_id), s.get_buffer_for_device(worker), host_data, false); + } + }, + device_tensor.get_storage()); + }); } - async_safe_tensor.tensor_attributes->update_main_thread_ref_count(device_tensor.workers.at(0), host_tensor_ref_count); + async_safe_tensor.tensor_attributes->update_main_thread_ref_count( + device_tensor.workers.at(0), host_tensor_ref_count); device_tensor.tensor_attributes->update_main_thread_ref_count(device_tensor.workers.at(0), device_tensor_ref_count); } diff --git a/tt_eager/tensor/tensor_impl.cpp b/tt_eager/tensor/tensor_impl.cpp index f2f31beaca7..f9bdcffb300 100644 --- a/tt_eager/tensor/tensor_impl.cpp +++ b/tt_eager/tensor/tensor_impl.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "tensor/tensor_impl.hpp" + #include "tensor/tensor_impl_wrapper.hpp" namespace tt { @@ -27,83 +28,83 @@ std::ostream& operator<<(std::ostream& os, const DataType& dtype) { return os; } - +uint32_t element_size_bytes(DataType dtype) { + switch (dtype) { + case DataType::BFLOAT16: return sizeof(bfloat16); + case DataType::FLOAT32: return sizeof(float); + case DataType::INT32: return sizeof(int32_t); + case DataType::UINT32: return sizeof(uint32_t); + case DataType::UINT16: return sizeof(uint16_t); + case DataType::BFLOAT8_B: return sizeof(std::byte); + case DataType::BFLOAT4_B: return sizeof(std::byte); + default: TT_THROW("Unsupported data type"); + } +} uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, const Shape& shape) { uint32_t W = shape[-1]; uint32_t page_size = 0; switch (layout) { case Layout::ROW_MAJOR: { - uint32_t size_of_element = element_size_bytes_wrapper(dtype); + uint32_t size_of_element = element_size_bytes(dtype); page_size = W * size_of_element; - } - break; + } break; case Layout::TILE: { // TODO: Update to be generic for data type (issue 462) switch (dtype) { case DataType::BFLOAT16: { // Float is converted to bfloat16 before being written to device - uint32_t size_of_element = element_size_bytes_wrapper(DataType::BFLOAT16); + uint32_t size_of_element = element_size_bytes(DataType::BFLOAT16); page_size = constants::TILE_HW * size_of_element; - } - break; + } break; case DataType::FLOAT32: { - uint32_t size_of_element = element_size_bytes_wrapper(DataType::FLOAT32); + uint32_t size_of_element = element_size_bytes(DataType::FLOAT32); page_size = constants::TILE_HW * size_of_element; - } - break; + } break; case DataType::UINT32: case DataType::INT32: case DataType::UINT16: { - uint32_t size_of_element = element_size_bytes_wrapper(dtype); + uint32_t size_of_element = element_size_bytes(dtype); page_size = constants::TILE_HW * size_of_element; - } - break; + } break; case DataType::BFLOAT4_B: { page_size = constants::BFLOAT4_B_TILE_HW; - } - break; - case DataType::BFLOAT8_B: { + } break; + case DataType::BFLOAT8_B: { page_size = constants::BFLOAT8_B_TILE_HW; - } - break; - default: - TT_ASSERT(false && "Unsupported data type!"); + } break; + default: TT_ASSERT(false && "Unsupported data type!"); } TT_ASSERT(total_size_bytes % page_size == 0); - } - break; - default: - TT_ASSERT(false && "Unsupported layout to write to device"); + } break; + default: TT_ASSERT(false && "Unsupported layout to write to device"); } TT_ASSERT(page_size != 0); return page_size; } - - -std::array get_sharded_page_shape(Layout layout, DataType dtype, std::array shard_shape) { +std::array get_sharded_page_shape(Layout layout, DataType dtype, std::array shard_shape) { uint32_t page_size = 0; std::array page_shape = {constants::TILE_HEIGHT, constants::TILE_WIDTH}; - //Physical limitation in FD for now + // Physical limitation in FD for now switch (layout) { case Layout::ROW_MAJOR: { - //TODO: Explore valid page shapes other than 1,W + // TODO: Explore valid page shapes other than 1,W page_shape = {1, shard_shape[1]}; - } - break; - case Layout::TILE: {;} - break; - default: - TT_ASSERT(false && "Unsupported layout to write to device"); + } break; + case Layout::TILE: { + ; + } break; + default: TT_ASSERT(false && "Unsupported layout to write to device"); } return page_shape; } -void validate_sharded_buffer_allocation(const Shape& shape, Layout layout, std::optional shard_params, const MemoryConfig& memory_config) { +void validate_sharded_buffer_allocation( + const Shape& shape, Layout layout, std::optional shard_params, const MemoryConfig& memory_config) { TT_ASSERT(shard_params.has_value(), "Shard params are required for sharded buffer and they were not initialized"); auto shard_spec = memory_config.shard_spec.value(); @@ -114,114 +115,154 @@ void validate_sharded_buffer_allocation(const Shape& shape, Layout layout, std:: uint32_t total_height = tt_metal::compute_volume(shape) / shape[-1]; uint32_t total_width = shape[-1]; if (memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - TT_ASSERT(total_width == shard_shape[1], fmt::format("Shard shape {} does not divide tensor shape {} correctly according to sharding scheme", shard_shape[1], total_width)); + TT_ASSERT( + total_width == shard_shape[1], + fmt::format( + "Shard shape {} does not divide tensor shape {} correctly according to sharding scheme", + shard_shape[1], + total_width)); uint32_t num_shards = div_up(total_height, shard_shape[0]); - TT_ASSERT(num_shards <= num_cores, fmt::format("Number of shards {} must match number of cores {}", num_shards, num_cores)); + TT_ASSERT( + num_shards <= num_cores, + fmt::format("Number of shards {} must match number of cores {}", num_shards, num_cores)); } else if (memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { - TT_ASSERT(total_height == shard_shape[0], "Shard shape does not divide tensor shape correctly according to sharding scheme"); + TT_ASSERT( + total_height == shard_shape[0], + "Shard shape does not divide tensor shape correctly according to sharding scheme"); uint32_t num_shards = div_up(total_width, shard_shape[1]); - TT_ASSERT(num_shards <= num_cores, fmt::format("Number of shards {} must match number of cores {}", num_shards, num_cores)); + TT_ASSERT( + num_shards <= num_cores, + fmt::format("Number of shards {} must match number of cores {}", num_shards, num_cores)); } else if (memory_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - TT_ASSERT(shard_spec.grid.ranges().size() == 1, "Shard grid must be one full rectangular grid for block sharded!"); + TT_ASSERT( + shard_spec.grid.ranges().size() == 1, "Shard grid must be one full rectangular grid for block sharded!"); uint32_t num_shards_along_height = div_up(total_height, shard_shape[0]); uint32_t num_shards_along_width = div_up(total_width, shard_shape[1]); // Additionally check that number of cores along height and width matches shard grid const CoreCoord shard_grid = shard_spec.grid.bounding_box().grid_size(); if (shard_spec.orientation == ShardOrientation::ROW_MAJOR) { - TT_ASSERT(num_shards_along_height <= shard_grid.y, fmt::format("Number of shards along height {} must match number of rows {} for row major orientation!", num_shards_along_height, shard_grid.y)); - TT_ASSERT(num_shards_along_width <= shard_grid.x, fmt::format("Number of shards along width {} must match number of columns {} for row major orientation!", num_shards_along_width, shard_grid.x)); + TT_ASSERT( + num_shards_along_height <= shard_grid.y, + fmt::format( + "Number of shards along height {} must match number of rows {} for row major orientation!", + num_shards_along_height, + shard_grid.y)); + TT_ASSERT( + num_shards_along_width <= shard_grid.x, + fmt::format( + "Number of shards along width {} must match number of columns {} for row major orientation!", + num_shards_along_width, + shard_grid.x)); } else { - TT_ASSERT(num_shards_along_height <= shard_grid.x, fmt::format("Number of shards along height {} must match number of columns {} for column major orientation!", num_shards_along_height, shard_grid.x)); - TT_ASSERT(num_shards_along_width <= shard_grid.y, fmt::format("Number of shards along width {} must match number of rows {} for column major orientation!", num_shards_along_width, shard_grid.y)); + TT_ASSERT( + num_shards_along_height <= shard_grid.x, + fmt::format( + "Number of shards along height {} must match number of columns {} for column major orientation!", + num_shards_along_height, + shard_grid.x)); + TT_ASSERT( + num_shards_along_width <= shard_grid.y, + fmt::format( + "Number of shards along width {} must match number of rows {} for column major orientation!", + num_shards_along_width, + shard_grid.y)); } } else { TT_FATAL(false, "Unsupported sharding scheme"); } if (layout == Layout::TILE) { - TT_ASSERT((shard_shape[0] % constants::TILE_HEIGHT == 0 && shard_shape[1] % constants::TILE_WIDTH == 0), "Shard shape must be tile sized"); + TT_ASSERT( + (shard_shape[0] % constants::TILE_HEIGHT == 0 && shard_shape[1] % constants::TILE_WIDTH == 0), + "Shard shape must be tile sized"); } else if (layout == Layout::ROW_MAJOR) { // Require alignment for now - // TT_ASSERT(shard_shape[1] * tensor_impl::element_size_bytes_wrapper(data_type) % ADDRESS_ALIGNMENT == 0); + // TT_ASSERT(shard_shape[1] * tensor_impl::element_size_bytes(data_type) % ADDRESS_ALIGNMENT == 0); } } namespace detail { -DeviceBuffer allocate_interleaved_buffer_on_device(uint32_t buffer_size_bytes, Device *device, const Shape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config) { +DeviceBuffer allocate_interleaved_buffer_on_device( + uint32_t buffer_size_bytes, + Device* device, + const Shape& shape, + DataType data_type, + Layout layout, + const MemoryConfig& memory_config) { uint32_t page_size = get_page_size(data_type, layout, buffer_size_bytes, shape); return std::make_shared(device, buffer_size_bytes, page_size, memory_config.buffer_type); } -DeviceBuffer allocate_contiguous_buffer_on_device(uint32_t buffer_size_bytes, Device *device, const MemoryConfig& memory_config) { +DeviceBuffer allocate_contiguous_buffer_on_device( + uint32_t buffer_size_bytes, Device* device, const MemoryConfig& memory_config) { return std::make_shared(device, buffer_size_bytes, buffer_size_bytes, memory_config.buffer_type); } - -DeviceBuffer allocate_sharded_buffer_on_device(uint32_t buffer_size_bytes, Device *device, - const Shape& shape, DataType data_type, Layout layout, - std::optional shard_params, - const MemoryConfig& memory_config) { +DeviceBuffer allocate_sharded_buffer_on_device( + uint32_t buffer_size_bytes, + Device* device, + const Shape& shape, + DataType data_type, + Layout layout, + std::optional shard_params, + const MemoryConfig& memory_config) { validate_sharded_buffer_allocation(shape, layout, shard_params, memory_config); auto page_shape = shard_params.value().page_shape; - uint32_t size_of_element = element_size_bytes_wrapper(data_type); + uint32_t size_of_element = element_size_bytes(data_type); uint32_t page_size = page_shape[0] * page_shape[1] * size_of_element; - if(layout == Layout::TILE){ + if (layout == Layout::TILE) { page_size = get_page_size(data_type, layout, buffer_size_bytes, shape); } - return std::make_shared(device, buffer_size_bytes, page_size, - memory_config.buffer_type, - memory_config.memory_layout, - shard_params); -} - - + return std::make_shared( + device, buffer_size_bytes, page_size, memory_config.buffer_type, memory_config.memory_layout, shard_params); } +} // namespace detail - - -DeviceBuffer allocate_buffer_on_device(uint32_t buffer_size_bytes, Device *device, const Shape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, std::optional shard_spec) { +DeviceBuffer allocate_buffer_on_device( + uint32_t buffer_size_bytes, + Device* device, + const Shape& shape, + DataType data_type, + Layout layout, + const MemoryConfig& memory_config, + std::optional shard_spec) { if (memory_config.memory_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { - return detail::allocate_interleaved_buffer_on_device(buffer_size_bytes, device, shape, data_type, layout, memory_config); - } - else if(memory_config.memory_layout == tt::tt_metal::TensorMemoryLayout::SINGLE_BANK){ + return detail::allocate_interleaved_buffer_on_device( + buffer_size_bytes, device, shape, data_type, layout, memory_config); + } else if (memory_config.memory_layout == tt::tt_metal::TensorMemoryLayout::SINGLE_BANK) { return detail::allocate_contiguous_buffer_on_device(buffer_size_bytes, device, memory_config); - } - else { - TT_ASSERT( memory_config.is_sharded() && "Incorrect Memory Layout"); - return detail::allocate_sharded_buffer_on_device(buffer_size_bytes, device, shape, data_type, layout, shard_spec, memory_config); + } else { + TT_ASSERT(memory_config.is_sharded() && "Incorrect Memory Layout"); + return detail::allocate_sharded_buffer_on_device( + buffer_size_bytes, device, shape, data_type, layout, shard_spec, memory_config); } } -void validate_on_device_dtype_and_layout(Device *device, const Shape& shape, DataType dtype, Layout layout) { +void validate_on_device_dtype_and_layout(Device* device, const Shape& shape, DataType dtype, Layout layout) { // TODO: Get supported layout and dtypes from device auto supported_dtype = [&dtype]() { TT_ASSERT( - ( - dtype == DataType::UINT32 || - dtype == DataType::INT32 || - dtype == DataType::FLOAT32 || - dtype == DataType::UINT16 || - dtype == DataType::BFLOAT16 || - dtype == DataType::BFLOAT8_B || - dtype == DataType::BFLOAT4_B - ), - "Only UINT32, INT32, FLOAT32, UINT16, BFLOAT16, BFLOAT8_B, or BFLOAT4_B dtypes are supported on device!" - ); + (dtype == DataType::UINT32 || dtype == DataType::INT32 || dtype == DataType::FLOAT32 || + dtype == DataType::UINT16 || dtype == DataType::BFLOAT16 || dtype == DataType::BFLOAT8_B || + dtype == DataType::BFLOAT4_B), + "Only UINT32, INT32, FLOAT32, UINT16, BFLOAT16, BFLOAT8_B, or BFLOAT4_B dtypes are supported on device!"); }; auto supported_layout = [&shape, &dtype, &layout]() { switch (dtype) { case DataType::UINT32: case DataType::INT32: - case DataType::FLOAT32: - break; + case DataType::FLOAT32: break; case DataType::UINT16: case DataType::BFLOAT16: if (layout == Layout::ROW_MAJOR) { - TT_ASSERT(shape[-1] % 2 == 0, "For ROW_MAJOR layout tensors with dtype BFLOAT16 or UINT16, tensor width must be divisible by 2 since data is packed as uint32_t when creating buffers on device!"); + TT_ASSERT( + shape[-1] % 2 == 0, + "For ROW_MAJOR layout tensors with dtype BFLOAT16 or UINT16, tensor width must be divisible by " + "2 since data is packed as uint32_t when creating buffers on device!"); } break; case DataType::BFLOAT8_B: @@ -229,79 +270,115 @@ void validate_on_device_dtype_and_layout(Device *device, const Shape& shape, Dat TT_ASSERT(layout == Layout::TILE, "Only TILE layout is supported for BFLOAT8_B dtype!"); break; default: - TT_ASSERT(false, "Only UINT32, INT32, FLOAT32, UINT16, BFLOAT16, BFLOAT8_B, or BFLOAT4_B dtypes are supported on device!"); + TT_ASSERT( + false, + "Only UINT32, INT32, FLOAT32, UINT16, BFLOAT16, BFLOAT8_B, or BFLOAT4_B dtypes are supported on " + "device!"); break; - } + } }; supported_dtype(); supported_layout(); } -Tensor pad_bfloat8_b(const Tensor &tensor, const Shape& output_tensor_shape, const Shape& input_tensor_start, float pad_value) { +Tensor pad_bfloat8_b( + const Tensor& tensor, const Shape& output_tensor_shape, const Shape& input_tensor_start, float pad_value) { // TODO(arakhmati): do not convert to FLOAT32 // Convert to FLOAT32 tensor and pad auto input_packed_data = owned_buffer::get_as(tensor).get(); - auto input_float_data = unpack_bfp8_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); + auto input_float_data = + unpack_bfp8_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); - auto float_tensor = Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()).pad(output_tensor_shape, input_tensor_start, pad_value); + auto float_tensor = + Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()) + .pad(output_tensor_shape, input_tensor_start, pad_value); // Convert back to BFLOAT8_B auto output_float_data = owned_buffer::get_as(float_tensor).get(); - auto output_packed_data = pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_packed_data = + pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), float_tensor.get_legacy_shape(), DataType::BFLOAT8_B, tensor.get_layout()); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), + float_tensor.get_legacy_shape(), + DataType::BFLOAT8_B, + tensor.get_layout()); } -Tensor unpad_bfloat8_b(const Tensor &tensor, const Shape& output_tensor_start, const Shape& output_tensor_end) { +Tensor unpad_bfloat8_b(const Tensor& tensor, const Shape& output_tensor_start, const Shape& output_tensor_end) { // TODO(arakhmati): do not convert to FLOAT32 // Convert to FLOAT32 tensor and unpad auto input_packed_data = owned_buffer::get_as(tensor).get(); - auto input_float_data = unpack_bfp8_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); + auto input_float_data = + unpack_bfp8_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); - auto float_tensor = Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()).unpad(output_tensor_start, output_tensor_end); + auto float_tensor = + Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()) + .unpad(output_tensor_start, output_tensor_end); // Convert back to BFLOAT8_B auto output_float_data = owned_buffer::get_as(float_tensor).get(); - auto output_packed_data = pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_packed_data = + pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), float_tensor.get_legacy_shape(), DataType::BFLOAT8_B, tensor.get_layout()); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), + float_tensor.get_legacy_shape(), + DataType::BFLOAT8_B, + tensor.get_layout()); } -Tensor pad_bfloat4_b(const Tensor &tensor, const Shape& output_tensor_shape, const Shape& input_tensor_start, float pad_value) { +Tensor pad_bfloat4_b( + const Tensor& tensor, const Shape& output_tensor_shape, const Shape& input_tensor_start, float pad_value) { // TODO(arakhmati): do not convert to FLOAT32 // Convert to FLOAT32 tensor and pad auto input_packed_data = owned_buffer::get_as(tensor).get(); - auto input_float_data = unpack_bfp4_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); + auto input_float_data = + unpack_bfp4_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); - auto float_tensor = Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()).pad(output_tensor_shape, input_tensor_start, pad_value); + auto float_tensor = + Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()) + .pad(output_tensor_shape, input_tensor_start, pad_value); // Convert back to BFLOAT4_B auto output_float_data = owned_buffer::get_as(float_tensor).get(); - auto output_packed_data = pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_packed_data = + pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), float_tensor.get_legacy_shape(), DataType::BFLOAT4_B, tensor.get_layout()); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), + float_tensor.get_legacy_shape(), + DataType::BFLOAT4_B, + tensor.get_layout()); } -Tensor unpad_bfloat4_b(const Tensor &tensor, const Shape& output_tensor_start, const Shape& output_tensor_end) { +Tensor unpad_bfloat4_b(const Tensor& tensor, const Shape& output_tensor_start, const Shape& output_tensor_end) { // TODO(arakhmati): do not convert to FLOAT32 // Convert to FLOAT32 tensor and unpad auto input_packed_data = owned_buffer::get_as(tensor).get(); - auto input_float_data = unpack_bfp4_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); + auto input_float_data = + unpack_bfp4_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); - auto float_tensor = Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()).unpad(output_tensor_start, output_tensor_end); + auto float_tensor = + Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()) + .unpad(output_tensor_start, output_tensor_end); // Convert back to BFLOAT4_B auto output_float_data = owned_buffer::get_as(float_tensor).get(); - auto output_packed_data = pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_packed_data = + pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), float_tensor.get_legacy_shape(), DataType::BFLOAT4_B, tensor.get_layout()); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), + float_tensor.get_legacy_shape(), + DataType::BFLOAT4_B, + tensor.get_layout()); } - } // namespace tensor_impl } // namespace tt_metal diff --git a/tt_eager/tensor/tensor_impl.hpp b/tt_eager/tensor/tensor_impl.hpp index 7c8f71b25f1..c78a4cc7dac 100644 --- a/tt_eager/tensor/tensor_impl.hpp +++ b/tt_eager/tensor/tensor_impl.hpp @@ -10,7 +10,6 @@ #include "common/bfloat8.hpp" #include "tensor/host_buffer/functions.hpp" #include "tensor/tensor.hpp" -#include "tensor/tensor_impl_wrapper.hpp" #include "tensor/tensor_utils.hpp" #include "tensor/types.hpp" #include "tt_metal/detail/tt_metal.hpp" @@ -143,10 +142,7 @@ std::vector unpack_uint32_vec(std::vector& data_to_unpack) { } } -template -constexpr inline uint32_t element_size_bytes() { - return sizeof(T); -} +uint32_t element_size_bytes(DataType dtype); template constexpr inline uint32_t packed_buffer_size_bytes(uint32_t volume_unpacked_data) { @@ -161,6 +157,16 @@ constexpr inline uint32_t packed_buffer_size_bytes(uint32_t volume_unpack return (volume_unpacked_data / num_type_in_u32) * sizeof(uint32_t); } +template <> +constexpr inline uint32_t packed_buffer_size_bytes(uint32_t volume_unpacked_data) { + return packed_buffer_size_bytes(volume_unpacked_data); +} + +template <> +constexpr inline uint32_t packed_buffer_size_bytes(uint32_t volume_unpacked_data) { + return packed_buffer_size_bytes(volume_unpacked_data); +} + // ====================================================================================== // Layout converters // ====================================================================================== @@ -208,7 +214,8 @@ inline std::vector convert_layout_tile_to_row_major(const Shape& shape, const // Validators // ====================================================================================== void validate_on_device_dtype_and_layout(Device* device, const Shape& shape, DataType dtype, Layout layout); -void validate_sharded_buffer_allocation(const Shape& shape, Layout layout, std::optional shard_params, const MemoryConfig& memory_config); +void validate_sharded_buffer_allocation( + const Shape& shape, Layout layout, std::optional shard_params, const MemoryConfig& memory_config); // ----------------------------------------------------------------------------------------------------------------------------------------------- // =============================================================================================================================================== // High Level APIs @@ -393,6 +400,16 @@ inline Tensor to_host(const Tensor& tensor, bool blocking = true) { } } +template <> +inline Tensor to_host(const Tensor& tensor, bool blocking) { + return to_host(tensor, blocking); +} + +template <> +inline Tensor to_host(const Tensor& tensor, bool blocking) { + return to_host(tensor, blocking); +} + template inline Tensor to_host_sharded(const Tensor& tensor) { TT_ASSERT(tensor.is_allocated(), "Buffer must be allocated on device!"); @@ -411,6 +428,16 @@ inline Tensor to_host_sharded(const Tensor& tensor) { return Tensor(OwnedStorage{output_buffer}, tensor.get_legacy_shape(), tensor.get_dtype(), tensor.get_layout()); } +template <> +inline Tensor to_host_sharded(const Tensor& tensor) { + return to_host_sharded(tensor); +} + +template <> +inline Tensor to_host_sharded(const Tensor& tensor) { + return to_host_sharded(tensor); +} + template inline Tensor to_device( const Tensor& tensor, @@ -447,6 +474,24 @@ inline Tensor to_device( return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout); } +template <> +inline Tensor to_device( + const Tensor& tensor, + Device* target_device, + const MemoryConfig& memory_config, + std::optional> queue) { + return to_device(tensor, target_device, memory_config, queue); +} + +template <> +inline Tensor to_device( + const Tensor& tensor, + Device* target_device, + const MemoryConfig& memory_config, + std::optional> queue) { + return to_device(tensor, target_device, memory_config, queue); +} + template inline Tensor to_layout(const Tensor& tensor, Layout target_layout) { if (tensor.get_layout() == target_layout) { @@ -527,8 +572,7 @@ Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout); // .pad() and .unpad() // ====================================================================================== template -inline Tensor pad( - const Tensor& tensor, const Shape& output_shape, const Shape& input_tensor_start, float pad_value) { +inline Tensor pad(const Tensor& tensor, const Shape& output_shape, const Shape& input_tensor_start, float pad_value) { if (is_multi_device_tensor(tensor)) { return transform(tensor, [&](const Tensor& device_tensor) { return pad(device_tensor, output_shape, input_tensor_start, pad_value); @@ -540,13 +584,8 @@ inline Tensor pad( const auto input_strides = tensor.strides(); const auto input_data_type = tensor.get_dtype(); - auto pad = [&input_shape, - &input_strides, - &input_data_type, - &output_shape, - &input_tensor_start, - &pad_value_](const auto& input_buffer) { - + auto pad = [&input_shape, &input_strides, &input_data_type, &output_shape, &input_tensor_start, &pad_value_]( + const auto& input_buffer) { auto compute_stride = [](const Shape& shape, uint32_t index) { uint32_t stride = 1; for (auto i = index + 1; i < shape.rank(); i++) { @@ -563,13 +602,11 @@ inline Tensor pad( for (auto index = 0; index < output_shape.rank(); index++) { // Check if input tensor fits in output tensor given the input tensor start indices TT_ASSERT( - input_shape[index] + input_tensor_start[index] <= output_shape[index], - "Input tensor is out of bounds"); + input_shape[index] + input_tensor_start[index] <= output_shape[index], "Input tensor is out of bounds"); // Figure out pad size on each dim pad_size.push_back( - {input_tensor_start[index], - output_shape[index] - input_shape[index] - input_tensor_start[index]}); + {input_tensor_start[index], output_shape[index] - input_shape[index] - input_tensor_start[index]}); input_strides.push_back(compute_stride(input_shape, index)); output_strides.push_back(compute_stride(output_shape, index)); @@ -624,10 +661,20 @@ inline Tensor pad( return Tensor(OwnedStorage{output_buffer}, output_shape, tensor.get_dtype(), tensor.get_layout()); } -Tensor pad_bfloat8_b( - const Tensor& tensor, const Shape& output_shape, const Shape& input_tensor_start, float pad_value); -Tensor pad_bfloat4_b( - const Tensor& tensor, const Shape& output_shape, const Shape& input_tensor_start, float pad_value); +Tensor pad_bfloat8_b(const Tensor& tensor, const Shape& output_shape, const Shape& input_tensor_start, float pad_value); +Tensor pad_bfloat4_b(const Tensor& tensor, const Shape& output_shape, const Shape& input_tensor_start, float pad_value); + +template <> +inline Tensor pad( + const Tensor& tensor, const Shape& output_shape, const Shape& input_tensor_start, float pad_value) { + return pad_bfloat8_b(tensor, output_shape, input_tensor_start, pad_value); +} + +template <> +inline Tensor pad( + const Tensor& tensor, const Shape& output_shape, const Shape& input_tensor_start, float pad_value) { + return pad_bfloat4_b(tensor, output_shape, input_tensor_start, pad_value); +} template inline Tensor unpad(const Tensor& tensor, const Shape& output_tensor_start, const Shape& output_tensor_end) { @@ -695,6 +742,16 @@ inline Tensor unpad(const Tensor& tensor, const Shape& output_tensor_start, cons Tensor unpad_bfloat8_b(const Tensor& tensor, const Shape& output_tensor_start, const Shape& output_tensor_end); Tensor unpad_bfloat4_b(const Tensor& tensor, const Shape& output_tensor_start, const Shape& output_tensor_end); +template <> +inline Tensor unpad(const Tensor& tensor, const Shape& output_tensor_start, const Shape& output_tensor_end) { + return unpad_bfloat8_b(tensor, output_tensor_start, output_tensor_end); +} + +template <> +inline Tensor unpad(const Tensor& tensor, const Shape& output_tensor_start, const Shape& output_tensor_end) { + return unpad_bfloat4_b(tensor, output_tensor_start, output_tensor_end); +} + // ====================================================================================== // Print // ====================================================================================== @@ -895,22 +952,28 @@ inline std::string to_string(const Tensor& tensor, std::optional origi if (dtype == DataType::BFLOAT8_B and original_dtype == std::nullopt) { // Convert to FLOAT32 tensor before printing auto input_packed_data = owned_buffer::get_as(tensor).get(); - auto input_float_data = - unpack_bfp8_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); + auto input_float_data = unpack_bfp8_tiles_into_float_vec( + input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); - auto float_tensor = - Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()); + auto float_tensor = Tensor( + OwnedStorage{input_float_buffer}, + tensor.get_legacy_shape(), + DataType::FLOAT32, + tensor.get_layout()); return to_string(float_tensor, tensor.get_dtype()); } if (dtype == DataType::BFLOAT4_B and original_dtype == std::nullopt) { // Convert to FLOAT32 tensor before printing auto input_packed_data = owned_buffer::get_as(tensor).get(); - auto input_float_data = - unpack_bfp4_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); + auto input_float_data = unpack_bfp4_tiles_into_float_vec( + input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); - auto float_tensor = - Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()); + auto float_tensor = Tensor( + OwnedStorage{input_float_buffer}, + tensor.get_legacy_shape(), + DataType::FLOAT32, + tensor.get_layout()); return to_string(float_tensor, tensor.get_dtype()); } const auto buffer = owned_buffer::get_as(storage.buffer); @@ -941,6 +1004,16 @@ inline std::string to_string(const Tensor& tensor, std::optional origi tensor.get_storage()); } +template <> +inline std::string to_string(const Tensor& tensor, std::optional original_dtype) { + return to_string(tensor, original_dtype); +} + +template <> +inline std::string to_string(const Tensor& tensor, std::optional original_dtype) { + return to_string(tensor, original_dtype); +} + template Tensor extract_shard(const Tensor& tensor, const uint32_t& core_id) { auto buffer = tensor.buffer(); @@ -955,6 +1028,16 @@ Tensor extract_shard(const Tensor& tensor, const uint32_t& core_id) { return Tensor(OwnedStorage{output_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()); } +template <> +inline Tensor extract_shard(const Tensor& tensor, const uint32_t& core_id) { + return extract_shard(tensor, core_id); +} + +template <> +inline Tensor extract_shard(const Tensor& tensor, const uint32_t& core_id) { + return extract_shard(tensor, core_id); +} + template void* get_raw_host_data_ptr(const Tensor& tensor) { return std::visit( @@ -987,21 +1070,21 @@ void* get_raw_host_data_ptr(const Tensor& tensor) { } // Template Specialization for unpack_bfloat_tiles_into_float {bfp4,bfp8} -template +template inline std::vector unpack_bfloat_tiles_into_float_vec(const bfloat8_b&, Args&&... args) { return unpack_bfp8_tiles_into_float_vec(std::forward(args)...); } -template +template inline std::vector unpack_bfloat_tiles_into_float_vec(const bfloat4_b&, Args&&... args) { return unpack_bfp4_tiles_into_float_vec(std::forward(args)...); } // Template Specialization for pack_fp32_vec_as_bfp4_tiles {bfp4,bfp8} -template +template inline std::vector pack_fp32_vec_as_bfloat_tiles(const bfloat8_b&, Args&&... args) { return pack_fp32_vec_as_bfp8_tiles(std::forward(args)...); } -template +template inline std::vector pack_fp32_vec_as_bfloat_tiles(const bfloat4_b&, Args&&... args) { return pack_fp32_vec_as_bfp4_tiles(std::forward(args)...); } @@ -1020,14 +1103,13 @@ struct bfloat_enum { static constexpr DataType value = DataType::BFLOAT4_B; }; - template -Tensor to_layout_bfloat(const Tensor &tensor, Layout target_layout) { +Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout) { static_assert(std::is_same_v || std::is_same_v, "Invalid type T"); // TODO(arakhmati): do not convert to FLOA32 - if(tensor.get_layout() == target_layout) { + if (tensor.get_layout() == target_layout) { return tensor; } return std::visit( @@ -1038,13 +1120,20 @@ Tensor to_layout_bfloat(const Tensor &tensor, Layout target_layout) { for (int i = 0; i < storage.buffers.size(); i++) { // Convert to FLOAT32 tensor and change layout auto input_packed_data = owned_buffer::get_as(storage.buffers[i]).get(); - auto input_float_data = unpack_bfloat_tiles_into_float_vec(T{}, input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); + auto input_float_data = unpack_bfloat_tiles_into_float_vec( + T{}, input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); - auto float_tensor = Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()).to(target_layout); + auto float_tensor = Tensor( + OwnedStorage{input_float_buffer}, + tensor.get_legacy_shape(), + DataType::FLOAT32, + tensor.get_layout()) + .to(target_layout); // Convert back to BFLOAT8_B auto output_float_data = owned_buffer::get_as(float_tensor).get(); - auto output_packed_data = pack_fp32_vec_as_bfloat_tiles(T{}, output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_packed_data = pack_fp32_vec_as_bfloat_tiles( + T{}, output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); output_buffers.push_back(output_uint32_buffer); } @@ -1052,23 +1141,44 @@ Tensor to_layout_bfloat(const Tensor &tensor, Layout target_layout) { std::move(MultiDeviceHostStorage{storage.strategy, output_buffers, storage.shapes}), tensor.get_legacy_shape(), bfloat_enum::value, - target_layout - ); + target_layout); } else { // Convert to FLOAT32 tensor and change layout auto input_packed_data = owned_buffer::get_as(tensor).get(); - auto input_float_data = unpack_bfloat_tiles_into_float_vec(T{}, input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); + auto input_float_data = unpack_bfloat_tiles_into_float_vec( + T{}, input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); - auto float_tensor = Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()).to(target_layout); + auto float_tensor = Tensor( + OwnedStorage{input_float_buffer}, + tensor.get_legacy_shape(), + DataType::FLOAT32, + tensor.get_layout()) + .to(target_layout); // Convert back to BFLOAT auto output_float_data = owned_buffer::get_as(float_tensor).get(); - auto output_packed_data = pack_fp32_vec_as_bfloat_tiles(T{}, output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_packed_data = pack_fp32_vec_as_bfloat_tiles( + T{}, output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), tensor.get_legacy_shape(), bfloat_enum::value, target_layout); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), + tensor.get_legacy_shape(), + bfloat_enum::value, + target_layout); } - }, tensor.get_storage()); + }, + tensor.get_storage()); +} + +template <> +inline Tensor to_layout(const Tensor& tensor, Layout target_layout) { + return to_layout_bfloat(tensor, target_layout); +} + +template <> +inline Tensor to_layout(const Tensor& tensor, Layout target_layout) { + return to_layout_bfloat(tensor, target_layout); } } // namespace tensor_impl diff --git a/tt_eager/tensor/tensor_impl_wrapper.cpp b/tt_eager/tensor/tensor_impl_wrapper.cpp deleted file mode 100644 index 29afd08d8aa..00000000000 --- a/tt_eager/tensor/tensor_impl_wrapper.cpp +++ /dev/null @@ -1,158 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "tensor/tensor_impl_wrapper.hpp" -#include "tensor/tensor_impl.hpp" - -#include "common/bfloat16.hpp" -#include "common/bfloat4.hpp" -#include "common/bfloat8.hpp" -#include - -namespace tt { - -namespace tt_metal { - -namespace tensor_impl { - -uint32_t element_size_bytes_wrapper(DataType dtype) { - const static std::map> element_size_bytes_map = { - {DataType::BFLOAT16, &element_size_bytes}, - {DataType::FLOAT32, &element_size_bytes}, - {DataType::INT32, &element_size_bytes}, - {DataType::UINT32, &element_size_bytes}, - {DataType::UINT16, &element_size_bytes}, - {DataType::BFLOAT8_B, &element_size_bytes}, - {DataType::BFLOAT4_B, &element_size_bytes}, - }; - return element_size_bytes_map.at(dtype)(); -} - -uint32_t packed_buffer_size_bytes_wrapper(DataType dtype, uint32_t volume_unpacked_data) { - const static std::map> packed_buffer_size_bytes_map = { - {DataType::BFLOAT16, &packed_buffer_size_bytes}, - {DataType::FLOAT32, &packed_buffer_size_bytes}, - {DataType::INT32, &packed_buffer_size_bytes}, - {DataType::UINT32, &packed_buffer_size_bytes}, - {DataType::BFLOAT8_B, &packed_buffer_size_bytes}, - {DataType::BFLOAT4_B, &packed_buffer_size_bytes}, - {DataType::UINT16, &packed_buffer_size_bytes}, - }; - return packed_buffer_size_bytes_map.at(dtype)(volume_unpacked_data); -} - -Tensor to_host_wrapper(const Tensor &tensor, bool blocking) { - const static std::map> to_host_map = { - {DataType::BFLOAT16, &to_host}, - {DataType::FLOAT32, &to_host}, - {DataType::INT32, &to_host}, - {DataType::UINT32, &to_host}, - {DataType::BFLOAT8_B, &to_host}, - {DataType::BFLOAT4_B, &to_host}, - {DataType::UINT16, &to_host}, - }; - return to_host_map.at(tensor.get_dtype())(tensor, blocking); -} - - -Tensor to_extract_shard_wrapper(const Tensor &tensor, const uint32_t & core_id) { - const static std::map> to_host_map = { - {DataType::BFLOAT16, &extract_shard}, - {DataType::FLOAT32, &extract_shard}, - {DataType::INT32, &extract_shard}, - {DataType::UINT32, &extract_shard}, - {DataType::BFLOAT8_B, &extract_shard}, - {DataType::BFLOAT4_B, &extract_shard}, - {DataType::UINT16, &extract_shard}, - }; - return to_host_map.at(tensor.get_dtype())(tensor, core_id); -} - -Tensor to_host_wrapper_sharded(const Tensor &tensor) { - const static std::map> to_host_map = { - {DataType::BFLOAT16, &to_host_sharded}, - {DataType::FLOAT32, &to_host_sharded}, - {DataType::INT32, &to_host_sharded}, - {DataType::UINT32, &to_host_sharded}, - {DataType::BFLOAT8_B, &to_host_sharded}, - {DataType::BFLOAT4_B, &to_host_sharded}, - {DataType::UINT16, &to_host_sharded}, - }; - return to_host_map.at(tensor.get_dtype())(tensor); -} - -Tensor to_device_wrapper(const Tensor &tensor, Device *target_device, const MemoryConfig &mem_config, std::optional< std::reference_wrapper > q) { - const static std::unordered_map> )>> - to_device_map = { - {DataType::BFLOAT16, &to_device}, - {DataType::FLOAT32, &to_device}, - {DataType::INT32, &to_device}, - {DataType::UINT32, &to_device}, - {DataType::BFLOAT8_B, &to_device}, - {DataType::BFLOAT4_B, &to_device}, - {DataType::UINT16, &to_device}, - }; - return to_device_map.at(tensor.get_dtype())(tensor, target_device, mem_config, q); -} - - -Tensor to_layout_wrapper(const Tensor &tensor, Layout target_layout) { - const static std::unordered_map> to_layout_map = { - {DataType::BFLOAT16, &to_layout}, - {DataType::FLOAT32, &to_layout}, - {DataType::INT32, &to_layout}, - {DataType::UINT32, &to_layout}, - {DataType::BFLOAT8_B, &to_layout_bfloat}, - {DataType::BFLOAT4_B, &to_layout_bfloat}, - {DataType::UINT16, &to_layout}, - }; - return to_layout_map.at(tensor.get_dtype())(tensor, target_layout); -} - -Tensor pad_wrapper(const Tensor &tensor, const Shape &output_tensor_shape, const Shape &input_tensor_start, float pad_value) { - const static std::unordered_map> - pad_map = { - {DataType::BFLOAT16, &pad}, - {DataType::FLOAT32, &pad}, - {DataType::INT32, &pad}, - {DataType::UINT32, &pad}, - {DataType::BFLOAT8_B, &pad_bfloat8_b}, - {DataType::BFLOAT4_B, &pad_bfloat4_b}, - {DataType::UINT16, &pad}, - }; - return pad_map.at(tensor.get_dtype())(tensor, output_tensor_shape, input_tensor_start, pad_value); -} - -Tensor unpad_wrapper(const Tensor &tensor, const Shape &output_tensor_start, const Shape &output_tensor_end) { - const static std::unordered_map> unpad_map = { - {DataType::BFLOAT16, &unpad}, - {DataType::FLOAT32, &unpad}, - {DataType::INT32, &unpad}, - {DataType::UINT32, &unpad}, - {DataType::BFLOAT8_B, &unpad_bfloat8_b}, - {DataType::BFLOAT4_B, &unpad_bfloat4_b}, - {DataType::UINT16, &unpad}, - }; - return unpad_map.at(tensor.get_dtype())(tensor, output_tensor_start, output_tensor_end); -} - -std::string to_string_wrapper(const Tensor &tensor) { - const static std::unordered_map)>> - to_string_map = { - {DataType::BFLOAT16, &to_string}, - {DataType::FLOAT32, &to_string}, - {DataType::INT32, &to_string}, - {DataType::UINT32, &to_string}, - {DataType::BFLOAT8_B, &to_string}, - {DataType::BFLOAT4_B, &to_string}, - {DataType::UINT16, &to_string}, - }; - return to_string_map.at(tensor.get_dtype())(tensor, std::nullopt); -} - -} // namespace tensor_impl - -} // namespace tt_metal - -} // namespace tt diff --git a/tt_eager/tensor/tensor_impl_wrapper.hpp b/tt_eager/tensor/tensor_impl_wrapper.hpp index 25f201a5324..8b50ea47b36 100644 --- a/tt_eager/tensor/tensor_impl_wrapper.hpp +++ b/tt_eager/tensor/tensor_impl_wrapper.hpp @@ -4,38 +4,48 @@ #pragma once -#include "tensor/tensor.hpp" #include "tensor/tensor_impl.hpp" -namespace tt { - -namespace tt_metal { - -namespace tensor_impl { - - -uint32_t element_size_bytes_wrapper(DataType dtype); - -uint32_t packed_buffer_size_bytes_wrapper(DataType dtype, uint32_t volume_unpacked_data); - -Tensor to_host_wrapper(const Tensor &tensor, bool blocking = true); - -Tensor to_host_wrapper_sharded(const Tensor &tensor); - -Tensor to_extract_shard_wrapper(const Tensor &tensor, const uint32_t & core_id); - -Tensor to_device_wrapper(const Tensor &tensor, Device *target_device, const MemoryConfig &mem_config, std::optional> queue = std::nullopt); - -Tensor to_layout_wrapper(const Tensor &tensor, Layout target_layout); - -Tensor pad_wrapper(const Tensor &tensor, const Shape &output_tensor_shape, const Shape &input_tensor_start, float pad_value); - -Tensor unpad_wrapper(const Tensor &tensor, const Shape &output_tensor_start, const Shape &output_tensor_end); - -std::string to_string_wrapper(const Tensor &tensor); - -} // namespace tensor_impl - -} // namespace tt_metal - -} // namespace tt +namespace tt::tt_metal::tensor_impl { + +// Utility to convert runtime DataType to compile-time constant and dispatch the function call +template +auto dispatch(DataType dtype, Func &&func, Args &&...args) { + switch (dtype) { + case DataType::BFLOAT16: return func.template operator()(static_cast(args)...); + case DataType::FLOAT32: return func.template operator()(static_cast(args)...); + case DataType::INT32: return func.template operator()(static_cast(args)...); + case DataType::UINT32: return func.template operator()(static_cast(args)...); + case DataType::UINT16: return func.template operator()(static_cast(args)...); + case DataType::BFLOAT8_B: return func.template operator()(static_cast(args)...); + case DataType::BFLOAT4_B: return func.template operator()(static_cast(args)...); + default: TT_THROW("Unsupported data type"); + } +} + +#define AS_LAMBDA(func) [](auto &&...args) { return func(std::forward(args)...); } + +#define WRAP_FUNCTION(func) \ + template \ + auto func##_wrapper(Args &&...args) { \ + return dispatch( \ + std::get<0>(std::forward_as_tuple(args...)).get_dtype(), AS_LAMBDA(func), std::forward(args)...); \ + } + +inline uint32_t packed_buffer_size_bytes_wrapper(DataType dtype, uint32_t volume_unpacked_data) { + return dispatch(dtype, AS_LAMBDA(packed_buffer_size_bytes), volume_unpacked_data); +} + +WRAP_FUNCTION(to_host) +WRAP_FUNCTION(extract_shard) +WRAP_FUNCTION(to_host_sharded) +WRAP_FUNCTION(to_device) +WRAP_FUNCTION(to_layout) +WRAP_FUNCTION(pad) +WRAP_FUNCTION(unpad) +WRAP_FUNCTION(to_string) + +#undef WRAP_FUNCTION +#undef AS_LAMBDA + +} // namespace tt::tt_metal::tensor_impl diff --git a/ttnn/cpp/ttnn/async_runtime.cpp b/ttnn/cpp/ttnn/async_runtime.cpp index fa523714679..627a5c20e25 100644 --- a/ttnn/cpp/ttnn/async_runtime.cpp +++ b/ttnn/cpp/ttnn/async_runtime.cpp @@ -3,116 +3,131 @@ // SPDX-License-Identifier: Apache-2.0 #include "async_runtime.hpp" + #include "tt_eager/tensor/tensor_impl.hpp" #include "tt_eager/tensor/tensor_impl_wrapper.hpp" namespace ttnn { - using DeviceBuffer = std::shared_ptr; - using queue_id = uint8_t; +using DeviceBuffer = std::shared_ptr; +using queue_id = uint8_t; - DeviceBuffer allocate_interleaved_buffer_on_device(uint32_t buffer_size_bytes, Device *device, const Shape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config) { - uint32_t page_size = tt::tt_metal::tensor_impl::get_page_size(data_type, layout, buffer_size_bytes, shape.value()); - return std::make_shared(device, buffer_size_bytes, page_size, memory_config.buffer_type); - } +DeviceBuffer allocate_interleaved_buffer_on_device( + uint32_t buffer_size_bytes, + Device* device, + const Shape& shape, + DataType data_type, + Layout layout, + const MemoryConfig& memory_config) { + uint32_t page_size = tt::tt_metal::tensor_impl::get_page_size(data_type, layout, buffer_size_bytes, shape.value()); + return std::make_shared(device, buffer_size_bytes, page_size, memory_config.buffer_type); +} + +DeviceBuffer allocate_contiguous_buffer_on_device( + uint32_t buffer_size_bytes, Device* device, const MemoryConfig& memory_config) { + return std::make_shared(device, buffer_size_bytes, buffer_size_bytes, memory_config.buffer_type); +} - DeviceBuffer allocate_contiguous_buffer_on_device(uint32_t buffer_size_bytes, Device *device, const MemoryConfig& memory_config) { - return std::make_shared(device, buffer_size_bytes, buffer_size_bytes, memory_config.buffer_type); +DeviceBuffer allocate_sharded_buffer_on_device( + uint32_t buffer_size_bytes, + Device* device, + const Shape& shape, + DataType data_type, + Layout layout, + std::optional shard_params, + const MemoryConfig& memory_config) { + tt::tt_metal::tensor_impl::validate_sharded_buffer_allocation(shape.value(), layout, shard_params, memory_config); + auto page_shape = shard_params.value().page_shape; + uint32_t size_of_element = tt::tt_metal::tensor_impl::element_size_bytes(data_type); + uint32_t page_size = page_shape[0] * page_shape[1] * size_of_element; + if (layout == Layout::TILE) { + page_size = tt::tt_metal::tensor_impl::get_page_size(data_type, layout, buffer_size_bytes, shape.value()); } - DeviceBuffer allocate_sharded_buffer_on_device(uint32_t buffer_size_bytes, Device *device, - const Shape& shape, DataType data_type, Layout layout, - std::optional shard_params, - const MemoryConfig& memory_config) { - tt::tt_metal::tensor_impl::validate_sharded_buffer_allocation(shape.value(), layout, shard_params, memory_config); - auto page_shape = shard_params.value().page_shape; - uint32_t size_of_element = tt::tt_metal::tensor_impl::element_size_bytes_wrapper(data_type); - uint32_t page_size = page_shape[0] * page_shape[1] * size_of_element; - if(layout == Layout::TILE){ - page_size = tt::tt_metal::tensor_impl::get_page_size(data_type, layout, buffer_size_bytes, shape.value()); - } + return std::make_shared( + device, buffer_size_bytes, page_size, memory_config.buffer_type, memory_config.memory_layout, shard_params); +} - return std::make_shared(device, buffer_size_bytes, page_size, - memory_config.buffer_type, - memory_config.memory_layout, - shard_params); +DeviceBuffer allocate_buffer_on_device( + uint32_t buffer_size_bytes, + types::Device* device, + const Shape& shape, + DataType data_type, + Layout layout, + const MemoryConfig& memory_config, + std::optional shard_spec) { + if (memory_config.memory_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { + return allocate_interleaved_buffer_on_device( + buffer_size_bytes, device, shape, data_type, layout, memory_config); + } else if (memory_config.memory_layout == tt::tt_metal::TensorMemoryLayout::SINGLE_BANK) { + return allocate_contiguous_buffer_on_device(buffer_size_bytes, device, memory_config); + } else { + return allocate_sharded_buffer_on_device( + buffer_size_bytes, device, shape, data_type, layout, shard_spec, memory_config); } +} - DeviceBuffer allocate_buffer_on_device(uint32_t buffer_size_bytes, types::Device* device, const Shape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, std::optional shard_spec) { - if (memory_config.memory_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { - return allocate_interleaved_buffer_on_device(buffer_size_bytes, device, shape, data_type, layout, memory_config); - } - else if(memory_config.memory_layout == tt::tt_metal::TensorMemoryLayout::SINGLE_BANK){ - return allocate_contiguous_buffer_on_device(buffer_size_bytes, device, memory_config); - } - else { - return allocate_sharded_buffer_on_device(buffer_size_bytes, device, shape, data_type, layout, shard_spec, memory_config); - } +void write_buffer( + queue_id cq_id, + Tensor& dst, + std::vector> src, + const std::optional transfer_size) { + uint32_t dst_ref_count = dst.tensor_attributes->record_main_thread_ref_count(); + for (const auto worker : dst.get_workers()) { + auto src_for_device = src.at(worker->id()); + worker->push_work([worker, src_for_device, dst, cq_id, transfer_size]() { + auto shard = tt::tt_metal::get_shard_for_device(dst, worker); + tt::tt_metal::memcpy(worker->command_queue(cq_id), shard, src_for_device.get(), transfer_size); + }); } + dst.tensor_attributes->update_main_thread_ref_count(dst.workers.at(0), dst_ref_count); +} - void write_buffer(queue_id cq_id, Tensor& dst, std::vector> src, const std::optional transfer_size) { - uint32_t dst_ref_count = dst.tensor_attributes->record_main_thread_ref_count(); - for (const auto worker : dst.get_workers()) { - auto src_for_device = src.at(worker->id()); - worker->push_work( - [worker, src_for_device, dst, cq_id, transfer_size] () { - auto shard = tt::tt_metal::get_shard_for_device(dst, worker); - tt::tt_metal::memcpy(worker->command_queue(cq_id), shard, src_for_device.get(), transfer_size); - }); - } - dst.tensor_attributes->update_main_thread_ref_count(dst.workers.at(0), dst_ref_count); +void read_buffer( + queue_id cq_id, + Tensor& src, + std::vector> dst, + const std::optional transfer_size, + size_t src_offset, + bool blocking) { + TT_ASSERT(src_offset == 0, "src_offset is not supported"); + uint32_t src_ref_count = src.tensor_attributes->record_main_thread_ref_count(); + for (const auto worker : src.get_workers()) { + auto dst_for_device = dst.at(worker->id()); + worker->push_work([worker, dst_for_device, src, cq_id, transfer_size, src_offset, blocking]() { + const auto& shard = tt::tt_metal::get_shard_for_device(src, worker); + tt::tt_metal::memcpy(worker->command_queue(cq_id), dst_for_device.get(), shard, transfer_size, blocking); + }); } - - void read_buffer(queue_id cq_id, Tensor& src, std::vector> dst, const std::optional transfer_size, size_t src_offset, bool blocking) { - TT_ASSERT(src_offset == 0, "src_offset is not supported"); - uint32_t src_ref_count = src.tensor_attributes->record_main_thread_ref_count(); - for (const auto worker : src.get_workers()) { - auto dst_for_device = dst.at(worker->id()); - worker->push_work( - [worker, dst_for_device, src, cq_id, transfer_size, src_offset, blocking] () { - const auto& shard = tt::tt_metal::get_shard_for_device(src, worker); - tt::tt_metal::memcpy(worker->command_queue(cq_id), dst_for_device.get(), shard, transfer_size, blocking); - }); - } - if (blocking) { - for (auto worker : src.get_workers()) { - worker->synchronize(); - } + if (blocking) { + for (auto worker : src.get_workers()) { + worker->synchronize(); } - src.tensor_attributes->update_main_thread_ref_count(src.workers.at(0), src_ref_count); } + src.tensor_attributes->update_main_thread_ref_count(src.workers.at(0), src_ref_count); +} - void queue_synchronize(CommandQueue& cq) { - // Ensure that all work pushed to async engine has been passed - // off to device CQ - cq.device()->synchronize(); - // Wait for device CQ to finish - Finish(cq); - } +void queue_synchronize(CommandQueue& cq) { + // Ensure that all work pushed to async engine has been passed + // off to device CQ + cq.device()->synchronize(); + // Wait for device CQ to finish + Finish(cq); +} - void event_synchronize(std::shared_ptr event) { - EventSynchronize(event); - } +void event_synchronize(std::shared_ptr event) { EventSynchronize(event); } - bool event_query(std::shared_ptr event) { - return EventQuery(event); - } +bool event_query(std::shared_ptr event) { return EventQuery(event); } - void wait_for_event(CommandQueue& cq, std::shared_ptr event) { - auto cq_id = cq.id(); - auto cq_worker = cq.device(); - cq_worker->push_work( - [cq_worker, cq_id, event] () { - EnqueueWaitForEvent(cq_worker->command_queue(cq_id), event); - }); - } +void wait_for_event(CommandQueue& cq, std::shared_ptr event) { + auto cq_id = cq.id(); + auto cq_worker = cq.device(); + cq_worker->push_work([cq_worker, cq_id, event]() { EnqueueWaitForEvent(cq_worker->command_queue(cq_id), event); }); +} - void record_event(CommandQueue& cq, std::shared_ptr event) { - auto cq_id = cq.id(); - auto cq_worker = cq.device(); - cq_worker->push_work( - [cq_worker, cq_id, event] () { - EnqueueRecordEvent(cq_worker->command_queue(cq_id), event); - }); - } +void record_event(CommandQueue& cq, std::shared_ptr event) { + auto cq_id = cq.id(); + auto cq_worker = cq.device(); + cq_worker->push_work([cq_worker, cq_id, event]() { EnqueueRecordEvent(cq_worker->command_queue(cq_id), event); }); +} -} // namespace::ttnn +} // namespace ttnn