Skip to content

Commit

Permalink
#0: fix mesh device fixture selection for test_distributed_layernorm (#…
Browse files Browse the repository at this point in the history
…13433)

 fix mesh device fixture selection for test_distributed_layernorm
  • Loading branch information
SeanNijjar authored Oct 3, 2024
1 parent 73e0dd8 commit 824c167
Showing 1 changed file with 58 additions and 47 deletions.
105 changes: 58 additions & 47 deletions tests/ttnn/unit_tests/operations/test_distributed_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def tt_distributed_layernorm(inp, gamma, beta, epsilon, is_rmsnorm, compute_kern


def run_distributed_layernorm(
inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, devices, fp32_enabled=False, iterations=1
inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, mesh_device, fp32_enabled=False, iterations=1
):
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi4, # Highest fidelity
Expand Down Expand Up @@ -94,7 +94,7 @@ def run_distributed_layernorm(
ttnn.as_tensor(
inp_chunked[d],
dtype=dtype,
device=devices[d],
device=mesh_device.get_devices()[d],
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
Expand All @@ -106,7 +106,7 @@ def run_distributed_layernorm(
ttnn.as_tensor(
gamma_chunked[d].reshape(1, 1, -1, 32),
dtype=ttnn.bfloat16,
device=devices[d],
device=mesh_device.get_devices()[d],
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
Expand All @@ -118,7 +118,7 @@ def run_distributed_layernorm(
ttnn.as_tensor(
beta_chunked[d].reshape(1, 1, -1, 32),
dtype=ttnn.bfloat16,
device=devices[d],
device=mesh_device.get_devices()[d],
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
Expand All @@ -135,51 +135,62 @@ def run_distributed_layernorm(
assert passing


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"iterations",
[2],
ids=["loops2"],
)
@pytest.mark.parametrize(
"dtype",
(ttnn.bfloat16, ttnn.bfloat8_b),
ids=["BFLOAT16_in", "BFLOAT8_B_in"],
)
@pytest.mark.parametrize(
"stats_dtype",
(ttnn.bfloat16, ttnn.bfloat8_b),
ids=["BFLOAT16_stats", "BFLOAT8_B_stats"],
)
@pytest.mark.parametrize(
"inp_shape",
[
(1, 1, 2048, 8192),
(1, 1, 128, 8192),
(2, 1, 128, 8192),
],
ids=["inp_shape0", "inp_shape1", "inp_shape2"],
)
@pytest.mark.parametrize(
"n_devices",
[4, 8],
)
@pytest.mark.parametrize(
"is_rmsnorm",
[True, False],
ids=["rmsnorm", "layernorm"],
)
def test_distributed_layernorm_with_program_cache(
inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, iterations, all_devices, use_program_cache
inp_shapes = [
(1, 1, 2048, 8192),
(1, 1, 128, 8192),
(2, 1, 128, 8192),
]
inp_shape_ids = ["inp_shape0", "inp_shape1", "inp_shape2"]

stats_dtypes = [ttnn.bfloat16, ttnn.bfloat8_b]
stats_dtypes_ids = ["BFLOAT16_stats", "BFLOAT8_B_stats"]

dtypes = [ttnn.bfloat16, ttnn.bfloat8_b]
dtype_ids = ["BFLOAT16_in", "BFLOAT8_B_in"]

rms_norm_parametrizations = [True, False]
rms_norm_parametrization_ids = ["rmsnorm", "layernorm"]


def run_test_distributed_layernorm_with_program_cache_and_checks(
inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, mesh_device, use_program_cache, iterations
):
if len(all_devices) != 8:
if mesh_device.get_num_devices() < n_devices:
pytest.skip("Not T3000!")

devices = get_devices_for_t3000(all_devices, n_devices)

run_distributed_layernorm(inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, devices, iterations=iterations)
run_distributed_layernorm(inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, mesh_device, iterations=iterations)

for d in range(len(devices)):
assert devices[d].num_program_cache_entries() == 3, "Program cache should have only 3 entries, but has " + str(
devices[d].num_program_cache_entries()
for d in mesh_device.get_devices():
assert d.num_program_cache_entries() == 3, "Program cache should have only 3 entries, but has " + str(
d.num_program_cache_entries()
)


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize("iterations", [2], ids=["loops2"])
@pytest.mark.parametrize("dtype", dtypes, ids=dtype_ids)
@pytest.mark.parametrize("stats_dtype", stats_dtypes, ids=stats_dtypes_ids)
@pytest.mark.parametrize("inp_shape", inp_shapes, ids=inp_shape_ids)
@pytest.mark.parametrize("n_devices", [8])
@pytest.mark.parametrize("is_rmsnorm", rms_norm_parametrizations, ids=rms_norm_parametrization_ids)
def test_distributed_layernorm_with_program_cache(
inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, iterations, t3k_mesh_device, use_program_cache
):
run_test_distributed_layernorm_with_program_cache_and_checks(
inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, t3k_mesh_device, use_program_cache, iterations=iterations
)


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize("iterations", [2], ids=["loops2"])
@pytest.mark.parametrize("dtype", dtypes, ids=dtype_ids)
@pytest.mark.parametrize("stats_dtype", stats_dtypes, ids=stats_dtypes_ids)
@pytest.mark.parametrize("inp_shape", inp_shapes, ids=inp_shape_ids)
@pytest.mark.parametrize("n_devices", [4])
@pytest.mark.parametrize("is_rmsnorm", rms_norm_parametrizations, ids=rms_norm_parametrization_ids)
def test_distributed_layernorm_with_program_cache_4chip(
inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, iterations, pcie_mesh_device, use_program_cache
):
run_test_distributed_layernorm_with_program_cache_and_checks(
inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, pcie_mesh_device, use_program_cache, iterations=iterations
)

0 comments on commit 824c167

Please sign in to comment.