Skip to content

Commit

Permalink
#0: Finalize transition of eltwise and matmul operations to SimpleSha…
Browse files Browse the repository at this point in the history
…pe (#14777)

* #0: Finalize transition of eltwise and matmul to SimpleShape

* #0: Shape fixup
  • Loading branch information
sminakov-tt authored Nov 7, 2024
1 parent 0b5ca69 commit 63a72ad
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void UnaryDeviceOperation::validate_on_program_cache_miss(

shape_return_value_t UnaryDeviceOperation::compute_output_shapes(
const operation_attributes_t&, const tensor_args_t& tensor_args) {
return {tensor_args.input.get_shape()};
return {tensor_args.input.get_logical_shape()};
}

tensor_return_value_t UnaryDeviceOperation::create_output_tensors(
Expand All @@ -158,13 +158,12 @@ tensor_return_value_t UnaryDeviceOperation::create_output_tensors(
return tensor_args.preallocated_output.value();
}

const auto output_shape = compute_output_shapes(args, tensor_args);

auto output_layout = Layout::TILE;
if (args.output_memory_config.is_sharded()) {
output_layout = tensor_args.input.get_layout();
}

const auto output_shape = tensor_args.input.shape();
return create_device_tensor(
output_shape, args.output_dtype, output_layout, tensor_args.input.device(), args.output_memory_config);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ struct tensor_args_t {

using tensor_return_value_t = Tensor;

using shape_return_value_t = ttnn::Shape;
using shape_return_value_t = ttnn::SimpleShape;

} // namespace ttnn::operations::unary
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ std::vector <ttnn::Tensor> all_gather_matmul(


/* Matmul setup */
bool user_run_batched = ttnn::operations::matmul::detail::is_input_batched(weight_tensor.get_shape());
bool user_run_batched = ttnn::operations::matmul::detail::is_input_batched(weight_tensor.get_logical_shape());
std::optional<CoreCoord> user_core_coord;
if (core_grid.has_value()) {
user_core_coord = CoreCoord(core_grid->x, core_grid->y);
Expand Down
3 changes: 0 additions & 3 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ namespace operations {
namespace matmul {

using ttnn::operations::unary::UnaryWithParam;
using tt::tt_metal::LegacyShape;

/*
* GENERAL MATMUL AND BMM
Expand Down Expand Up @@ -170,8 +169,6 @@ struct Matmul {
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &optional_input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes_dram_sharded(
const std::vector<Tensor> &input_tensors, uint32_t N_unpadded) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors,
Expand Down
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/operations/matmul/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace matmul {

namespace detail {

bool is_input_batched(const ttnn::Shape& shape) {
bool is_input_batched(const ttnn::SimpleShape& shape) {
auto is_batched = false;
for (auto i = 0; i < shape.rank() - 2; ++i) {
if (shape[i] > 1) {
Expand Down Expand Up @@ -109,7 +109,7 @@ Tensor MatmulOperation::invoke(
if (core_grid.has_value()) {
user_core_coord = CoreCoord(core_grid->x, core_grid->y);
}
bool user_run_batched = detail::is_input_batched(input_tensor_b.get_shape());
bool user_run_batched = detail::is_input_batched(input_tensor_b.get_logical_shape());
return bound_matmul(
input_tensor_a,
input_tensor_b,
Expand Down Expand Up @@ -147,7 +147,7 @@ Tensor LinearOperation::invoke(
if (core_grid.has_value()) {
user_core_coord = CoreCoord(core_grid->x, core_grid->y);
}
bool b_is_batched = detail::is_input_batched(input_tensor_b.get_shape());
bool b_is_batched = detail::is_input_batched(input_tensor_b.get_logical_shape());
TT_FATAL(!(b_is_batched && bias.has_value()), "Batched input not supported when bias exists (linear operation).");

return bound_matmul(
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/matmul/matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace matmul {

namespace detail {

bool is_input_batched(const ttnn::Shape& shape);
bool is_input_batched(const ttnn::SimpleShape& logical_shape);

} // namespace detail

Expand Down

0 comments on commit 63a72ad

Please sign in to comment.