diff --git a/models/experimental/functional_bloom/tt/ttnn_optimized_functional_bloom.py b/models/experimental/functional_bloom/tt/ttnn_optimized_functional_bloom.py index e29faee176d..c0a213322ee 100644 --- a/models/experimental/functional_bloom/tt/ttnn_optimized_functional_bloom.py +++ b/models/experimental/functional_bloom/tt/ttnn_optimized_functional_bloom.py @@ -78,6 +78,7 @@ def create_query_key_value(hidden_states, weight, bias, num_heads): query_key_value = ttnn.linear( hidden_states, weight, bias=bias, core_grid=(9, 12), memory_config=BLOOM_MEMORY_CONFIG, dtype=BLOOM_DTYPE ) + ttnn.deallocate(hidden_states) query, key, value = split_query_key_value_and_split_heads(query_key_value, num_heads=num_heads) ttnn.deallocate(query_key_value) @@ -161,8 +162,6 @@ def multi_head_attention( query_layer, key_layer, value_layer = create_query_key_value( hidden_states, query_key_value_weight, query_key_value_bias, num_heads=num_heads ) - value_layer = ttnn.reallocate(value_layer) - attention_scores = compute_attention_scores(query_layer, key_layer, alibi) attention_probs = compute_attention_probs(attention_scores, causal_mask) context_layer = compute_context_layer(attention_probs, value_layer) @@ -186,6 +185,8 @@ def mlp( memory_config=BLOOM_MEMORY_CONFIG, dtype=BLOOM_DTYPE, ) + ttnn.deallocate(hidden_states) + ff2_output = ttnn.linear( ff1_output, dense_4h_to_h_weight, @@ -211,12 +212,11 @@ def bloom( layout=ttnn.TILE_LAYOUT, ) - # TODO(arakhmati): put hidden_states in L1 hidden_states = ttnn.layer_norm( inputs_embeds, weight=parameters.transformer.word_embeddings_layernorm.weight, bias=parameters.transformer.word_embeddings_layernorm.bias, - memory_config=ttnn.DRAM_MEMORY_CONFIG, + memory_config=BLOOM_MEMORY_CONFIG, ) ttnn.deallocate(inputs_embeds) @@ -238,10 +238,8 @@ def bloom( layer_parameters.self_attention.dense.bias, num_heads=num_heads, ) - ttnn.deallocate(normalized_hidden_states) - # TODO(arakhmati): put attention_output in L1 - attention_output = ttnn.add(attention_output, hidden_states, memory_config=ttnn.DRAM_MEMORY_CONFIG) + attention_output = ttnn.add(attention_output, hidden_states, memory_config=BLOOM_MEMORY_CONFIG) ttnn.deallocate(hidden_states) normalized_attention_output = ttnn.layer_norm( @@ -258,10 +256,8 @@ def bloom( layer_parameters.mlp.dense_4h_to_h.weight, layer_parameters.mlp.dense_4h_to_h.bias, ) - ttnn.deallocate(normalized_attention_output) - # TODO(arakhmati): put mlp_output in L1 - mlp_output = ttnn.add(mlp_output, attention_output, memory_config=ttnn.DRAM_MEMORY_CONFIG) + mlp_output = ttnn.add(mlp_output, attention_output, memory_config=BLOOM_MEMORY_CONFIG) ttnn.deallocate(attention_output) hidden_states = mlp_output diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_transpose.py b/tests/tt_eager/python_api_testing/unit_testing/test_transpose.py index ae3b00d0bfa..a843d669625 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_transpose.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_transpose.py @@ -186,7 +186,7 @@ def test_transpose_wh_sharded_program_cache(device, use_program_cache): mem_config = ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, ttl.tensor.BufferType.L1, input_shard_spec ) - + # shape change also changes shard_spec as shard_shape is dependent on input_shape (resulting in CACHE MISS) transpose( input_shape, device, @@ -195,7 +195,7 @@ def test_transpose_wh_sharded_program_cache(device, use_program_cache): input_mem_config=mem_config, output_mem_config=mem_config, input_dtype=input_dtype, - expected_program_cache_size=1, + expected_program_cache_size=2, ) # changing shape @@ -222,6 +222,7 @@ def test_transpose_wh_sharded_program_cache(device, use_program_cache): ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, ttl.tensor.BufferType.L1, input_shard_spec ) + # shape change also changes shard_spec as shard_shape is dependent on input_shape (resulting in CACHE MISS) transpose( input_shape, device, @@ -230,5 +231,5 @@ def test_transpose_wh_sharded_program_cache(device, use_program_cache): input_mem_config=mem_config, output_mem_config=mem_config, input_dtype=input_dtype, - expected_program_cache_size=1, + expected_program_cache_size=3, ) diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index b5a53348346..ba90ddaa683 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -25,8 +25,8 @@ namespace tt { namespace tt_metal { -Tensor::Tensor(const Storage& storage, const Shape& shape, DataType dtype, Layout layout, std::optional shard_spec) - : storage_(storage), shape_(shape), dtype_(dtype), layout_(layout), shard_spec_(shard_spec) { +Tensor::Tensor(const Storage& storage, const Shape& shape, DataType dtype, Layout layout) : + storage_(storage), shape_(shape), dtype_(dtype), layout_(layout) { std::visit( [&] (auto&& storage) { using StorageType = std::decay_t; @@ -48,7 +48,6 @@ Tensor::Tensor(const Storage& storage, const Shape& shape, DataType dtype, Layou ); } - Tensor::~Tensor() { this->deallocate(); } @@ -343,7 +342,7 @@ Tensor create_sharded_device_tensor(const Shape& shape, DataType data_type, Layo data_type, layout, memory_config, std::make_optional(shard_spec_buffer) ); - return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout, shard_spec); + return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout); } } // namespace tt_metal diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index e070dc80679..ad3aacddb9a 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -29,84 +29,93 @@ class Tensor { // ====================================================================================== // Hi Level APIs // ====================================================================================== - Tensor(const Storage& storage, const Shape& shape, DataType dtype, Layout layout, std::optional shard_spec = std::nullopt); + Tensor(const Storage &storage, const Shape &shape, DataType dtype, Layout layout); - Tensor(const Tensor &other) = default; - Tensor& operator=(const Tensor &other) = default; + Tensor(const Tensor &other) = default; + Tensor &operator=(const Tensor &other) = default; - Tensor(Tensor &&other) = default; - Tensor& operator=(Tensor &&other) = default; + Tensor(Tensor &&other) = default; + Tensor &operator=(Tensor &&other) = default; - ~Tensor(); + ~Tensor(); - void deallocate(bool force=false); + void deallocate(bool force = false); - Tensor to(Device *target_device, const MemoryConfig &mem_config={.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) const; + Tensor to( + Device *target_device, + const MemoryConfig &mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) const; - Tensor to(Layout target_layout) const; + Tensor to(Layout target_layout) const; - Tensor pad(const Shape &output_tensor_shape, const Shape &input_tensor_start, float pad_value) const; + Tensor pad(const Shape &output_tensor_shape, const Shape &input_tensor_start, float pad_value) const; - Tensor cpu() const; + Tensor cpu() const; - Tensor cpu_sharded() const; + Tensor cpu_sharded() const; - Tensor unpad(const Shape &output_tensor_start, const Shape &output_tensor_end) const; + Tensor unpad(const Shape &output_tensor_start, const Shape &output_tensor_end) const; - Tensor pad_to_tile(float pad_value) const; + Tensor pad_to_tile(float pad_value) const; - Tensor unpad_from_tile(const Shape &output_tensor_shape) const; + Tensor unpad_from_tile(const Shape &output_tensor_shape) const; - const std::string write_to_string(Layout print_layout = Layout::ROW_MAJOR, bool pretty_print = false) const; - void print(Layout print_layout=Layout::ROW_MAJOR, bool pretty_print=false) const; + const std::string write_to_string(Layout print_layout = Layout::ROW_MAJOR, bool pretty_print = false) const; + void print(Layout print_layout = Layout::ROW_MAJOR, bool pretty_print = false) const; - Tensor extract_shard(const CoreCoord & core) const; - Tensor extract_shard(const uint32_t & core_id) const; + Tensor extract_shard(const CoreCoord &core) const; + Tensor extract_shard(const uint32_t &core_id) const; - // ====================================================================================== - // Low Level APIs - // ====================================================================================== - Tensor reshape(int N, int C, int H, int W) const; - Tensor reshape(const Shape& new_shape) const; + // ====================================================================================== + // Low Level APIs + // ====================================================================================== + Tensor reshape(int N, int C, int H, int W) const; + Tensor reshape(const Shape &new_shape) const; - // ====================================================================================== - // Getters - // ====================================================================================== - const Storage& storage() const; - const Shape& shape() const { return this->shape_; } - DataType dtype() const { return this->dtype_; } - Layout layout() const { return this->layout_; } - const std::optional& shard_spec() const { return this->shard_spec_; } + // ====================================================================================== + // Getters + // ====================================================================================== + const Storage &storage() const; + const Shape &shape() const { return this->shape_; } + DataType dtype() const { return this->dtype_; } + Layout layout() const { return this->layout_; } - // ====================================================================================== - // Extra Helper Functions - // ====================================================================================== + // ====================================================================================== + // Extra Helper Functions + // ====================================================================================== - StorageType storage_type() const; - const Shape strides() const; - uint32_t volume() const; + StorageType storage_type() const; + const Shape strides() const; + uint32_t volume() const; - bool is_allocated() const; + bool is_allocated() const; - // TODO(arakhmati): clean up the methods below - Buffer* buffer() const { return std::get(this->storage_).buffer.get(); } - Device *device() const { return this->buffer()->device(); } - const MemoryConfig memory_config() const { return std::get(this->storage_).memory_config(); } + // TODO(arakhmati): clean up the methods below + Buffer *buffer() const { return std::get(this->storage_).buffer.get(); } + Device *device() const { return this->buffer()->device(); } + const MemoryConfig memory_config() const { + return std::visit( + [](const auto &storage) -> MemoryConfig { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return storage.memory_config(); + } else { + TT_THROW("MemoryConfig can only be obtained for a tensor with DeviceStorage"); + } + }, + this->storage_); + } + const std::optional shard_spec() const { return this->memory_config().shard_spec; } - const bool is_sharded() const { return this->memory_config().is_sharded(); } + const bool is_sharded() const { return this->memory_config().is_sharded(); } - // Size in bytes of a single element held in tensor - uint32_t element_size() const; + // Size in bytes of a single element held in tensor + uint32_t element_size() const; - static constexpr auto attribute_names = std::make_tuple("storage", "shape", "dtype", "layout", "shard_spec"); - const auto attribute_values() const { - return std::make_tuple( - std::cref(this->storage_), - std::cref(this->shape_), - std::cref(this->dtype_), - std::cref(this->layout_), - std::cref(this->shard_spec_)); - } + static constexpr auto attribute_names = std::make_tuple("storage", "shape", "dtype", "layout"); + const auto attribute_values() const { + return std::make_tuple( + std::cref(this->storage_), std::cref(this->shape_), std::cref(this->dtype_), std::cref(this->layout_)); + } std::vector host_page_ordering(); private: @@ -114,8 +123,6 @@ class Tensor { Shape shape_; DataType dtype_; Layout layout_; - std::optional shard_spec_; - }; @@ -127,30 +134,3 @@ Tensor create_sharded_device_tensor(const Shape& shape, DataType data_type, Layo } // namespace tt_metal } // namespace tt - - -namespace std { - -template <> -struct hash { - uint64_t operator()(const tt::tt_metal::Tensor &tensor) const { - if (std::holds_alternative(tensor.storage())) { - return tt::stl::hash::hash_objects(0, - typeid(tt::tt_metal::Tensor).hash_code(), - tensor.storage_type(), - std::holds_alternative(tensor.storage()), - tensor.shape(), - tensor.layout(), - tensor.dtype()); - } - else { - return tt::stl::hash::hash_objects(0, - typeid(tt::tt_metal::Tensor).hash_code(), - tensor.storage_type(), - tensor.shape(), - tensor.layout(), - tensor.dtype()); - } - } -}; -} diff --git a/tt_eager/tensor/tensor_impl.hpp b/tt_eager/tensor/tensor_impl.hpp index 308f17c5c31..c2aee9f9f22 100644 --- a/tt_eager/tensor/tensor_impl.hpp +++ b/tt_eager/tensor/tensor_impl.hpp @@ -574,7 +574,7 @@ inline Tensor to_device(const Tensor &tensor, Device *target_device, const Memor data_type, layout, memory_config, shard_spec_buffer_opt ); - return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout, memory_config.shard_spec); + return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout); } diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index bc7f5289efd..b87cca7edd6 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -164,14 +164,14 @@ bool operator!=(const Shape&, const Shape&); struct MemoryConfig { TensorMemoryLayout memory_layout = TensorMemoryLayout::INTERLEAVED; // Interleave the data across multiple banks BufferType buffer_type = BufferType::DRAM; // Can be either DRAM or L1 - std::optional shard_spec = std::nullopt; + std::optional shard_spec = std::nullopt; bool is_sharded() const; - static constexpr auto attribute_names = std::make_tuple("memory_layout", "buffer_type"); + static constexpr auto attribute_names = std::make_tuple("memory_layout", "buffer_type", "shard_spec"); const auto attribute_values() const { - return std::make_tuple(std::cref(this->memory_layout), std::cref(this->buffer_type)); + return std::make_tuple( + std::cref(this->memory_layout), std::cref(this->buffer_type), std::cref(this->shard_spec)); } - ~MemoryConfig(){;} }; bool operator==(const MemoryConfig& config_a, const MemoryConfig& config_b); @@ -194,17 +194,18 @@ struct DeviceStorage { DeviceBuffer buffer; const MemoryConfig memory_config() const { - const auto& buffer = this->buffer; + if (this->buffer.get() == nullptr) { + TT_THROW("MemoryConfig can only be obtained if the buffer is not null"); + } - std::optional shard_spec_opt = std::nullopt; - if(is_sharded(buffer->buffer_layout())){ - shard_spec_opt = buffer->shard_spec().tensor_shard_spec; + std::optional shard_spec = std::nullopt; + if (is_sharded(this->buffer->buffer_layout())) { + shard_spec = this->buffer->shard_spec().tensor_shard_spec; } return MemoryConfig{ - .memory_layout = buffer->buffer_layout(), - .buffer_type = buffer->buffer_type(), - .shard_spec = shard_spec_opt - }; + .memory_layout = this->buffer->buffer_layout(), + .buffer_type = this->buffer->buffer_type(), + .shard_spec = shard_spec}; } static constexpr auto attribute_names = std::make_tuple("memory_config"); @@ -282,70 +283,3 @@ bool operator!=(const ShardSpec& spec_a, const ShardSpec& spec_b); } // namespace tt_metal } // namespace tt - - -namespace std { - -template <> -struct hash { - uint64_t operator()(const tt::tt_metal::MemoryConfig &mem_config) const { - return tt::stl::hash::hash_objects(0, - typeid(tt::tt_metal::MemoryConfig).hash_code(), - mem_config.memory_layout, - mem_config.buffer_type - ); - } -}; - - -template <> -struct hash { - uint64_t operator()(const tt::tt_metal::DeviceStorage &storage) const { - return tt::stl::hash::hash_objects(0, - typeid(tt::tt_metal::DeviceStorage).hash_code(), - storage.buffer, - storage.memory_config() - ); - } -}; - - -template <> -struct hash { - uint64_t operator()(const tt::tt_metal::Padding::PadDimension &pad_dimension) const { - return tt::stl::hash::hash_objects(0, - typeid(tt::tt_metal::Padding::PadDimension).hash_code(), - pad_dimension.front, - pad_dimension.back - ); - } -}; - - -template <> -struct hash { - uint64_t operator()(const tt::tt_metal::Padding &padding) const { - - uint64_t hash = tt::stl::hash::hash_objects(0, typeid(tt::tt_metal::Padding).hash_code(), - padding.rank_, padding.pad_value_); - for (const auto& pad_dim : padding.pad_dimensions_) { - hash = tt::stl::hash::hash_objects(hash, pad_dim); - } - return hash; - } -}; -template <> -struct hash { - uint64_t operator()(const tt::tt_metal::Shape &shape) const { - - uint64_t hash = tt::stl::hash::hash_objects(0, typeid(tt::tt_metal::Shape).hash_code(), - shape.rank(), shape.padding()); - for(int idx=0; idx < shape.rank(); idx++){ - hash = tt::stl::hash::hash_objects(hash, shape[idx]); - } - return hash; - - } -}; - -} diff --git a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp index da82ef0d748..dc3fd50c36b 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp @@ -229,11 +229,11 @@ const operation::Hash EltwiseBinary::compute_program_hash( this->op_type, parallelization_strategy, input_tensor_a.dtype(), - input_tensor_a.memory_config().memory_layout, + input_tensor_a.memory_config(), input_tensor_b.dtype(), - input_tensor_b.memory_config().memory_layout, + input_tensor_b.memory_config(), this->output_dtype, - this->output_mem_config.memory_layout, + this->output_mem_config, this->in_place); if (this->fused_activations.has_value()) { diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index 08651ee31f5..c04ecd35d18 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -302,8 +302,13 @@ const operation::Hash EltwiseUnary::compute_program_hash(const std::vectoroutput_mem_config); for (const auto& unary_with_param_op : this->op_chain) { hash = tt::stl::hash::hash_objects(hash, unary_with_param_op.op_type); if (unary_with_param_op.param.has_value()) { diff --git a/tt_eager/tt_dnn/op_library/run_operation.cpp b/tt_eager/tt_dnn/op_library/run_operation.cpp index d83c4af1580..e71d09ba5cc 100644 --- a/tt_eager/tt_dnn/op_library/run_operation.cpp +++ b/tt_eager/tt_dnn/op_library/run_operation.cpp @@ -122,7 +122,7 @@ constexpr auto decorate_operation(const Function& function) { #ifdef TTNN_ENABLE_LOGGING const auto elapsed_seconds = static_cast((end - start).count()); tt::log_info( - tt::LogOp, "Operation {:100} finished in {:15} nanoseconds", operation.get_type_name(), elapsed_seconds); + tt::LogOp, "Finished Operation {:50} in {:15} nanoseconds", operation.get_type_name(), elapsed_seconds); #endif return output_tensors; @@ -230,7 +230,7 @@ std::vector run_device_operation( const auto elapsed_seconds = static_cast((end - start).count()); tt::log_info( tt::LogOp, - "Program {:100} finished in {:15} nanoseconds", + "Finished Program {:50} in {:15} nanoseconds", operation.get_type_name(), elapsed_seconds); #endif diff --git a/tt_eager/tt_dnn/op_library/run_operation.hpp b/tt_eager/tt_dnn/op_library/run_operation.hpp index 30f98293c79..a69f8a922d5 100644 --- a/tt_eager/tt_dnn/op_library/run_operation.hpp +++ b/tt_eager/tt_dnn/op_library/run_operation.hpp @@ -196,7 +196,7 @@ inline void log_operation( const std::vector>& optional_input_tensors = {}) { tt::log_debug( tt::LogOp, - "Operation: \"{}\" ({})", + "Launching Operation: \"{}\" ({})", operation.get_type_name(), detail::operation_type_to_string()); diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp index 79a03a58bd1..537c578465d 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp @@ -95,6 +95,7 @@ tt::stl::reflection::Attributes SplitFusedQKVAndSplitHeads::attributes() const { return { {"compute_with_storage_grid_size", this->compute_with_storage_grid_size.str()}, {"output_mem_config", this->output_mem_config}, + {"num_heads", this->num_heads}, }; } diff --git a/tt_eager/tt_dnn/op_library/transpose/wh_multi_core/transpose_wh_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/transpose/wh_multi_core/transpose_wh_op_multi_core.cpp index c63b51d9966..16851a677d1 100644 --- a/tt_eager/tt_dnn/op_library/transpose/wh_multi_core/transpose_wh_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/transpose/wh_multi_core/transpose_wh_op_multi_core.cpp @@ -265,7 +265,7 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor &a, uint32_t num_cores_y = compute_with_storage_grid_size.y; CoreRange total_cores = {.start={0, 0}, .end={num_cores_x-1, num_cores_y-1}}; - auto& shard_spec = a.shard_spec().value(); + auto shard_spec = a.shard_spec().value(); auto& all_cores = shard_spec.grid; uint32_t num_cores = all_cores.num_cores(); diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp index 19b2f3d7e72..00e9e666e08 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp @@ -345,7 +345,7 @@ Tensor convert_torch_tensor_to_tt_tensor( const auto elapsed_seconds = static_cast((end - start).count()); tt::log_info( tt::LogOp, - "Operation {:100} finished in {:15} nanoseconds", + "Finished Operation {:50} in {:15} nanoseconds", op.get_type_name(), elapsed_seconds); #endif diff --git a/tt_metal/impl/buffers/buffer.hpp b/tt_metal/impl/buffers/buffer.hpp index 4b19ec97718..af16ef025ba 100644 --- a/tt_metal/impl/buffers/buffer.hpp +++ b/tt_metal/impl/buffers/buffer.hpp @@ -134,7 +134,11 @@ bool is_sharded(const TensorMemoryLayout & layout); class Buffer { public: - Buffer() : device_(nullptr) {} + Buffer() : + device_(nullptr), + buffer_type_(BufferType::DRAM), + buffer_layout_(TensorMemoryLayout::INTERLEAVED), + shard_parameters_(std::nullopt) {} Buffer(Device *device, uint64_t size, uint64_t page_size, const BufferType buffer_type, const TensorMemoryLayout buffer_layout=TensorMemoryLayout::INTERLEAVED, diff --git a/tt_metal/tt_stl/reflection.hpp b/tt_metal/tt_stl/reflection.hpp index a9834f29f36..1efb3ae1897 100644 --- a/tt_metal/tt_stl/reflection.hpp +++ b/tt_metal/tt_stl/reflection.hpp @@ -458,7 +458,7 @@ constexpr bool is_specialization_v = is_specialization::value; template inline hash_t hash_object(const std::array& array) noexcept { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { - fmt::print("Hashing std::array<{}, {}>\n", boost::core::demangle(typeid(T).name()), N); + fmt::print("Hashing std::array<{}, {}>: {}\n", boost::core::demangle(typeid(T).name()), N, array); } std::size_t hash = 0; [&array, &hash](std::index_sequence) { @@ -475,7 +475,7 @@ inline hash_t hash_object(const std::array& array) noexcept { template inline hash_t hash_object(const std::variant& variant) noexcept { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { - fmt::print("Hashing std::variant\n"); + fmt::print("Hashing std::variant: {}\n", variant); } return std::visit([](const auto& value) { return hash_object(value); }, variant); } @@ -484,17 +484,17 @@ template inline hash_t hash_object(const T& object) noexcept { if constexpr (std::numeric_limits::is_integer) { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { - fmt::print("Hashing integer of type {}\n", boost::core::demangle(typeid(T).name())); + fmt::print("Hashing integer of type {}: {}\n", boost::core::demangle(typeid(T).name()), object); } return object; } else if constexpr (detail::is_std_hashable_v) { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { - fmt::print("Hashing {} using std::hash\n", boost::core::demangle(typeid(T).name())); + fmt::print("Hashing {} using std::hash: {}\n", boost::core::demangle(typeid(T).name()), object); } return std::hash{}(object); } else if constexpr (std::is_same_v) { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { - fmt::print("Hashing tt::stl::reflection::Attributes\n"); + fmt::print("Hashing tt::stl::reflection::Attributes: {}\n", object); } auto hash = 0; for (auto&& [name, attribute] : object) { @@ -503,19 +503,23 @@ inline hash_t hash_object(const T& object) noexcept { return hash; } else if constexpr (tt::stl::reflection::detail::supports_to_hash_v) { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { - fmt::print("Hashing struct {} using to_hash method\n", boost::core::demangle(typeid(T).name())); + fmt::print("Hashing struct {} using to_hash method: {}\n", boost::core::demangle(typeid(T).name()), object); } return object.to_hash(); } else if constexpr (tt::stl::reflection::detail::supports_compile_time_attributes_v) { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { - fmt::print("Hashing struct {} using compile-time attributes\n", boost::core::demangle(typeid(T).name())); + fmt::print( + "Hashing struct {} using compile-time attributes: {}\n", + boost::core::demangle(typeid(T).name()), + object); } constexpr auto num_attributes = reflection::detail::get_num_attributes(); - std::size_t hash = 0; - [&object, &hash](std::index_sequence) { + std::size_t hash = hash_objects(0, typeid(T).hash_code()); + const auto attribute_values = object.attribute_values(); + [&object, &hash, &attribute_values](std::index_sequence) { ( - [&object, &hash] { - const auto& attribute = std::get(object.attribute_values()); + [&object, &hash, &attribute_values] { + const auto& attribute = std::get(attribute_values); hash = hash_objects(hash, attribute); }(), ...); @@ -523,12 +527,13 @@ inline hash_t hash_object(const T& object) noexcept { return hash; } else if constexpr (tt::stl::reflection::detail::supports_runtime_time_attributes_v) { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { - fmt::print("Hashing struct {} using run-time attributes\n", boost::core::demangle(typeid(T).name())); + fmt::print( + "Hashing struct {} using run-time attributes: {}\n", boost::core::demangle(typeid(T).name()), object); } - return hash_object(object.attributes()); + return hash_objects(0, typeid(T).hash_code(), object.attributes()); } else if constexpr (detail::is_specialization_v) { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { - fmt::print("Hashing std::vector<{}>\n", boost::core::demangle(typeid(T).name())); + fmt::print("Hashing std::vector of type {}: {}\n", boost::core::demangle(typeid(T).name()), object); } auto hash = 0; for (const auto& element : object) { @@ -537,7 +542,7 @@ inline hash_t hash_object(const T& object) noexcept { return hash; } else if constexpr (detail::is_specialization_v) { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { - fmt::print("Hashing std::optional<{}>\n", boost::core::demangle(typeid(T).name())); + fmt::print("Hashing std::optional of type {}: {}\n", boost::core::demangle(typeid(T).name()), object); } if (object.has_value()) { return hash_object(object.value()); diff --git a/ttnn/tensor.py b/ttnn/tensor.py index f0efcb7f26d..2fb1283ad41 100644 --- a/ttnn/tensor.py +++ b/ttnn/tensor.py @@ -668,9 +668,12 @@ def _torch_identity(input_tensor): @decorate_operation(torch_function=_torch_identity) def reallocate(input_tensor: Tensor) -> Tensor: - ttl_input_tensor = input_tensor.value - ttl_output_tensor = ttl.tensor.move(ttl_input_tensor) - return Tensor(ttl_output_tensor) + def impl(input_tensor): + ttl_input_tensor = input_tensor.value + ttl_output_tensor = ttl.tensor.move(ttl_input_tensor) + return Tensor(ttl_output_tensor) + + return ttl.tensor.decorate_external_operation(impl, function_name="ttnn.reallocate")(input_tensor) @decorate_operation()