Skip to content

Commit

Permalink
oFurther push of embedding op #104. Hitting L1 issues on Metal. (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
nvukobratTT authored Sep 2, 2024
1 parent 831b1b5 commit 3dec39a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 5 deletions.
3 changes: 1 addition & 2 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,11 +497,10 @@ class MLIRGenerator
void init_lowering_handler_map()
{
lowering_handler_map["add"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::AddOp>;
lowering_handler_map["embedding"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EmbeddingOp>;
lowering_handler_map["matmul"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MatmulOp>;
lowering_handler_map["multiply"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MultiplyOp>;
lowering_handler_map["reduce_avg"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MeanOp>;
lowering_handler_map["reduce_avg"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MeanOp>;
lowering_handler_map["reduce_sum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SumOp>;
lowering_handler_map["reduce_sum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SumOp>;
lowering_handler_map["relu"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReluOp>;
lowering_handler_map["softmax"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SoftmaxOp>;
Expand Down
5 changes: 2 additions & 3 deletions pybuda/test/mlir/llama/tests/test_llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pybuda.op.eval.common import compare_with_golden_pcc


@pytest.mark.xfail(reason="Embedding op is not supported on MLIR.")
@pytest.mark.xfail(reason="L1 allocation issue on Metal")
def test_llama_embedding():
# Load Llama 3B model and tokenizer
framework_model = load_model()
Expand All @@ -20,7 +20,7 @@ def test_llama_embedding():
inputs = [
torch.randint(0, vocab_size, (1, 12)), # Input token IDs
]

# Sanity run
golden_output = framework_model(*inputs)

Expand All @@ -33,4 +33,3 @@ def test_llama_embedding():

# Validate results
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99)

33 changes: 33 additions & 0 deletions pybuda/test/mlir/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,36 @@ def forward(self, x):

co_out = [co.to("cpu") for co in co_out]
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)


# @pytest.mark.parametrize("vocab_size", [2048, 16384, 32000])
# @pytest.mark.parametrize("token_num", [1, 7, 32])
# @pytest.mark.parametrize("embedding_dim", [128, 512, 3200])
@pytest.mark.xfail(reason="L1 allocation issue on Metal")
@pytest.mark.parametrize("vocab_size", [32000])
@pytest.mark.parametrize("token_num", [12])
@pytest.mark.parametrize("embedding_dim", [3200])
def test_embedding(vocab_size, token_num, embedding_dim):
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.enable_tvm_cpu_fallback = False

class Embedding(nn.Module):
def __init__(self):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)

def forward(self, x):
return self.embedding(x)

inputs = [
torch.randint(0, vocab_size, (1, token_num)),
]

framework_model = Embedding()
fw_out = framework_model(*inputs)

compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)

0 comments on commit 3dec39a

Please sign in to comment.