From 5de5bc944e569e0145e7ef79ec6905e529790b1a Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Sun, 20 Oct 2024 19:48:06 -0700 Subject: [PATCH] #13127: Add compute_output_specs which allows user to return output tensor layout - Op must implement either compute_output_shapes + create_output_tensors OR compute_output_specs * If compute_output_specs is implemented, infra will handle tensor creation * Also works for optional output tensors - Update generic reduction which carries over padding along N and C * pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_min_max.py::test_min_max_for_dim_hw * pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py::test_sharded_reduce_h - Update transpose op which transposes and maintains padding * pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py - Update slice op which uses optional output tensors * pytest tests/ttnn/unit_tests/operations/test_slice.py::test_slice_output_tensor_tile --- ttnn/cpp/ttnn/operation.hpp | 83 ++++++++++++++++--- .../data_movement/slice/device/slice_op.cpp | 42 +++------- .../data_movement/slice/device/slice_op.hpp | 3 +- .../transpose/device/transpose_op.cpp | 69 ++++++--------- .../transpose/device/transpose_op.hpp | 3 +- .../reduction/generic/device/reduce_op.cpp | 47 +++++------ .../reduction/generic/device/reduce_op.hpp | 3 +- ttnn/cpp/ttnn/run_operation.cpp | 35 +++++--- ttnn/cpp/ttnn/run_operation.hpp | 5 ++ ttnn/cpp/ttnn/tensor/tensor.hpp | 1 + 10 files changed, 162 insertions(+), 129 deletions(-) diff --git a/ttnn/cpp/ttnn/operation.hpp b/ttnn/cpp/ttnn/operation.hpp index e20faf18fdf..267571da734 100644 --- a/ttnn/cpp/ttnn/operation.hpp +++ b/ttnn/cpp/ttnn/operation.hpp @@ -287,6 +287,22 @@ constexpr bool implements_validate_with_output_tensors_and_optional_input_tensor const OptionalTensors&>; // optional output_tensors } +template +using has_compute_output_shapes_t = decltype(std::declval().compute_output_shapes(std::declval()...)); + +template +constexpr bool implements_compute_output_shapes() { + return std::experimental::is_detected_v; +} + +template +using has_compute_output_specs_t = decltype(std::declval().compute_output_specs(std::declval()...)); + +template +constexpr bool implements_compute_output_specs() { + return std::experimental::is_detected_v; +} + template using has_create_output_tensors_t = decltype(std::declval().create_output_tensors(std::declval()...)); @@ -378,13 +394,41 @@ constexpr bool implements_get_parallelization_strategy() { return std::experimental::is_detected_v; } +template +auto default_create_output_tensors( + const ConcreteOperation& operation, + const Tensors& input_tensors, + const OptionalTensors& optional_output_tensors) -> ProgramOutputTensors { + using OutputTensors = ProgramOutputTensors; + OutputTensors output_tensors; + + if (!optional_output_tensors.empty() and optional_output_tensors[0].has_value()) { + output_tensors.reserve(optional_output_tensors.size()); + for (const auto& optional_output_tensor : optional_output_tensors) { + TT_FATAL(optional_output_tensor.has_value(), "If using optional output tensors, all output tensors must have a value"); + output_tensors.emplace_back(optional_output_tensor.value()); + } + return output_tensors; + } + const auto& device = input_tensors.at(0).device(); + const auto& output_specs = operation.compute_output_specs(input_tensors); + output_tensors.reserve(output_specs.size()); + for (const auto& [output_shape, output_layout] : output_specs) { + output_tensors.emplace_back(create_device_tensor( + output_shape, + output_layout, + device)); + } + return output_tensors; +} + } // namespace detail template struct DeviceOperation final { using storage_t = std::array; using OutputTensors = OutputTensorsT; - using ComputedShapes = std::variant, std::vector>; + using ComputedShapes = std::variant, std::vector, std::vector>; inline const std::string get_type_name() const { return this->get_type_name_impl_(this->type_erased_storage); } @@ -396,6 +440,7 @@ struct DeviceOperation final { this->type_erased_storage, input_tensors, optional_input_tensors, optional_output_tensors); } + // TODO: Rename into compute_output_specs in later PR inline const ComputedShapes compute_output_shapes(const Tensors& input_tensors) const { return this->compute_output_shapes_impl_(this->type_erased_storage, input_tensors); } @@ -502,14 +547,6 @@ struct DeviceOperation final { static_assert( tt::stl::concepts::always_false_v, "You cannot implement both validate and validate_with_output_tensors"); - } else if constexpr ( - (detail::implements_validate_with_output_tensors() or - detail::implements_validate_with_output_tensors_and_optional_input_tensors()) and - not detail::implements_create_output_tensors_with_optional_output_tensors()) { - static_assert( - tt::stl::concepts::always_false_v, - "Operation doesn't implement create_output_tensors with ant optional output tensors argument " - "when using validate_with_output_tensors"); } else if constexpr (detail::implements_validate() and not detail::implements_create_program()) { static_assert( tt::stl::concepts::always_false_v, @@ -547,7 +584,19 @@ struct DeviceOperation final { compute_output_shapes_impl_{ [](const storage_t& storage, const Tensors& input_tensors) -> const ComputedShapes { const auto& operation = *reinterpret_cast*>(&storage); - return operation.compute_output_shapes(input_tensors); + if constexpr (detail::implements_compute_output_shapes() and detail::implements_compute_output_specs()) { + static_assert( + tt::stl::concepts::always_false_v, + "Operation cannot implement both compute_output_shapes and compute_output_specs"); + } else if constexpr (detail::implements_compute_output_shapes()) { + return operation.compute_output_shapes(input_tensors); + } else if constexpr (detail::implements_compute_output_specs()) { + return operation.compute_output_specs(input_tensors); + } else { + static_assert( + tt::stl::concepts::always_false_v, + "Operation must implement either compute_output_shapes or compute_output_specs"); + } }}, create_output_tensors_impl_{ [](const storage_t& storage, @@ -555,9 +604,21 @@ struct DeviceOperation final { const OptionalTensors& output_tensors) -> const OutputTensors { const auto& operation = *reinterpret_cast*>(&storage); if constexpr (detail::implements_create_output_tensors_with_optional_output_tensors()) { + static_assert( + detail::implements_compute_output_shapes(), + "Operation must implement compute_output_shapes if it implements create_output_tensors"); return operation.create_output_tensors(input_tensors, output_tensors); - } else { + } else if constexpr (detail::implements_create_output_tensors()) { + static_assert( + detail::implements_compute_output_shapes(), + "Operation must implement compute_output_shapes if it implements create_output_tensors"); return operation.create_output_tensors(input_tensors); + } else if constexpr (detail::implements_compute_output_specs()) { + return detail::default_create_output_tensors(operation, input_tensors, output_tensors); + } else { + static_assert( + tt::stl::concepts::always_false_v, + "Operation must implement either create_output_tensors or compute_output_specs"); } }}, create_program_impl_{ diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp index 386a76fef93..8ded07491f0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp @@ -82,11 +82,11 @@ void SliceDeviceOperation::validate_with_output_tensors( TT_FATAL(this->slice_start[i] <= this->slice_end[i], "Error"); } if(!output_tensors.empty() && output_tensors[0].has_value()){ - const auto output_shape_required = this->compute_output_shapes(input_tensors)[0]; + const auto output_shape_required = std::get<0>(this->compute_output_specs(input_tensors)[0]); const auto& out_tensor = output_tensors[0].value(); - TT_FATAL(out_tensor.get_legacy_shape() == output_shape_required, "The input tensors need a shape of {}, however the output tensor is only {}", output_shape_required, out_tensor.get_legacy_shape()); + TT_FATAL(out_tensor.get_padded_shape() == output_shape_required, "The input tensors need a shape of {}, however the output tensor is only {}", output_shape_required, out_tensor.get_padded_shape()); } - auto output_tensor_shape = this->compute_output_shapes(input_tensors)[0]; + auto output_tensor_shape = std::get<0>(this->compute_output_specs(input_tensors)[0]); if (has_step) { // if all ones modify before passing in to function TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "Strided slice is only supported for row major layout"); TT_FATAL(!input_tensor_a.is_sharded(), "Strided slice is not supported for sharded tensor"); @@ -117,40 +117,18 @@ void SliceDeviceOperation::validate_with_output_tensors( } } -std::vector SliceDeviceOperation::compute_output_shapes(const std::vector &input_tensors) const { - SmallVector out_shape; - auto rank = input_tensors[0].get_legacy_shape().rank(); - out_shape.reserve(rank); +std::vector SliceDeviceOperation::compute_output_specs(const std::vector &input_tensors) const { + const auto& input_tensor = input_tensors[0]; + SmallVector out_shape(input_tensor.get_logical_shape().rank()); auto output_dim_i = [this] (size_t i) { return (this->slice_end[i] - this->slice_start[i] + this->step[i] - 1) / this->step[i]; }; - for (uint32_t i = 0; i < rank; i++) { - out_shape.push_back(output_dim_i(i)); - } - tt::tt_metal::LegacyShape output_tensor_shape(out_shape); - return {output_tensor_shape}; -} - -std::vector SliceDeviceOperation::create_output_tensors( - const std::vector &input_tensors, const std::vector> &output_tensors) const { - if (!output_tensors.empty() && output_tensors[0].has_value()) { - return {output_tensors[0].value()}; - } - const auto &input_tensor_a = input_tensors.at(0); - const auto shapes = compute_output_shapes(input_tensors); - - if (input_tensor_a.is_sharded()) { - return {create_device_tensor( - shapes[0], - input_tensor_a.get_dtype(), - input_tensor_a.get_layout(), - input_tensor_a.device(), - this->output_mem_config)}; - } else { - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensor_a.get_dtype(), input_tensor_a.get_layout(), this->output_mem_config); + for (uint32_t i = 0; i < out_shape.size(); i++) { + out_shape[i] = output_dim_i(i); } + ttnn::SimpleShape output_tensor_shape(std::move(out_shape)); + return {ttnn::TensorSpec(output_tensor_shape, tt::tt_metal::TensorLayout(input_tensor.get_dtype(), PageConfig(input_tensor.get_layout()), this->output_mem_config))}; } operation::ProgramWithCallbacks SliceDeviceOperation::create_program( diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.hpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.hpp index 5fdd6922cc9..a663db54a45 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.hpp @@ -22,8 +22,7 @@ struct SliceDeviceOperation { void validate_with_output_tensors(const std::vector& input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors, const std::vector>& output_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp index 2c420128583..6dad529d179 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp @@ -77,72 +77,57 @@ void Transpose::validate(const std::vector &input_tensors) const { } -std::vector Transpose::compute_output_shapes(const std::vector &input_tensors) const { +std::vector Transpose::compute_output_specs(const std::vector &input_tensors) const { const auto& input_tensor = input_tensors.at(0); - auto out_shape = input_tensor.get_legacy_shape(); - auto padding = out_shape.padding(); + + // TODO: Remove usage of input/output padded shape + // - Get output alignment from input alignment and output dtype, layout, mem_config + // - Get shard spec from output strides (logical shape + alignment)? + auto output_shape = input_tensor.get_logical_shape(); + auto output_padded_shape = input_tensor.get_padded_shape(); + switch (this->dim){ case TransposeOpDim::CN: - std::swap(out_shape[0], out_shape[1]); - std::swap(padding[0], padding[1]); + std::swap(output_shape[0], output_shape[1]); + std::swap(output_padded_shape[0], output_padded_shape[1]); break; case TransposeOpDim::HC: - std::swap(out_shape[1], out_shape[2]); - std::swap(padding[1], padding[2]); + std::swap(output_shape[1], output_shape[2]); + std::swap(output_padded_shape[1], output_padded_shape[2]); break; case TransposeOpDim::WH: - std::swap(out_shape[2], out_shape[3]); - std::swap(padding[2], padding[3]); + std::swap(output_shape[2], output_shape[3]); + std::swap(output_padded_shape[2], output_padded_shape[3]); break; case TransposeOpDim::NH: - std::swap(out_shape[0], out_shape[2]); - std::swap(padding[0], padding[2]); + std::swap(output_shape[0], output_shape[2]); + std::swap(output_padded_shape[0], output_padded_shape[2]); break; case TransposeOpDim::NW: - std::swap(out_shape[0], out_shape[3]); - std::swap(padding[0], padding[3]); + std::swap(output_shape[0], output_shape[3]); + std::swap(output_padded_shape[0], output_padded_shape[3]); break; case TransposeOpDim::CW: - std::swap(out_shape[1], out_shape[3]); - std::swap(padding[1], padding[3]); + std::swap(output_shape[1], output_shape[3]); + std::swap(output_padded_shape[1], output_padded_shape[3]); break; } - return {tt::tt_metal::LegacyShape(out_shape, padding)}; -} - -std::vector Transpose::create_output_tensors(const std::vector &input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - // This is only for WH + auto output_mem_config = this->output_mem_config; if (this->output_mem_config.is_sharded()) { if (this->dim == TransposeOpDim::WH) { + const auto& input_padded_shape = input_tensor.get_padded_shape(); ShardSpec shard_spec = input_tensor.shard_spec().value(); - shard_spec.shape[0] = shard_spec.shape[0] / input_tensor.get_legacy_shape()[-2] * input_tensor.get_legacy_shape()[-1]; - shard_spec.shape[1] = input_tensor.get_legacy_shape()[-2]; - const auto output_shape = this->compute_output_shapes(input_tensors)[0]; - auto mem_config = this->output_mem_config; - mem_config.shard_spec = shard_spec; - return {create_device_tensor( - output_shape, - input_tensor.get_dtype(), - input_tensor.get_layout(), - input_tensor.device(), - mem_config)}; + shard_spec.shape[0] = shard_spec.shape[0] / input_padded_shape[-2] * input_padded_shape[-1]; + shard_spec.shape[1] = input_padded_shape[-2]; + output_mem_config.shard_spec = shard_spec; } else if (this->dim == TransposeOpDim::HC) { - const auto output_shape = this->compute_output_shapes(input_tensors)[0]; - auto mem_config = this->output_mem_config; - mem_config.shard_spec = input_tensor.shard_spec().value(); - return {create_device_tensor( - output_shape, - input_tensor.get_dtype(), - input_tensor.get_layout(), - input_tensor.device(), - mem_config)}; + output_mem_config.shard_spec = input_tensor.shard_spec().value(); } else { TT_ASSERT(false, "Unsupported sharding"); } } - return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config); + return {ttnn::TensorSpec(output_shape, TensorLayout::fromLegacyPaddedShape(input_tensor.get_dtype(), PageConfig(input_tensor.get_layout()), output_mem_config, ttnn::Shape(output_shape.view(), output_padded_shape.view())))}; } operation::ProgramWithCallbacks Transpose::create_program(const std::vector& input_tensors, std::vector &output_tensors) const { diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.hpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.hpp index 776509a6c80..44979d114e0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.hpp @@ -23,8 +23,7 @@ struct Transpose { const MemoryConfig output_mem_config; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors(const std::vector &input_tensors) const; + std::vector compute_output_specs(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; TransposeOpParallelizationStrategy get_parallelization_strategy(const std::vector &input_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/device/reduce_op.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/device/reduce_op.cpp index 0278c3da6c3..2fc0719ec97 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/device/reduce_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/device/reduce_op.cpp @@ -57,46 +57,39 @@ void Reduce::validate(const std::vector& input_tensors) const { } } -std::vector Reduce::compute_output_shapes(const std::vector& input_tensors) const { +std::vector Reduce::compute_output_specs(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); - auto output_shape = input_tensor.get_legacy_shape(); - auto padding = output_shape.padding(); + // TODO: Remove usage of input/output padded shape + // - Get output alignment from input alignment and output dtype, layout, mem_config + // - Get shard spec from output strides (logical shape + alignment)? + auto output_shape = input_tensor.get_logical_shape(); + auto output_padded_shape = input_tensor.get_padded_shape(); switch (this->dim) { case ReduceOpDim::H: - output_shape[2] = TILE_HEIGHT; - padding[2] = Padding::PadDimension{0, 31}; + output_shape[2] = 1; + output_padded_shape[2] = TILE_HEIGHT; break; case ReduceOpDim::W: - output_shape[3] = TILE_WIDTH; - padding[3] = Padding::PadDimension{0, 31}; + output_shape[3] = 1; + output_padded_shape[3] = TILE_WIDTH; break; case ReduceOpDim::HW: - output_shape[2] = TILE_HEIGHT; - output_shape[3] = TILE_WIDTH; - padding[2] = Padding::PadDimension{0, 31}; - padding[3] = Padding::PadDimension{0, 31}; + output_shape[2] = 1; + output_shape[3] = 1; + output_padded_shape[2] = TILE_HEIGHT; + output_padded_shape[3] = TILE_WIDTH; break; } - return {tt::tt_metal::LegacyShape(output_shape, padding)}; -} -std::vector Reduce::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - if (this->output_mem_config.is_sharded()) { - auto output_shape = this->compute_output_shapes(input_tensors).at(0); + auto output_mem_config = this->output_mem_config; + if (output_mem_config.is_sharded()) { auto shard_spec = input_tensor.shard_spec().value(); // TODO: This will segfault if input is not sharded... - // TODO: For reduction along H, the shard height is always 1 padded up to 32 (tile height) - // Need to clean this up to have new layout account for sharding with padding - shard_spec.shape[0] = tt_metal::compute_volume(output_shape) / output_shape[-1]; - auto mem_config = this->output_mem_config; - mem_config.shard_spec = shard_spec; - return { - create_device_tensor(output_shape, this->output_dtype, Layout::TILE, input_tensor.device(), mem_config)}; - } else { - return operation::generic_create_output_tensors( - *this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config); + shard_spec.shape[0] = output_padded_shape.volume() / output_padded_shape[-1]; + output_mem_config.shard_spec = shard_spec; } + + return {ttnn::TensorSpec(output_shape, TensorLayout::fromLegacyPaddedShape(this->output_dtype, PageConfig(Layout::TILE), output_mem_config, ttnn::Shape(output_shape.view(), output_padded_shape.view())))}; } operation::ProgramWithCallbacks Reduce::create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/device/reduce_op.hpp b/ttnn/cpp/ttnn/operations/reduction/generic/device/reduce_op.hpp index 2f7dfbdabd6..c53598910a1 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/device/reduce_op.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/device/reduce_op.hpp @@ -30,8 +30,7 @@ struct Reduce { ttnn::DeviceComputeKernelConfig compute_kernel_config; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors(const std::vector &input_tensors) const; + std::vector compute_output_specs(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; ReduceOpParallelizationStrategy get_parallelization_strategy(const std::vector& input_tensors) const; }; diff --git a/ttnn/cpp/ttnn/run_operation.cpp b/ttnn/cpp/ttnn/run_operation.cpp index 40e0f3a44ab..3d03485a101 100644 --- a/ttnn/cpp/ttnn/run_operation.cpp +++ b/ttnn/cpp/ttnn/run_operation.cpp @@ -298,18 +298,31 @@ template OptionalTensors run_without_autoformat( uint8_t cq_id); std::vector extract_legacy_shapes( - const std::variant, std::vector>&& shapes, const std::function& layout_provider) { + const std::variant, std::vector, std::vector>&& shapes, const std::function& layout_provider, const bool use_tensor_layout_from_tensor_spec) { if (std::holds_alternative>(shapes)) { return std::get>(std::move(shapes)); + } else if (std::holds_alternative>(shapes)) { + const auto& simple_shapes = std::get>(shapes); + std::vector legacy_shapes; + legacy_shapes.reserve(simple_shapes.size()); + for (size_t idx = 0; idx < simple_shapes.size(); idx++) { + TensorLayout tensor_layout = layout_provider(idx); + legacy_shapes.emplace_back(simple_shapes[idx].view(), tensor_layout.compute_padded_shape(simple_shapes[idx]).view()); + } + return legacy_shapes; + } else if (std::holds_alternative>(shapes)) { + const auto& tensor_specs = std::get>(shapes); + std::vector legacy_shapes; + legacy_shapes.reserve(tensor_specs.size()); + for (size_t idx = 0; idx < tensor_specs.size(); idx++) { + const auto& [simple_shape, output_layout] = tensor_specs[idx]; + TensorLayout tensor_layout = use_tensor_layout_from_tensor_spec ? output_layout : layout_provider(idx); + legacy_shapes.emplace_back(simple_shape.view(), tensor_layout.compute_padded_shape(simple_shape).view()); + } + return legacy_shapes; + } else { + TT_THROW("extract_legacy_shapes only supports LegacyShape, SimpleShape, or TensorSpec"); } - const auto& simple_shapes = std::get>(shapes); - std::vector legacy_shapes; - legacy_shapes.reserve(simple_shapes.size()); - for (size_t idx = 0; idx < simple_shapes.size(); idx++) { - TensorLayout tensor_layout = layout_provider(idx); - legacy_shapes.emplace_back(simple_shapes[idx].view(), tensor_layout.compute_padded_shape(simple_shapes[idx]).view()); - } - return legacy_shapes; } // To be deprecated/removed in favor of new implementation where ops specifically request how to format inputs/outputss @@ -361,7 +374,7 @@ Tensors run_with_autoformat( auto output_shapes = extract_legacy_shapes(operation.compute_output_shapes(input_tensors), [&](size_t idx) { auto tensor = output_tensors[idx]; return TensorLayout(tensor.get_dtype(), Layout::TILE, tensor.memory_config()); - }); + }, /*use_tensor_layout_from_tensor_spec=*/ true); TT_ASSERT(output_tensors.size() == output_shapes.size()); @@ -424,7 +437,7 @@ Tensors run_with_autoformat( auto output_shapes = extract_legacy_shapes(operation.compute_output_shapes(input_tensors), [&](size_t idx) { auto tensor = output_tensors[idx]; return TensorLayout(tensor.get_dtype(), output_layouts[idx], tensor.memory_config()); - }); + }, /*use_tensor_layout_from_tensor_spec=*/ false); TT_ASSERT(output_tensors.size() == output_shapes.size()); TT_ASSERT(output_tensors.size() == output_layouts.size()); diff --git a/ttnn/cpp/ttnn/run_operation.hpp b/ttnn/cpp/ttnn/run_operation.hpp index 3286146dfd8..1305a7e4da5 100644 --- a/ttnn/cpp/ttnn/run_operation.hpp +++ b/ttnn/cpp/ttnn/run_operation.hpp @@ -17,6 +17,11 @@ namespace tt::tt_metal { namespace operation { using ttnn::operations::experimental::auto_format::FormatParams; + +// TODO: create_output_tensors should become a fully manual path with no dependency on infra +// - Pass output shapes directly +// - Move default values for output_dtype and output_mem_config inside ops +// - This function becomes just a regular helper function template auto generic_create_output_tensors( const ConcreteOperation& operation, diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index 2aa903fad62..8291cba455d 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -368,5 +368,6 @@ bool validate_worker_modes(const std::vector &workers); namespace ttnn { using Tensor = tt::tt_metal::Tensor; +using TensorSpec = std::pair; } // namespace ttnn