diff --git a/tests/tt_eager/module.mk b/tests/tt_eager/module.mk index f8d10942848..585df78c04d 100644 --- a/tests/tt_eager/module.mk +++ b/tests/tt_eager/module.mk @@ -27,6 +27,7 @@ TT_EAGER_TESTS += \ tests/tt_eager/ops/test_sfpu \ tests/tt_eager/tensors/test_copy_and_move \ tests/tt_eager/tensors/test_host_device_loopback \ + tests/tt_eager/tensors/test_raw_host_memory_pointer \ tests/tt_eager/tensors/test_sharded_loopback \ tests/tt_eager/integration_tests/test_bert \ diff --git a/tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp b/tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp new file mode 100644 index 00000000000..14bff594054 --- /dev/null +++ b/tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp @@ -0,0 +1,188 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include + +#include "common/bfloat16.hpp" +#include "common/constants.hpp" +#include "tensor/owned_buffer.hpp" +#include "tensor/owned_buffer_functions.hpp" +#include "tensor/tensor.hpp" +#include "tensor/tensor_impl.hpp" +#include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp" +#include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_numpy/functions.hpp" + +/* + +01 import torch +02: import numpy as np +03: +04: device = torch.device("cuda:0") +05: +06: // define tensors on the CPU +07: a_cpu = np.array([[1,2,3,4],[5,6,7,8]], dtype=np.bfloat16) +08: +09: // define tensors on the device with CPU tensors +10: a_dev = torch.from_numpy(a_cpu).to(device) +11: +12: c_dev = torch.sqrt(a_dev) +13: +14: print(c_dev[1][0]) +15: +16: d_cpu = np.array([[11,12,13,14],[15,16,17,18]]) +17: d_dev = d_cpu.to(device) +18: +19: e_dev = c_dev + d_dev +20: print(e_dev) + +*/ + +namespace numpy { + +template +struct ndarray { + Shape shape; + void* data; + + ndarray(Shape shape) : shape(shape), data(malloc(tt::tt_metal::compute_volume(shape) * sizeof(DataType))) {} + ~ndarray() { free(data); } + + std::size_t size() const { return tt::tt_metal::compute_volume(shape); } +}; + +} // namespace numpy + +void test_raw_host_memory_pointer() { + using tt::tt_metal::BorrowedStorage; + using tt::tt_metal::DataType; + using tt::tt_metal::OwnedStorage; + using tt::tt_metal::Shape; + using tt::tt_metal::Tensor; + using namespace tt::tt_metal::borrowed_buffer; + using namespace tt::tt_metal::owned_buffer; + + int device_id = 0; + tt::tt_metal::Device* device = tt::tt_metal::CreateDevice(device_id); + + Shape shape = {1, 1, tt::constants::TILE_HEIGHT, tt::constants::TILE_WIDTH}; + + // Host tensor to print the output + Tensor tensor_for_printing = Tensor( + OwnedStorage{owned_buffer::create(tt::tt_metal::compute_volume(shape))}, + shape, + DataType::BFLOAT16, + Layout::TILE); + + /* Borrow Data from Numpy Start */ + // Create some + auto a_np_array = numpy::ndarray(shape); + void* a_np_array_data = a_np_array.data; + auto on_creation_callback = [] {}; + auto on_destruction_callback = [] {}; + Tensor a_cpu = Tensor( + BorrowedStorage{ + borrowed_buffer::Buffer(static_cast(a_np_array_data), a_np_array.size()), + on_creation_callback, + on_destruction_callback}, + shape, + DataType::BFLOAT16, + Layout::ROW_MAJOR); + /* Borrow Data from Numpy End */ + + /* Sanity Check Start */ + + // Test that borrowing numpy array's data works correctly + + // Set every value of tt Tensor to the same non-zero number + bfloat16 a_value = 4.0f; + + for (auto& element : borrowed_buffer::get_as(a_cpu)) { + element = a_value; + } + + // Check that numpy array's data is now set to that value + for (auto index = 0; index < a_np_array.size(); index++) { + auto a_np_array_element = static_cast(a_np_array_data)[index]; + TT_ASSERT(a_np_array_element == a_value); + } + /* Sanity Check End */ + + /* Run and Print Start */ + Tensor a_dev = a_cpu.to(device); + + Tensor c_dev = tt::tt_metal::sqrt(a_dev); + + tt::tt_metal::memcpy(tensor_for_printing, c_dev); + + // Check that cpu tensor has correct data + bfloat16 output_value = 1.99219f; // Not exactly 2.0f because of rounding errors + for (auto& element : owned_buffer::get_as(tensor_for_printing)) { + TT_ASSERT(element == output_value); + } + + tensor_for_printing.print(); + /* Run and Print End */ + + /* Alternative Way to Print Start */ + // Alternatively, we could allocate memory manually and create Tensors with BorrowedStorage on the fly to print the + // data + void* alternative_tensor_for_printing_data = malloc(tt::tt_metal::compute_volume(shape) * sizeof(bfloat16)); + Tensor alternative_tensor_for_printing = Tensor( + BorrowedStorage{ + borrowed_buffer::Buffer( + static_cast(alternative_tensor_for_printing_data), tt::tt_metal::compute_volume(shape)), + on_creation_callback, + on_destruction_callback}, + shape, + DataType::BFLOAT16, + Layout::TILE); + tt::tt_metal::memcpy(alternative_tensor_for_printing, c_dev); + alternative_tensor_for_printing.print(); + + for (auto& element : borrowed_buffer::get_as(alternative_tensor_for_printing)) { + TT_ASSERT(element == output_value); + } + + free(alternative_tensor_for_printing_data); + /* Alternative Way to Print End */ + + auto d_np_array = numpy::ndarray(shape); + void* d_np_array_data = d_np_array.data; + Tensor d_cpu = Tensor( + BorrowedStorage{ + borrowed_buffer::Buffer(static_cast(d_np_array_data), d_np_array.size()), + on_creation_callback, + on_destruction_callback}, + shape, + DataType::BFLOAT16, + Layout::ROW_MAJOR); + + bfloat16 d_value = 8.0f; + for (auto& element : borrowed_buffer::get_as(d_cpu)) { + element = d_value; + } + + Tensor d_dev = a_dev; + memcpy(d_dev, d_cpu); + + Tensor e_dev = tt::tt_metal::add(c_dev, d_dev); + + tt::tt_metal::memcpy(tensor_for_printing, e_dev); + tensor_for_printing.print(); + + for (auto& element : owned_buffer::get_as(tensor_for_printing)) { + TT_ASSERT(element == bfloat16(10.0f)); + } + + TT_FATAL(tt::tt_metal::CloseDevice(device)); +} + +int main() { + test_raw_host_memory_pointer(); + return 0; +} diff --git a/tt_eager/tensor/borrowed_buffer.hpp b/tt_eager/tensor/borrowed_buffer.hpp index 6772f29174b..527a3ad14eb 100644 --- a/tt_eager/tensor/borrowed_buffer.hpp +++ b/tt_eager/tensor/borrowed_buffer.hpp @@ -30,7 +30,10 @@ struct Buffer { inline const T* begin() const noexcept { return this->data_ptr_; } inline const T* end() const noexcept { return this->data_ptr_ + this->size(); } - private: + inline void* data() noexcept { return static_cast(this->data_ptr_); } + inline const void* data() const noexcept { return static_cast(this->data_ptr_); } + + private: T* data_ptr_; std::size_t size_; }; diff --git a/tt_eager/tensor/borrowed_buffer_functions.hpp b/tt_eager/tensor/borrowed_buffer_functions.hpp index a6bcef6643d..86c2b4fb39b 100644 --- a/tt_eager/tensor/borrowed_buffer_functions.hpp +++ b/tt_eager/tensor/borrowed_buffer_functions.hpp @@ -40,32 +40,30 @@ template Buffer get_as(Tensor& tensor) { validate_datatype(tensor); return std::visit( - [&] (auto&& storage) { + [](auto&& storage) -> Buffer { using StorageType = std::decay_t; if constexpr (std::is_same_v) { return get_as(storage.buffer); } else { - TT_THROW("Must be a BorrowedStorage"); + TT_THROW("Tensor must have BorrowedStorage"); } }, - tensor.storage() - ); + tensor.storage()); } template const Buffer get_as(const Tensor& tensor) { validate_datatype(tensor); return std::visit( - [] (auto&& storage) { + [](auto&& storage) -> Buffer { using StorageType = std::decay_t; if constexpr (std::is_same_v) { return get_as(storage.buffer); } else { - TT_THROW("Must be an BorrowedStorage"); + TT_THROW("Tensor must have BorrowedStorage"); } }, - tensor.storage() - ); + tensor.storage()); } } // namespace borrowed_buffer diff --git a/tt_eager/tensor/owned_buffer.hpp b/tt_eager/tensor/owned_buffer.hpp index 6d674cf3b87..d4b7b79ad98 100644 --- a/tt_eager/tensor/owned_buffer.hpp +++ b/tt_eager/tensor/owned_buffer.hpp @@ -36,7 +36,10 @@ struct Buffer { inline const std::vector& get() const { return *this->shared_vector_; } inline void reset() { this->shared_vector_.reset(); } - private: + inline void* data() noexcept { return static_cast(this->pointer_for_faster_access_); } + inline const void* data() const noexcept { return static_cast(this->pointer_for_faster_access_); } + + private: std::shared_ptr> shared_vector_; T* pointer_for_faster_access_; std::size_t size_; diff --git a/tt_eager/tensor/owned_buffer_functions.hpp b/tt_eager/tensor/owned_buffer_functions.hpp index 1a7367a5dc8..b27cf9305dd 100644 --- a/tt_eager/tensor/owned_buffer_functions.hpp +++ b/tt_eager/tensor/owned_buffer_functions.hpp @@ -55,7 +55,7 @@ Buffer get_as(Tensor& tensor) { if constexpr (std::is_same_v) { return get_as(storage.buffer); } else { - TT_THROW("Must be an owned storage"); + TT_THROW("Tensor must have OwnedStorage"); } }, tensor.storage() @@ -71,7 +71,7 @@ const Buffer get_as(const Tensor& tensor) { if constexpr (std::is_same_v) { return get_as(storage.buffer); } else { - TT_THROW("Must be a owned storage"); + TT_THROW("Tensor must have OwnedStorage"); } }, tensor.storage() diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index ba90ddaa683..75b586ce149 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -131,9 +131,11 @@ void Tensor::print(Layout print_layout, bool pretty_print) const { std::cout << write_to_string(print_layout, pretty_print); } -Tensor Tensor::pad(const Shape &output_tensor_shape, const Shape &input_tensor_start, float pad_value) const { +Tensor Tensor::pad(const Shape& output_tensor_shape, const Shape& input_tensor_start, float pad_value) const { ZoneScoped; - TT_ASSERT(this->storage_type() == StorageType::OWNED or this->storage_type() == StorageType::BORROWED && "Tensor must be on host for padding"); + TT_ASSERT( + this->storage_type() == StorageType::OWNED or + this->storage_type() == StorageType::BORROWED && "Tensor must be on host for padding"); TT_ASSERT(this->layout() == Layout::ROW_MAJOR && "Tensor layout must be ROW_MAJOR for padding"); auto input_shape = this->shape(); @@ -345,6 +347,30 @@ Tensor create_sharded_device_tensor(const Shape& shape, DataType data_type, Layo return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout); } +void* get_raw_host_data_ptr(const Tensor& tensor) { + const static std::unordered_map> dispatch_map = { + {DataType::BFLOAT16, &tensor_impl::get_raw_host_data_ptr}, + {DataType::FLOAT32, &tensor_impl::get_raw_host_data_ptr}, + {DataType::UINT32, &tensor_impl::get_raw_host_data_ptr}, + {DataType::BFLOAT8_B, &tensor_impl::get_raw_host_data_ptr}, + {DataType::UINT16, &tensor_impl::get_raw_host_data_ptr}, + }; + return dispatch_map.at(tensor.dtype())(tensor); +} + +void memcpy(Tensor& dst, const Tensor& src) { + ZoneScoped; + + const static std::unordered_map> dispatch_map = { + {DataType::BFLOAT16, &tensor_impl::memcpy}, + {DataType::FLOAT32, &tensor_impl::memcpy}, + {DataType::UINT32, &tensor_impl::memcpy}, + {DataType::BFLOAT8_B, &tensor_impl::memcpy}, + {DataType::UINT16, &tensor_impl::memcpy}, + }; + dispatch_map.at(dst.dtype())(dst, src); +} + } // namespace tt_metal } // namespace tt diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index ad3aacddb9a..2bfd6282997 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -131,6 +131,11 @@ Tensor create_device_tensor(const Shape& shape, DataType dtype, Layout layout, D Tensor create_sharded_device_tensor(const Shape& shape, DataType data_type, Layout layout, Device *device, const MemoryConfig& memory_config); +// template +// void *get_host_buffer(const Tensor &tensor); +void *get_raw_host_data_ptr(const Tensor &tensor); +void memcpy(Tensor &dst, const Tensor &src); + } // namespace tt_metal } // namespace tt diff --git a/tt_eager/tensor/tensor_impl.hpp b/tt_eager/tensor/tensor_impl.hpp index b98e3dcb6ed..6c918fe5fdb 100644 --- a/tt_eager/tensor/tensor_impl.hpp +++ b/tt_eager/tensor/tensor_impl.hpp @@ -128,7 +128,7 @@ constexpr inline uint32_t packed_buffer_size_bytes(uint32_t volume_unpacked_data template <> constexpr inline uint32_t packed_buffer_size_bytes(uint32_t volume_unpacked_data) { auto num_type_in_u32 = sizeof(uint32_t) / sizeof(bfloat16); - return (volume_unpacked_data/num_type_in_u32) * sizeof(uint32_t); + return (volume_unpacked_data / num_type_in_u32) * sizeof(uint32_t); } // ====================================================================================== @@ -347,7 +347,7 @@ std::string to_string_row_major_4D(const BufferType& buffer, const Shape& shape, if (shape[-2] > MAX_NUM_ELEMENTS_TO_PRINT[-2]) { ss << "..."; } - if(z < shape[1] - 1) + if (z < shape[1] - 1) ss << "]," << std::endl << std::endl; else ss << "]"; @@ -355,7 +355,7 @@ std::string to_string_row_major_4D(const BufferType& buffer, const Shape& shape, if (shape[-3] > MAX_NUM_ELEMENTS_TO_PRINT[-3]) { ss << "..."; } - if(w < shape[0] - 1) + if (w < shape[0] - 1) ss << "]," << std::endl << std::endl << std::endl; else ss << "]"; @@ -363,11 +363,10 @@ std::string to_string_row_major_4D(const BufferType& buffer, const Shape& shape, if (shape[-4] > MAX_NUM_ELEMENTS_TO_PRINT[-4]) { ss << "..."; } - ss << "], dtype=" << dtype << " )" << std::endl; + ss << "], dtype=" << dtype << " )" << std::endl; return ss.str(); } - template std::string to_string_row_major(const BufferType& buffer, const Shape& shape, DataType dtype) { if (shape.rank() == 0) { @@ -434,18 +433,19 @@ std::vector read_data_from_device(const Tensor &tensor, uint32_t size_in_byte } } -template typename BufferType> -inline void write_data_to_device_buffer(const BufferType& data_to_write, DeviceBuffer buffer, const Shape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config) { +template typename BufferType> +inline void write_data_to_device_buffer(const BufferType& host_buffer, Buffer& device_buffer) { ZoneScoped; // TODO(arakhmati): can we use generators in this function to go from `data_to_write` to `uint32_data`? // And effectively get rid of any additional allocation const char *TT_METAL_SLOW_DISPATCH_MODE = std::getenv("TT_METAL_SLOW_DISPATCH_MODE"); if (TT_METAL_SLOW_DISPATCH_MODE == nullptr) { - EnqueueWriteBuffer(tt::tt_metal::detail::GetCommandQueue(buffer->device()), *buffer, std::begin(data_to_write), false); + EnqueueWriteBuffer( + tt::tt_metal::detail::GetCommandQueue(device_buffer.device()), device_buffer, host_buffer.data(), false); } else { - auto uint32_data = pack_vec_into_uint32_vec(data_to_write); - ::detail::WriteToBuffer(*buffer, uint32_data); + auto uint32_data = pack_vec_into_uint32_vec(host_buffer); + ::detail::WriteToBuffer(device_buffer, uint32_data); } } @@ -458,7 +458,7 @@ inline DeviceBuffer initialize_data_on_device(const BufferType& data_to_write auto device_buffer = allocate_buffer_on_device(packed_size_in_bytes, device, shape, data_type, layout, memory_config, shard_spec); - write_data_to_device_buffer(data_to_write, device_buffer, shape, data_type, layout, memory_config); + write_data_to_device_buffer(data_to_write, *device_buffer); return device_buffer; } @@ -885,6 +885,52 @@ Tensor extract_shard(const Tensor & tensor, const uint32_t & core_id){ } +template +void* get_raw_host_data_ptr(const Tensor& tensor) { + return std::visit( + [](auto&& storage) -> void* { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + auto buffer = owned_buffer::get_as(storage.buffer); + return buffer.data(); + } else if constexpr (std::is_same_v) { + if constexpr ( + std::is_same_v or std::is_same_v or + std::is_same_v or std::is_same_v) { + auto buffer = borrowed_buffer::get_as(storage.buffer); + return buffer.data(); + } else { + TT_THROW("Borrowed storage doesn't support this data type"); + } + } else if constexpr (std::is_same_v) { + TT_THROW("Device storage isn't supported"); + } else { + raise_unsupported_storage(); + } + }, + tensor.storage()); +} + +template +void memcpy(Tensor& dst, const Tensor& src) { + const char* TT_METAL_SLOW_DISPATCH_MODE = std::getenv("TT_METAL_SLOW_DISPATCH_MODE"); + if (TT_METAL_SLOW_DISPATCH_MODE != nullptr) { + TT_THROW("SLOW_DISPATCH is not supported for memcpy!"); + } + + TT_ASSERT(dst.dtype() == src.dtype()); + TT_ASSERT(dst.layout() == src.layout()); + + if (is_cpu_tensor(dst) && is_device_tensor(src)) { + EnqueueReadBuffer( + tt::tt_metal::detail::GetCommandQueue(src.device()), *src.buffer(), get_raw_host_data_ptr(dst), true); + } else if (is_device_tensor(dst) && is_cpu_tensor(src)) { + EnqueueWriteBuffer( + tt::tt_metal::detail::GetCommandQueue(dst.device()), *dst.buffer(), get_raw_host_data_ptr(src), false); + } else { + TT_THROW("Unsupported memcpy"); + } +} } // namespace tensor_impl diff --git a/tt_eager/tensor/tensor_utils.cpp b/tt_eager/tensor/tensor_utils.cpp index 951f77d2443..ebc393bb75e 100644 --- a/tt_eager/tensor/tensor_utils.cpp +++ b/tt_eager/tensor/tensor_utils.cpp @@ -184,7 +184,11 @@ const Shape infer_dims_for_reshape(int N, int C, int H, int W, uint32_t old_volu return arch == tt::ARCH::WORMHOLE_B0; } + bool is_cpu_tensor(const Tensor& tensor) { + return tensor.storage_type() == StorageType::OWNED || tensor.storage_type() == StorageType::BORROWED; + } + bool is_device_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::DEVICE; } } } diff --git a/tt_eager/tensor/tensor_utils.hpp b/tt_eager/tensor/tensor_utils.hpp index 9482246e536..2acdfd7aabe 100644 --- a/tt_eager/tensor/tensor_utils.hpp +++ b/tt_eager/tensor/tensor_utils.hpp @@ -40,6 +40,9 @@ namespace tt_metal { bool is_arch_gs(const tt::ARCH& arch); bool is_arch_whb0(const tt::ARCH& arch); + + bool is_cpu_tensor(const Tensor& tensor); + bool is_device_tensor(const Tensor& tensor); } } diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index b87cca7edd6..64f57ba6b78 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -275,7 +275,6 @@ constexpr void raise_unsupported_storage() { static_assert(tt::stl::concepts::always_false_v, "Unsupported Storage"); } - bool operator==(const ShardSpec& spec_a, const ShardSpec& spec_b); bool operator!=(const ShardSpec& spec_a, const ShardSpec& spec_b); diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp index 744fbcb17c7..ad3c0296fca 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp @@ -281,7 +281,7 @@ constexpr auto div_unary_sfpu = make_eltwise_asymmetric_binop_unary_with_param{}(input_tensor, fast_and_approx, output_mem_config); }