diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index c9ccbfa2b..b65f58a4d 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -439,17 +439,22 @@ def TTNN_MaxOp : TTNN_ReductionOp<"max"> { }]; } -def TTNN_EmbeddingOp : TTNN_Op<"embedding"> { +def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> { let summary = "Embedding op."; let description = [{ Embedding operation. }]; let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, AnyRankedTensor:$weight); let results = (outs AnyRankedTensor:$result); + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + let hasVerifier = 1; } diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 598d97b79..217371c58 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -121,7 +121,7 @@ table ReductionOp { table EmbeddingOp { input: tt.target.TensorRef; weight: tt.target.TensorRef; - output: tt.target.TensorRef; + out: tt.target.TensorRef; } table SoftmaxOp { diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 02159e1b0..c0864c8d8 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -212,7 +212,7 @@ class ToLayoutOpConversionPattern bool shouldForceRowMajor(ttir::ToLayoutOp op) const { for (mlir::Operation *user : op.getResult().getUsers()) { if (isa(user) || isa(user) || - isa(user)) { + isa(user) || isa(user)) { return true; } } @@ -317,7 +317,7 @@ class EmbeddingOpConversionPattern ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), adaptor.getWeight()); + adaptor.getInput(), adaptor.getOutput(), adaptor.getWeight()); return success(); } diff --git a/runtime/lib/ttnn/operations/embedding/embedding.cpp b/runtime/lib/ttnn/operations/embedding/embedding.cpp index 21742e58d..47b27ca9a 100644 --- a/runtime/lib/ttnn/operations/embedding/embedding.cpp +++ b/runtime/lib/ttnn/operations/embedding/embedding.cpp @@ -18,14 +18,16 @@ void run(const ::tt::target::ttnn::EmbeddingOp *op, ProgramContext &context) { // default params for embedding op std::optional padToken = std::nullopt; - ::tt::tt_metal::Layout layout = ::ttnn::ROW_MAJOR_LAYOUT; + ::tt::tt_metal::Layout layout = utils::isTilized(op->out()) + ? ::ttnn::TILE_LAYOUT + : ::ttnn::ROW_MAJOR_LAYOUT; auto embeddingsType = ::ttnn::operations::embedding::EmbeddingsType::GENERIC; - ::ttnn::DataType outputDataType = utils::getDataType(op->output()); + ::ttnn::DataType outputDataType = utils::getDataType(op->out()); ::ttnn::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->output()); + utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ::ttnn::embedding(input, weight, padToken, layout, embeddingsType, outputDataType, outputMemoryConfig); - tensorPool.insert_or_assign(op->output()->global_id(), out); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::embedding diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp index 1f353a9bc..9cf2ee2b7 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp @@ -23,6 +23,12 @@ bool isOnDevice(const ::ttnn::Tensor &tensor) { return tensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE; } +bool isTilized(const ::tt::target::TensorRef *tensorRef) { + const ::tt::target::Dim2d *tileShape = + tensorRef->desc()->layout()->memory_desc()->tile_shape(); + return tileShape->x() == 32 and tileShape->y() == 32; +} + ::tt::target::MemorySpace getMemorySpace(const ::tt::target::TensorRef *tensorRef) { return tensorRef->desc()->layout()->memory_desc()->memory_space(); diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h index e368b2d72..b0aac074d 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h @@ -17,6 +17,8 @@ bool isOnHost(const ::ttnn::Tensor &tensor); bool isOnDevice(const ::ttnn::Tensor &tensor); +bool isTilized(const ::tt::target::TensorRef *tensorRef); + bool inSystemMemory(const ::tt::target::TensorRef *tensorRef); ::tt::target::MemorySpace diff --git a/test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir b/test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir index 16db98b82..45318423b 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir @@ -1,10 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { - func.func @forward(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> { - %0 = tensor.empty() : tensor<32x128xf32> + func.func @forward(%arg0: tensor<32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x128xbf16> { + %0 = tensor.empty() : tensor<32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<512x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32> - return %1 : tensor<32x128xf32> + %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xbf16>, tensor<512x128xbf16>, tensor<32x128xbf16>) -> tensor<32x128xbf16> + return %1 : tensor<32x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir b/test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir index e5a99b322..1d2813668 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir @@ -1,12 +1,11 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --convert-ttir-to-ttnn %s | FileCheck %s -// UNSUPPORTED: true +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { - func.func @forward(%arg0: tensor<1x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<1x32x128xf32> { + func.func @forward(%arg0: tensor<1x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<1x32x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<1x32x128xf32> + %0 = tensor.empty() : tensor<1x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xf32>, tensor<512x128xf32>, tensor<1x32x128xf32>) -> tensor<1x32x128xf32> - return %1 : tensor<1x32x128xf32> + %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>) -> tensor<1x32x128xbf16> + return %1 : tensor<1x32x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir b/test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir index de82b14ea..e5fb1421c 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir @@ -1,10 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { - func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x32x128xf32> { - %0 = tensor.empty() : tensor<32x32x128xf32> + func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> { + %0 = tensor.empty() : tensor<32x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<512x128xf32>, tensor<32x32x128xf32>) -> tensor<32x32x128xf32> - return %1 : tensor<32x32x128xf32> + %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16> + return %1 : tensor<32x32x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/remove_empty_op.mlir b/test/ttmlir/Dialect/TTNN/remove_empty_op.mlir index b6b3501c9..9640d91e2 100644 --- a/test/ttmlir/Dialect/TTNN/remove_empty_op.mlir +++ b/test/ttmlir/Dialect/TTNN/remove_empty_op.mlir @@ -1,11 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint module attributes {} { - func.func @forward(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> { + func.func @forward(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> { // CHECK-NOT: "ttnn.empty" - %0 = tensor.empty() : tensor<32x128xf32> - // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<512x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32> - return %1 : tensor<32x128xf32> + %0 = tensor.empty() : tensor<2x4x32x32xbf16> + // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] + %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> + return %1 : tensor<2x4x32x32xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/embedding/embedding_1d_tensor.mlir b/test/ttmlir/Silicon/TTNN/embedding/embedding_1d_tensor.mlir index ea1439de8..f4850e4f8 100644 --- a/test/ttmlir/Silicon/TTNN/embedding/embedding_1d_tensor.mlir +++ b/test/ttmlir/Silicon/TTNN/embedding/embedding_1d_tensor.mlir @@ -1,14 +1,13 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -// UNSUPPORTED: true #any_device = #tt.operand_constraint module attributes {} { - func.func @forward(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> { + func.func @forward(%arg0: tensor<32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<32x128xf32> + %0 = tensor.empty() : tensor<32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<512x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32> - return %1 : tensor<32x128xf32> + %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xbf16>, tensor<512x128xbf16>, tensor<32x128xbf16>) -> tensor<32x128xbf16> + return %1 : tensor<32x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/embedding/embedding_non_tile.mlir b/test/ttmlir/Silicon/TTNN/embedding/embedding_non_tile.mlir index 89ee4bc1c..c26634771 100644 --- a/test/ttmlir/Silicon/TTNN/embedding/embedding_non_tile.mlir +++ b/test/ttmlir/Silicon/TTNN/embedding/embedding_non_tile.mlir @@ -1,14 +1,13 @@ -// RUN: ttmlir-opt --ttir-load-system-desc="path=%system_desc_path%" --ttir-layout --convert-ttir-to-ttnn %s > %t.mlir +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -// UNSUPPORTED: true #any_device = #tt.operand_constraint module attributes {} { - func.func @forward(%arg0: tensor<1x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<1x32x128xf32> { + func.func @forward(%arg0: tensor<1x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<1x32x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<1x32x128xf32> + %0 = tensor.empty() : tensor<1x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xf32>, tensor<512x128xf32>, tensor<1x32x128xf32>) -> tensor<1x32x128xf32> - return %1 : tensor<1x32x128xf32> + %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>) -> tensor<1x32x128xbf16> + return %1 : tensor<1x32x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir b/test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir index e0f43be8a..343bb5e76 100644 --- a/test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir +++ b/test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir @@ -1,14 +1,13 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -// UNSUPPORTED: true #any_device = #tt.operand_constraint module attributes {} { - func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x32x128xf32> { + func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<32x32x128xf32> + %0 = tensor.empty() : tensor<32x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<512x128xf32>, tensor<32x32x128xf32>) -> tensor<32x32x128xf32> - return %1 : tensor<32x32x128xf32> + %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16> + return %1 : tensor<32x32x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_embedding.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_embedding.mlir new file mode 100644 index 000000000..343bb5e76 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_embedding.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<32x32x128xbf16> + // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] + %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16> + return %1 : tensor<32x32x128xbf16> + } +}