Skip to content

Commit

Permalink
Test cases for matmul, mean, and sqrt ops #106 #114 #117.
Browse files Browse the repository at this point in the history
Some ops are supported e2e, while others are blocked due to different reasons. Follow up issues will address failing cases.
  • Loading branch information
nvukobratTT committed Aug 22, 2024
1 parent 6a0f808 commit 03b4c99
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1 deletion.
1 change: 1 addition & 0 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ class MLIRGenerator
lowering_handler_map["softmax"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SoftmaxOp>;
lowering_handler_map["reduce_sum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SumOp>;
lowering_handler_map["reduce_avg"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MeanOp>;
// lowering_handler_map["sqrt"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SqrtOp>;
}
};
}
Expand Down
85 changes: 84 additions & 1 deletion pybuda/test/mlir/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,87 @@ def forward(self, a):
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)


@pytest.mark.parametrize("batch_size", [1, 7, 32])
@pytest.mark.parametrize("outer_dim_x", [7, 32, 41, 64])
@pytest.mark.parametrize("outer_dim_y", [7, 32, 41, 64])
@pytest.mark.parametrize("inner_dim", [1, 7, 32, 41, 64])
def test_matmul(batch_size, outer_dim_x, outer_dim_y, inner_dim):
class Matmul(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.matmul(x, y)

inputs = [
torch.rand(batch_size, outer_dim_x, inner_dim),
torch.rand(batch_size, inner_dim, outer_dim_y),
]

framework_model = Matmul()
fw_out = framework_model(*inputs)

compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)


@pytest.mark.parametrize("x_shape", [7, 32, 41])
@pytest.mark.parametrize("y_shape", [7, 32, 41])
@pytest.mark.parametrize("dim", [1, 2])
def test_mean(x_shape, y_shape, dim):
if dim == 1:
pytest.skip("FFE: Unsupported squeeze operation")
if dim == 2:
# Note: Some tests are passing when run in group, while failing when running individually
pytest.skip("TTNN: Tensor layout bugs")

class Mean(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.mean(x, dim=dim)

inputs = [
torch.rand(1, x_shape, y_shape),
]

framework_model = Mean()
fw_out = framework_model(*inputs)

compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)


@pytest.mark.parametrize("x_shape", [7, 32, 41])
@pytest.mark.parametrize("y_shape", [7, 32, 41])
def test_sqrt(x_shape, y_shape):
pytest.skip("FFE: Requires MLIR uplift")
class Sqrt(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sqrt(x)

inputs = [
torch.rand(1, x_shape, y_shape),
]

framework_model = Sqrt()
fw_out = framework_model(*inputs)

compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)

0 comments on commit 03b4c99

Please sign in to comment.