From 63a72ad69baf8a0b1cfacfa337abb6718a9bbb33 Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Thu, 7 Nov 2024 13:51:37 -0800 Subject: [PATCH] #0: Finalize transition of eltwise and matmul operations to SimpleShape (#14777) * #0: Finalize transition of eltwise and matmul to SimpleShape * #0: Shape fixup --- .../eltwise/unary/device/unary_device_operation.cpp | 5 ++--- .../eltwise/unary/device/unary_device_operation_types.hpp | 2 +- .../ccl/all_gather_matmul/device/all_gather_matmul_op.cpp | 2 +- ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp | 3 --- ttnn/cpp/ttnn/operations/matmul/matmul.cpp | 6 +++--- ttnn/cpp/ttnn/operations/matmul/matmul.hpp | 2 +- 6 files changed, 8 insertions(+), 12 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp index b8f7e3c1c10..255ca459504 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp @@ -149,7 +149,7 @@ void UnaryDeviceOperation::validate_on_program_cache_miss( shape_return_value_t UnaryDeviceOperation::compute_output_shapes( const operation_attributes_t&, const tensor_args_t& tensor_args) { - return {tensor_args.input.get_shape()}; + return {tensor_args.input.get_logical_shape()}; } tensor_return_value_t UnaryDeviceOperation::create_output_tensors( @@ -158,13 +158,12 @@ tensor_return_value_t UnaryDeviceOperation::create_output_tensors( return tensor_args.preallocated_output.value(); } - const auto output_shape = compute_output_shapes(args, tensor_args); - auto output_layout = Layout::TILE; if (args.output_memory_config.is_sharded()) { output_layout = tensor_args.input.get_layout(); } + const auto output_shape = tensor_args.input.shape(); return create_device_tensor( output_shape, args.output_dtype, output_layout, tensor_args.input.device(), args.output_memory_config); } diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation_types.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation_types.hpp index f600d7317f5..95d100a9c85 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation_types.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation_types.hpp @@ -27,6 +27,6 @@ struct tensor_args_t { using tensor_return_value_t = Tensor; -using shape_return_value_t = ttnn::Shape; +using shape_return_value_t = ttnn::SimpleShape; } // namespace ttnn::operations::unary diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp index 4fd3da1951b..c3ed75dcb6f 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp @@ -163,7 +163,7 @@ std::vector all_gather_matmul( /* Matmul setup */ - bool user_run_batched = ttnn::operations::matmul::detail::is_input_batched(weight_tensor.get_shape()); + bool user_run_batched = ttnn::operations::matmul::detail::is_input_batched(weight_tensor.get_logical_shape()); std::optional user_core_coord; if (core_grid.has_value()) { user_core_coord = CoreCoord(core_grid->x, core_grid->y); diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp index 93e36fee7f2..2e7144040f2 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp @@ -21,7 +21,6 @@ namespace operations { namespace matmul { using ttnn::operations::unary::UnaryWithParam; -using tt::tt_metal::LegacyShape; /* * GENERAL MATMUL AND BMM @@ -170,8 +169,6 @@ struct Matmul { const std::vector &input_tensors, const std::vector> &optional_input_tensors) const; std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector compute_output_shapes_dram_sharded( - const std::vector &input_tensors, uint32_t N_unpadded) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul.cpp b/ttnn/cpp/ttnn/operations/matmul/matmul.cpp index c2d14cddedc..5bc05ffedb3 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul.cpp @@ -18,7 +18,7 @@ namespace matmul { namespace detail { -bool is_input_batched(const ttnn::Shape& shape) { +bool is_input_batched(const ttnn::SimpleShape& shape) { auto is_batched = false; for (auto i = 0; i < shape.rank() - 2; ++i) { if (shape[i] > 1) { @@ -109,7 +109,7 @@ Tensor MatmulOperation::invoke( if (core_grid.has_value()) { user_core_coord = CoreCoord(core_grid->x, core_grid->y); } - bool user_run_batched = detail::is_input_batched(input_tensor_b.get_shape()); + bool user_run_batched = detail::is_input_batched(input_tensor_b.get_logical_shape()); return bound_matmul( input_tensor_a, input_tensor_b, @@ -147,7 +147,7 @@ Tensor LinearOperation::invoke( if (core_grid.has_value()) { user_core_coord = CoreCoord(core_grid->x, core_grid->y); } - bool b_is_batched = detail::is_input_batched(input_tensor_b.get_shape()); + bool b_is_batched = detail::is_input_batched(input_tensor_b.get_logical_shape()); TT_FATAL(!(b_is_batched && bias.has_value()), "Batched input not supported when bias exists (linear operation)."); return bound_matmul( diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul.hpp b/ttnn/cpp/ttnn/operations/matmul/matmul.hpp index eb450bd2896..1501b005471 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul.hpp @@ -22,7 +22,7 @@ namespace matmul { namespace detail { -bool is_input_batched(const ttnn::Shape& shape); +bool is_input_batched(const ttnn::SimpleShape& logical_shape); } // namespace detail