Skip to content

Commit

Permalink
#4686: add fp32 test for allgather
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Mar 26, 2024
1 parent a10525f commit 82efc61
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -667,3 +667,47 @@ def test_all_gather_post_commit_sharded(
print(f"")
print(f"")
assert all_eq, f"{i} FAILED: {output}"


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"input_shape, dim, layout",
[([4, 1, 33, 256], 0, ttl.tensor.Layout.ROW_MAJOR), ([4, 1, 256, 32], 0, ttl.tensor.Layout.TILE)],
)
@pytest.mark.parametrize(
"mem_config",
[
ttl.tensor.MemoryConfig(buffer_type=ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(buffer_type=ttl.tensor.BufferType.L1),
],
)
@pytest.mark.parametrize("num_links", [1, 2])
def test_all_gather_fp32(
pcie_devices, input_shape, dim, num_links, layout, mem_config, use_program_cache, function_level_defaults
):
if (
layout == ttl.tensor.Layout.ROW_MAJOR or num_links == 2
) and mem_config.buffer_type == ttl.tensor.BufferType.DRAM:
pytest.skip("All gather tests are hanging for RM in DRAM")
devices = pcie_devices
input_tensor = torch.rand(input_shape).bfloat16()
num_devices = len(devices)
if num_devices < 2:
pytest.skip("Requires multiple devices to run")
elif num_devices == 2 and num_links == 2:
pytest.skip("Not enough links to run")

if input_shape[dim] % num_devices != 0 or (dim == 3 and input_shape[dim] // num_devices % 32 != 0):
pytest.skip("Unsupported test case")

input_tensors = torch.chunk(input_tensor, num_devices, dim)
tt_input_tensors = []
for i, t in enumerate(input_tensors):
tt_input_tensors.append(ttl.tensor.Tensor(t, ttl.tensor.DataType.FLOAT32).to(layout).to(devices[i], mem_config))

tt_out_tensors = ttl.tensor.all_gather(tt_input_tensors, dim, num_links, output_mem_config=mem_config)

for i, t in enumerate(tt_out_tensors):
tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
eq, output = comp_equal(tt_output_tensor, input_tensor)
assert eq, f"{i} FAILED: {output}"
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,20 @@ def run_layernorm_mix_precision_tests(test_id, in_dtype, gamma_dtype, in0_mem_co
gamma_t = pad_by_zero(gamma, device, in0_mem_config, gamma_dtype)[0]
beta_t = pad_by_zero(beta, device, in0_mem_config, gamma_dtype)[0]

compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig(
math_fidelity=ttl.tensor.MathFidelity.HiFi4,
math_approx_mode=True,
fp32_dest_acc_en=True if in_dtype == ttl.tensor.DataType.FLOAT32 else False,
)
if not is_grayskull():
compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig(
math_fidelity=ttl.tensor.MathFidelity.HiFi4,
math_approx_mode=True,
fp32_dest_acc_en=True if in_dtype == ttl.tensor.DataType.FLOAT32 else False,
)

if test_id == 0:
ttz = ttl.tensor.add_layernorm(
in0_t, in1_t, epsf, output_mem_config=out_mem_config, compute_kernel_config=compute_kernel_config
in0_t,
in1_t,
epsf,
output_mem_config=out_mem_config,
compute_kernel_config=compute_kernel_config if not is_grayskull() else None,
)
if test_id == 1:
ttz = ttl.tensor.add_layernorm(
Expand All @@ -62,7 +67,7 @@ def run_layernorm_mix_precision_tests(test_id, in_dtype, gamma_dtype, in0_mem_co
epsf,
gamma_t,
output_mem_config=out_mem_config,
compute_kernel_config=compute_kernel_config,
compute_kernel_config=compute_kernel_config if not is_grayskull() else None,
)
if test_id == 2:
ttz = ttl.tensor.add_layernorm(
Expand All @@ -72,11 +77,15 @@ def run_layernorm_mix_precision_tests(test_id, in_dtype, gamma_dtype, in0_mem_co
gamma_t,
beta_t,
output_mem_config=out_mem_config,
compute_kernel_config=compute_kernel_config,
compute_kernel_config=compute_kernel_config if not is_grayskull() else None,
)
if test_id == 3:
ttz = ttl.tensor.add_rmsnorm(
in0_t, in1_t, epsf, output_mem_config=out_mem_config, compute_kernel_config=compute_kernel_config
in0_t,
in1_t,
epsf,
output_mem_config=out_mem_config,
compute_kernel_config=compute_kernel_config if not is_grayskull() else None,
)
if test_id == 4:
ttz = ttl.tensor.add_rmsnorm(
Expand All @@ -85,7 +94,7 @@ def run_layernorm_mix_precision_tests(test_id, in_dtype, gamma_dtype, in0_mem_co
epsf,
gamma_t,
output_mem_config=out_mem_config,
compute_kernel_config=compute_kernel_config,
compute_kernel_config=compute_kernel_config if not is_grayskull() else None,
)
if test_id == 5:
ttz = ttl.tensor.add_rmsnorm(
Expand All @@ -95,15 +104,22 @@ def run_layernorm_mix_precision_tests(test_id, in_dtype, gamma_dtype, in0_mem_co
gamma_t,
beta_t,
output_mem_config=out_mem_config,
compute_kernel_config=compute_kernel_config,
compute_kernel_config=compute_kernel_config if not is_grayskull() else None,
)
if test_id == 6:
ttz = ttl.tensor.layernorm(
in0_t, epsf, output_mem_config=out_mem_config, compute_kernel_config=compute_kernel_config
in0_t,
epsf,
output_mem_config=out_mem_config,
compute_kernel_config=compute_kernel_config if not is_grayskull() else None,
)
if test_id == 7:
ttz = ttl.tensor.layernorm(
in0_t, epsf, gamma_t, output_mem_config=out_mem_config, compute_kernel_config=compute_kernel_config
in0_t,
epsf,
gamma_t,
output_mem_config=out_mem_config,
compute_kernel_config=compute_kernel_config if not is_grayskull() else None,
)
if test_id == 8:
ttz = ttl.tensor.layernorm(
Expand All @@ -112,15 +128,22 @@ def run_layernorm_mix_precision_tests(test_id, in_dtype, gamma_dtype, in0_mem_co
gamma_t,
beta_t,
output_mem_config=out_mem_config,
compute_kernel_config=compute_kernel_config,
compute_kernel_config=compute_kernel_config if not is_grayskull() else None,
)
if test_id == 9:
ttz = ttl.tensor.rmsnorm(
in0_t, epsf, output_mem_config=out_mem_config, compute_kernel_config=compute_kernel_config
in0_t,
epsf,
output_mem_config=out_mem_config,
compute_kernel_config=compute_kernel_config if not is_grayskull() else None,
)
if test_id == 10:
ttz = ttl.tensor.rmsnorm(
in0_t, epsf, gamma_t, output_mem_config=out_mem_config, compute_kernel_config=compute_kernel_config
in0_t,
epsf,
gamma_t,
output_mem_config=out_mem_config,
compute_kernel_config=compute_kernel_config if not is_grayskull() else None,
)
if test_id == 11:
ttz = ttl.tensor.rmsnorm(
Expand All @@ -129,7 +152,7 @@ def run_layernorm_mix_precision_tests(test_id, in_dtype, gamma_dtype, in0_mem_co
gamma_t,
beta_t,
output_mem_config=out_mem_config,
compute_kernel_config=compute_kernel_config,
compute_kernel_config=compute_kernel_config if not is_grayskull() else None,
)

tt_got_back = ttz.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,15 @@ def test_rotary_embedding_prefill_fp32(
out_mem_config = ttl.tensor.MemoryConfig()

xt = ttl.tensor.Tensor(x, input_dtype)
if xt.shape()[-2] % 32 == 0 and xt.shape()[-1] % 32 == 0:
if xt.get_legacy_shape()[-2] % 32 == 0 and xt.get_legacy_shape()[-1] % 32 == 0:
xt = xt.to(ttl.tensor.Layout.TILE)
elif input_dtype == ttl.tensor.DataType.BFLOAT8_B:
pytest.skip()

if in_sharded or out_sharded:
if xt.layout() != ttl.tensor.Layout.TILE:
if xt.get_layout() != ttl.tensor.Layout.TILE:
pytest.skip("Sharding support required tile size")
num_blocks = xt.volume() // xt.shape()[-1] // 32
num_blocks = xt.volume() // xt.get_legacy_shape()[-1] // 32
compute_grid_size = device.compute_with_storage_grid_size()
for i in range(compute_grid_size.x * compute_grid_size.y, 0, -1):
if num_blocks % i == 0:
Expand All @@ -226,7 +226,7 @@ def test_rotary_embedding_prefill_fp32(
shard_grid,
[
Ht * 32,
xt.shape()[-1],
xt.get_legacy_shape()[-1],
],
ttl.tensor.ShardOrientation.ROW_MAJOR,
False,
Expand Down Expand Up @@ -274,15 +274,15 @@ def test_rotary_embedding_decode_fp32(
out_mem_config = ttl.tensor.MemoryConfig()

xt = ttl.tensor.Tensor(x, input_dtype)
if xt.shape()[-2] % 32 == 0 and xt.shape()[-1] % 32 == 0:
if xt.get_legacy_shape()[-2] % 32 == 0 and xt.get_legacy_shape()[-1] % 32 == 0:
xt = xt.to(ttl.tensor.Layout.TILE)
elif input_dtype == ttl.tensor.DataType.BFLOAT8_B:
pytest.skip()

if in_sharded or out_sharded:
if xt.layout() != ttl.tensor.Layout.TILE:
if xt.get_layout() != ttl.tensor.Layout.TILE:
pytest.skip("Sharding support required tile size")
num_blocks = xt.volume() // xt.shape()[-1] // 32
num_blocks = xt.volume() // xt.get_legacy_shape()[-1] // 32
compute_grid_size = device.compute_with_storage_grid_size()
for i in range(compute_grid_size.x * compute_grid_size.y, 0, -1):
if num_blocks % i == 0:
Expand All @@ -298,7 +298,7 @@ def test_rotary_embedding_decode_fp32(
shard_grid,
[
Ht * 32,
xt.shape()[-1],
xt.get_legacy_shape()[-1],
],
ttl.tensor.ShardOrientation.ROW_MAJOR,
False,
Expand Down

0 comments on commit 82efc61

Please sign in to comment.