Skip to content

Commit

Permalink
[TVM] Decompose repeat_interleave pytorch op and add sanity test
Browse files Browse the repository at this point in the history
(cherry picked from commit a4ff19a91355338f49d068d9e64302c52c258825)
  • Loading branch information
chandrasekaranpradeep authored and vmilosevic committed Aug 13, 2024
1 parent 93793b2 commit 0f43c18
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions pybuda/test/tvm/sanity/tests_B/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,57 @@ def forward(self,input):
verify_tvm_compile=True,
)
)

@pytest.mark.parametrize("input_shape", ((1, 4), (1, 4, 3), (1, 2, 7, 6)))
@pytest.mark.parametrize("repeat_dims", (1, 2, 3, -1, -2, -3))
@pytest.mark.parametrize("num_repeats", (2, 3))
def test_repeat_interleave_pytorch(test_device, input_shape, repeat_dims, num_repeats):

dims = repeat_dims
if dims < 0:
dims = len(input_shape) + dims
if dims < 0:
pytest.skip()

if dims > int(len(input_shape) - 1) or input_shape[dims] == 1:
pytest.skip()

# Set PyBuda configuration parameters
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.balancer_policy = "Ribbon"
compiler_cfg.default_df_override = pybuda.DataFormat.Float16_b

class Repeat_interleave_model(torch.nn.Module):
def __init__(self, repeats, dims):
super().__init__()
self.repeats = repeats
self.dims = dims

def forward(self, input_tensor):
return torch.repeat_interleave(input_tensor, repeats = self.repeats, dim = self.dims)

model = Repeat_interleave_model(repeats=num_repeats, dims=repeat_dims)
model.eval()

# Create PyBuda module from PyTorch model
tt_model = pybuda.PyTorchModule(
"pt_repeat_interleave", model
)

input_sample = torch.rand(input_shape)

# Run inference on Tenstorrent device
verify_module(
tt_model,
input_shapes=[(input_sample.shape,)],
inputs=[(input_sample,)],
verify_cfg=VerifyConfig(
arch=test_device.arch,
devtype=test_device.devtype,
devmode=test_device.devmode,
test_kind=TestKind.INFERENCE,
verify_pybuda_codegen_vs_framework=True,
verify_tvm_compile=True,
),
)

0 comments on commit 0f43c18

Please sign in to comment.