Skip to content

Commit

Permalink
Fix embedding tests that were previously failing (#1128)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT authored Nov 4, 2024
1 parent 51f6356 commit 6988418
Show file tree
Hide file tree
Showing 14 changed files with 68 additions and 44 deletions.
7 changes: 6 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -439,17 +439,22 @@ def TTNN_MaxOp : TTNN_ReductionOp<"max"> {
}];
}

def TTNN_EmbeddingOp : TTNN_Op<"embedding"> {
def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Embedding operation.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
AnyRankedTensor:$weight);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ table ReductionOp {
table EmbeddingOp {
input: tt.target.TensorRef;
weight: tt.target.TensorRef;
output: tt.target.TensorRef;
out: tt.target.TensorRef;
}

table SoftmaxOp {
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class ToLayoutOpConversionPattern
bool shouldForceRowMajor(ttir::ToLayoutOp op) const {
for (mlir::Operation *user : op.getResult().getUsers()) {
if (isa<ttir::Conv2dOp>(user) || isa<ttir::MaxPool2dOp>(user) ||
isa<ttir::SliceOp>(user)) {
isa<ttir::SliceOp>(user) || isa<ttir::EmbeddingOp>(user)) {
return true;
}
}
Expand Down Expand Up @@ -317,7 +317,7 @@ class EmbeddingOpConversionPattern
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::EmbeddingOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getWeight());
adaptor.getInput(), adaptor.getOutput(), adaptor.getWeight());

return success();
}
Expand Down
10 changes: 6 additions & 4 deletions runtime/lib/ttnn/operations/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ void run(const ::tt::target::ttnn::EmbeddingOp *op, ProgramContext &context) {

// default params for embedding op
std::optional<int> padToken = std::nullopt;
::tt::tt_metal::Layout layout = ::ttnn::ROW_MAJOR_LAYOUT;
::tt::tt_metal::Layout layout = utils::isTilized(op->out())
? ::ttnn::TILE_LAYOUT
: ::ttnn::ROW_MAJOR_LAYOUT;
auto embeddingsType = ::ttnn::operations::embedding::EmbeddingsType::GENERIC;
::ttnn::DataType outputDataType = utils::getDataType(op->output());
::ttnn::DataType outputDataType = utils::getDataType(op->out());
::ttnn::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->output());
utils::createMemoryConfig(op->out());
::ttnn::Tensor out =
::ttnn::embedding(input, weight, padToken, layout, embeddingsType,
outputDataType, outputMemoryConfig);
tensorPool.insert_or_assign(op->output()->global_id(), out);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::embedding
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ bool isOnDevice(const ::ttnn::Tensor &tensor) {
return tensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE;
}

bool isTilized(const ::tt::target::TensorRef *tensorRef) {
const ::tt::target::Dim2d *tileShape =
tensorRef->desc()->layout()->memory_desc()->tile_shape();
return tileShape->x() == 32 and tileShape->y() == 32;
}

::tt::target::MemorySpace
getMemorySpace(const ::tt::target::TensorRef *tensorRef) {
return tensorRef->desc()->layout()->memory_desc()->memory_space();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ bool isOnHost(const ::ttnn::Tensor &tensor);

bool isOnDevice(const ::ttnn::Tensor &tensor);

bool isTilized(const ::tt::target::TensorRef *tensorRef);

bool inSystemMemory(const ::tt::target::TensorRef *tensorRef);

::tt::target::MemorySpace
Expand Down
8 changes: 4 additions & 4 deletions test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> {
%0 = tensor.empty() : tensor<32x128xf32>
func.func @forward(%arg0: tensor<32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x128xbf16> {
%0 = tensor.empty() : tensor<32x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<512x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32>
return %1 : tensor<32x128xf32>
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xbf16>, tensor<512x128xbf16>, tensor<32x128xbf16>) -> tensor<32x128xbf16>
return %1 : tensor<32x128xbf16>
}
}
11 changes: 5 additions & 6 deletions test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --convert-ttir-to-ttnn %s | FileCheck %s
// UNSUPPORTED: true
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<1x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<1x32x128xf32> {
func.func @forward(%arg0: tensor<1x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<1x32x128xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x32x128xf32>
%0 = tensor.empty() : tensor<1x32x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xf32>, tensor<512x128xf32>, tensor<1x32x128xf32>) -> tensor<1x32x128xf32>
return %1 : tensor<1x32x128xf32>
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>) -> tensor<1x32x128xbf16>
return %1 : tensor<1x32x128xbf16>
}
}
8 changes: 4 additions & 4 deletions test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x32x128xf32> {
%0 = tensor.empty() : tensor<32x32x128xf32>
func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> {
%0 = tensor.empty() : tensor<32x32x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<512x128xf32>, tensor<32x32x128xf32>) -> tensor<32x32x128xf32>
return %1 : tensor<32x32x128xf32>
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16>
return %1 : tensor<32x32x128xbf16>
}
}
12 changes: 6 additions & 6 deletions test/ttmlir/Dialect/TTNN/remove_empty_op.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> {
func.func @forward(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> {
// CHECK-NOT: "ttnn.empty"
%0 = tensor.empty() : tensor<32x128xf32>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<512x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32>
return %1 : tensor<32x128xf32>
%0 = tensor.empty() : tensor<2x4x32x32xbf16>
// CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]]
%1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16>
return %1 : tensor<2x4x32x32xbf16>
}
}
9 changes: 4 additions & 5 deletions test/ttmlir/Silicon/TTNN/embedding/embedding_1d_tensor.mlir
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// UNSUPPORTED: true
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> {
func.func @forward(%arg0: tensor<32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x128xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<32x128xf32>
%0 = tensor.empty() : tensor<32x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<512x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32>
return %1 : tensor<32x128xf32>
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xbf16>, tensor<512x128xbf16>, tensor<32x128xbf16>) -> tensor<32x128xbf16>
return %1 : tensor<32x128xbf16>
}
}
11 changes: 5 additions & 6 deletions test/ttmlir/Silicon/TTNN/embedding/embedding_non_tile.mlir
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
// RUN: ttmlir-opt --ttir-load-system-desc="path=%system_desc_path%" --ttir-layout --convert-ttir-to-ttnn %s > %t.mlir
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// UNSUPPORTED: true
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<1x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<1x32x128xf32> {
func.func @forward(%arg0: tensor<1x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<1x32x128xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x32x128xf32>
%0 = tensor.empty() : tensor<1x32x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xf32>, tensor<512x128xf32>, tensor<1x32x128xf32>) -> tensor<1x32x128xf32>
return %1 : tensor<1x32x128xf32>
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>) -> tensor<1x32x128xbf16>
return %1 : tensor<1x32x128xbf16>
}
}
9 changes: 4 additions & 5 deletions test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// UNSUPPORTED: true
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x32x128xf32> {
func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<32x32x128xf32>
%0 = tensor.empty() : tensor<32x32x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<512x128xf32>, tensor<32x32x128xf32>) -> tensor<32x32x128xf32>
return %1 : tensor<32x32x128xf32>
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16>
return %1 : tensor<32x32x128xbf16>
}
}
13 changes: 13 additions & 0 deletions test/ttmlir/Silicon/TTNN/perf_unit/test_perf_embedding.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<32x32x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16>
return %1 : tensor<32x32x128xbf16>
}
}

0 comments on commit 6988418

Please sign in to comment.