diff --git a/tests/ttnn/unit_tests/operations/test_distributed_layernorm.py b/tests/ttnn/unit_tests/operations/test_distributed_layernorm.py index fc27488c11d..f675bd82a30 100644 --- a/tests/ttnn/unit_tests/operations/test_distributed_layernorm.py +++ b/tests/ttnn/unit_tests/operations/test_distributed_layernorm.py @@ -66,7 +66,7 @@ def tt_distributed_layernorm(inp, gamma, beta, epsilon, is_rmsnorm, compute_kern def run_distributed_layernorm( - inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, devices, fp32_enabled=False, iterations=1 + inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, mesh_device, fp32_enabled=False, iterations=1 ): compute_kernel_config = ttnn.WormholeComputeKernelConfig( math_fidelity=ttnn.MathFidelity.HiFi4, # Highest fidelity @@ -94,7 +94,7 @@ def run_distributed_layernorm( ttnn.as_tensor( inp_chunked[d], dtype=dtype, - device=devices[d], + device=mesh_device.get_devices()[d], layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) @@ -106,7 +106,7 @@ def run_distributed_layernorm( ttnn.as_tensor( gamma_chunked[d].reshape(1, 1, -1, 32), dtype=ttnn.bfloat16, - device=devices[d], + device=mesh_device.get_devices()[d], layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) @@ -118,7 +118,7 @@ def run_distributed_layernorm( ttnn.as_tensor( beta_chunked[d].reshape(1, 1, -1, 32), dtype=ttnn.bfloat16, - device=devices[d], + device=mesh_device.get_devices()[d], layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) @@ -135,51 +135,62 @@ def run_distributed_layernorm( assert passing -@skip_for_grayskull("Requires eth connected devices to run") -@pytest.mark.parametrize( - "iterations", - [2], - ids=["loops2"], -) -@pytest.mark.parametrize( - "dtype", - (ttnn.bfloat16, ttnn.bfloat8_b), - ids=["BFLOAT16_in", "BFLOAT8_B_in"], -) -@pytest.mark.parametrize( - "stats_dtype", - (ttnn.bfloat16, ttnn.bfloat8_b), - ids=["BFLOAT16_stats", "BFLOAT8_B_stats"], -) -@pytest.mark.parametrize( - "inp_shape", - [ - (1, 1, 2048, 8192), - (1, 1, 128, 8192), - (2, 1, 128, 8192), - ], - ids=["inp_shape0", "inp_shape1", "inp_shape2"], -) -@pytest.mark.parametrize( - "n_devices", - [4, 8], -) -@pytest.mark.parametrize( - "is_rmsnorm", - [True, False], - ids=["rmsnorm", "layernorm"], -) -def test_distributed_layernorm_with_program_cache( - inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, iterations, all_devices, use_program_cache +inp_shapes = [ + (1, 1, 2048, 8192), + (1, 1, 128, 8192), + (2, 1, 128, 8192), +] +inp_shape_ids = ["inp_shape0", "inp_shape1", "inp_shape2"] + +stats_dtypes = [ttnn.bfloat16, ttnn.bfloat8_b] +stats_dtypes_ids = ["BFLOAT16_stats", "BFLOAT8_B_stats"] + +dtypes = [ttnn.bfloat16, ttnn.bfloat8_b] +dtype_ids = ["BFLOAT16_in", "BFLOAT8_B_in"] + +rms_norm_parametrizations = [True, False] +rms_norm_parametrization_ids = ["rmsnorm", "layernorm"] + + +def run_test_distributed_layernorm_with_program_cache_and_checks( + inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, mesh_device, use_program_cache, iterations ): - if len(all_devices) != 8: + if mesh_device.get_num_devices() < n_devices: pytest.skip("Not T3000!") - devices = get_devices_for_t3000(all_devices, n_devices) - - run_distributed_layernorm(inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, devices, iterations=iterations) + run_distributed_layernorm(inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, mesh_device, iterations=iterations) - for d in range(len(devices)): - assert devices[d].num_program_cache_entries() == 3, "Program cache should have only 3 entries, but has " + str( - devices[d].num_program_cache_entries() + for d in mesh_device.get_devices(): + assert d.num_program_cache_entries() == 3, "Program cache should have only 3 entries, but has " + str( + d.num_program_cache_entries() ) + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize("iterations", [2], ids=["loops2"]) +@pytest.mark.parametrize("dtype", dtypes, ids=dtype_ids) +@pytest.mark.parametrize("stats_dtype", stats_dtypes, ids=stats_dtypes_ids) +@pytest.mark.parametrize("inp_shape", inp_shapes, ids=inp_shape_ids) +@pytest.mark.parametrize("n_devices", [8]) +@pytest.mark.parametrize("is_rmsnorm", rms_norm_parametrizations, ids=rms_norm_parametrization_ids) +def test_distributed_layernorm_with_program_cache( + inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, iterations, t3k_mesh_device, use_program_cache +): + run_test_distributed_layernorm_with_program_cache_and_checks( + inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, t3k_mesh_device, use_program_cache, iterations=iterations + ) + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize("iterations", [2], ids=["loops2"]) +@pytest.mark.parametrize("dtype", dtypes, ids=dtype_ids) +@pytest.mark.parametrize("stats_dtype", stats_dtypes, ids=stats_dtypes_ids) +@pytest.mark.parametrize("inp_shape", inp_shapes, ids=inp_shape_ids) +@pytest.mark.parametrize("n_devices", [4]) +@pytest.mark.parametrize("is_rmsnorm", rms_norm_parametrizations, ids=rms_norm_parametrization_ids) +def test_distributed_layernorm_with_program_cache_4chip( + inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, iterations, pcie_mesh_device, use_program_cache +): + run_test_distributed_layernorm_with_program_cache_and_checks( + inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, pcie_mesh_device, use_program_cache, iterations=iterations + )