Skip to content

Commit

Permalink
#0: TTNN Async and Multi Device Trace Support
Browse files Browse the repository at this point in the history
  - Add async safe ttnn and tt_lib trace APIs
  - Single and multi-chip trace tests added to ttnn
    post commit
  - Resnet50 Async Trace tests added (after porting the model
    over to async)
  - Certain multichip tests with all-gather currently disabled
    since they hang with trace
  • Loading branch information
tt-asaigal authored and tt-aho committed May 16, 2024
1 parent 8599dba commit cb34e1b
Show file tree
Hide file tree
Showing 33 changed files with 811 additions and 151 deletions.
13 changes: 11 additions & 2 deletions models/demos/resnet/tests/test_metal_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,20 @@ def test_run_resnet50_inference(
[tt_lib.tensor.MathFidelity.HiFi4, tt_lib.tensor.MathFidelity.HiFi2, tt_lib.tensor.MathFidelity.LoFi],
ids=["HiFi4", "HiFi2", "LoFi"],
)
@pytest.mark.parametrize("enable_async", [True, False])
def test_run_resnet50_trace_inference(
device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input
device,
use_program_cache,
batch_size,
weights_dtype,
activations_dtype,
math_fidelity,
imagenet_sample_input,
enable_async,
):
if is_e75(device):
pytest.skip("Resnet50 is not supported on E75")

device.enable_async(enable_async)
if batch_size > 8 and (
activations_dtype != tt_lib.tensor.DataType.BFLOAT8_B or weights_dtype != tt_lib.tensor.DataType.BFLOAT8_B
):
Expand Down Expand Up @@ -339,3 +347,4 @@ def test_run_resnet50_trace_inference(
# assert passing # fails because of torch.allclose
# Done with the trace, can deallocate the buffers now.
tt_lib.device.ReleaseTrace(device, tid)
device.enable_async(False)
5 changes: 4 additions & 1 deletion models/demos/resnet/tests/test_perf_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,21 +283,24 @@ def run_perf_resnet_trace(
(20, 0.04, 25),
),
)
@pytest.mark.parametrize("enable_async", [True, False])
def test_perf_trace_bare_metal(
device,
use_program_cache,
batch_size,
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
enable_async,
):
if is_e75(device):
pytest.skip("Resnet is not supported on E75")

device.enable_async(enable_async)
run_perf_resnet_trace(
batch_size,
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
device,
)
device.enable_async(False)
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ TT_METAL_ENABLE_REMOTE_CHIP=1 ./build/test/tt_metal/unit_tests_fast_dispatch --g
./build/test/tt_metal/unit_tests_fast_dispatch --gtest_filter="DPrintFixture.*:WatcherFixture.*"
pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py -k post_commit

# ttnn multi-device trace tests
pytest tests/ttnn/unit_tests/test_multi_device_trace.py

# ttnn multi-chip apis unit tests
pytest tests/ttnn/unit_tests/test_multi_device.py
pytest tests/ttnn/unit_tests/test_multi_device_async.py

# Falcon40b unit tests; prefill required 8x8 grids
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/t3000/falcon40b/tests/test_falcon_mlp.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def shape_padded(shape):
"BFLOAT16",
],
)
def test_run_average_pool(act_shape, dtype, device, use_program_cache):
@pytest.mark.parametrize("enable_async", [True, False])
def test_run_average_pool(act_shape, dtype, device, use_program_cache, enable_async):
device.enable_async(enable_async)

batch_size, _, _, channels = act_shape

torch.manual_seed(0)
Expand Down Expand Up @@ -103,3 +106,4 @@ def run_ops(ttact_res):

# Done with the trace, can deallocate the buffers now.
ttl.device.ReleaseTrace(device, tid)
device.enable_async(False)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
(False, False, False, 4608, 1024, 3072, None), # out interleaved, in0 interleaved
],
)
@pytest.mark.parametrize("enable_async", [True, False])
def test_bert_linear(
device,
fidelity,
Expand All @@ -47,7 +48,9 @@ def test_bert_linear(
activation,
use_program_cache,
function_level_defaults,
enable_async,
):
device.enable_async(enable_async)
has_bias = False
in0_shape = [1, 1, M, K]
in1_shape = [1, 1, K, N]
Expand Down Expand Up @@ -96,7 +99,6 @@ def test_bert_linear(
in0 = torch.randn(in0_shape).bfloat16().float()
in1 = torch.randn(in1_shape).bfloat16().float()
bias = torch.randn(bias_shape).bfloat16().float()

in0_t_res = torch2tt_tensor(
in0, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B
)
Expand Down Expand Up @@ -195,7 +197,7 @@ def run_ops(in0_t_res):
passing, output = comp_pcc(pt_out, tt_out)
logger.info(output)
assert passing
ttl.device.ReleaseLastTrace(device)

# Done with the trace, can deallocate the buffers now.
ttl.device.ReleaseTrace(device, tid)
device.enable_async(False)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

import tt_lib as ttl
import ttnn
from models.utility_functions import comp_pcc
from models.utility_functions import is_grayskull

Expand All @@ -31,7 +32,7 @@ def test_tensor_preallocation_and_write_apis(
for tensor_shape in shapes:
# Preallocate tensor on device
preallocated_tensor = ttl.tensor.allocate_tensor_on_device(
tensor_shape,
ttnn.Shape(tensor_shape),
in_dtype,
tensor_layout,
device,
Expand Down
222 changes: 222 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import typing
import pytest
import ttnn
import tempfile
from loguru import logger
from tests.ttnn.utils_for_testing import assert_with_pcc
from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor, ListMeshToTensor


@pytest.mark.parametrize("shape", [(1, 1, 512, 512), (1, 1, 32, 32), (1, 3, 512, 512), (1, 3, 32, 32)])
@pytest.mark.parametrize("use_all_gather", [True, False])
@pytest.mark.parametrize("enable_async", [True, False])
def test_multi_device_single_trace(pcie_device_mesh, shape, use_all_gather, enable_async):
# Trace requires program cache to be enabled
for device_id in pcie_device_mesh.get_device_ids():
pcie_device_mesh.get_device(device_id).enable_async(enable_async)
pcie_device_mesh.get_device(device_id).enable_program_cache()

# Preallocate activation tensors. These will be used when capturing and executing the trace
input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, pcie_device_mesh)
input_1_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, pcie_device_mesh)

# Op chains to be traced
def run_op_chain(input_0, input_1):
single_dev_output = ttnn.neg(ttnn.add(ttnn.mul(input_1, ttnn.neg(ttnn.gelu(input_0))), ttnn.relu(input_1)))
if use_all_gather:
return ttnn.all_gather(single_dev_output, dim=0, num_links=1)
return single_dev_output

# Compile program binaries
run_op_chain(input_0_dev, input_1_dev)

# Capture Trace
logger.info("Capture Trace")
tid = ttnn.begin_multi_device_trace_capture(pcie_device_mesh, 106496, 0)
output_tensor = run_op_chain(input_0_dev, input_1_dev)
ttnn.end_multi_device_trace_capture(pcie_device_mesh, tid, 0)

for i in range(50):
# Create torch inputs
torch_input_tensor_0 = torch.rand(
(pcie_device_mesh.get_num_devices(), shape[1], shape[2], shape[3]), dtype=torch.bfloat16
)
torch_input_tensor_1 = torch.rand(
(pcie_device_mesh.get_num_devices(), shape[1], shape[2], shape[3]), dtype=torch.bfloat16
)
# Compute PT Golden
torch_output_golden = torch.neg(
torch.add(
torch.mul(torch_input_tensor_1, torch.neg(torch.nn.functional.gelu(torch_input_tensor_0))),
torch.relu(torch_input_tensor_1),
)
)
# Convert torch tensors to TTNN Multi-Device Host Tensors
ttnn_input_tensor_0 = ttnn.from_torch(
torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(pcie_device_mesh, dim=0)
)
ttnn_input_tensor_1 = ttnn.from_torch(
torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(pcie_device_mesh, dim=0)
)

# Copy TTNN host tensors into preallocated Mult-Device tensors
logger.info("Send Inputs to Device")
ttnn.copy_host_to_device_tensor(ttnn_input_tensor_0, input_0_dev)
ttnn.copy_host_to_device_tensor(ttnn_input_tensor_1, input_1_dev)
logger.info("Execute Trace")
# Execute trace
ttnn.execute_multi_device_trace(pcie_device_mesh, tid, 0, False)

if use_all_gather:
# Device All-Gather: Iterate through tensors on all devices. Ensure they match the full tensor
logger.info("Read Back Trace Outputs")
device_tensors: typing.List[ttnn.Tensor] = ttnn.get_device_tensors(output_tensor)
for device_tensor in device_tensors:
device_tensor_torch = ttnn.to_torch(device_tensor)
assert_with_pcc(device_tensor_torch, torch_output_golden, pcc=0.99)

else:
# Perform host All-Gather
ttnn_torch_output_tensor = ttnn.to_torch(
output_tensor, mesh_composer=ConcatMeshToTensor(pcie_device_mesh, dim=0)
)
assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden, pcc=0.99)

# Release trace buffer once workload is complete
ttnn.release_multi_device_trace(pcie_device_mesh, tid)

for device_id in pcie_device_mesh.get_device_ids():
pcie_device_mesh.get_device(device_id).enable_async(False)


@pytest.mark.parametrize("shape", [(1, 1, 512, 512), (1, 1, 32, 32), (1, 3, 512, 512), (1, 3, 32, 32)])
@pytest.mark.parametrize("use_all_gather", [True, False])
@pytest.mark.parametrize("enable_async", [True, False])
def test_multi_device_multi_trace(pcie_device_mesh, shape, use_all_gather, enable_async):
if use_all_gather:
# Currently all-gather tests pass only if blocking == False
if shape == (1, 1, 32, 32) or shape == (1, 3, 512, 512) or shape == (1, 3, 32, 32):
pytest.skip("This configuration is not working with all-gather")

# Trace requires program cache to be enabled
for device_id in pcie_device_mesh.get_device_ids():
pcie_device_mesh.get_device(device_id).enable_async(enable_async)
pcie_device_mesh.get_device(device_id).enable_program_cache()

# Preallocate activation tensors. These will be used when capturing and executing the trace
input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, pcie_device_mesh)
input_1_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, pcie_device_mesh)
weight_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, pcie_device_mesh)

# Op chains to be traced
def run_op_chain(input_0, input_1, weight):
single_dev_output = ttnn.neg(
ttnn.add(ttnn.mul(input_1, ttnn.neg(ttnn.gelu(input_0))), ttnn.relu(input_1))
) @ ttnn.silu(weight)
if use_all_gather:
return ttnn.all_gather(single_dev_output, dim=0, num_links=1)
return single_dev_output

def run_op_chain_1(input_0, input_1, weight):
single_dev_output = ttnn.tanh(ttnn.mul(ttnn.sub(input_0, input_1), weight))
if use_all_gather:
return ttnn.all_gather(single_dev_output, dim=0, num_links=1)
return single_dev_output

# Compile program binaries
run_op_chain(input_0_dev, input_1_dev, weight_dev)
run_op_chain_1(input_0_dev, input_1_dev, weight_dev)

# Capture Trace 0
logger.info("Capture Trace 0")
tid = ttnn.begin_multi_device_trace_capture(pcie_device_mesh, 106496, 0)
output_tensor = run_op_chain(input_0_dev, input_1_dev, weight_dev)
ttnn.end_multi_device_trace_capture(pcie_device_mesh, tid, 0)

# Capture Trace 1
logger.info("Capture Trace 1")
tid_1 = ttnn.begin_multi_device_trace_capture(pcie_device_mesh, 26624, 0)
output_tensor_1 = run_op_chain_1(input_0_dev, input_1_dev, weight_dev)
ttnn.end_multi_device_trace_capture(pcie_device_mesh, tid_1, 0)

# Execute and verify trace against pytorch
torch_silu = torch.nn.SiLU()
for i in range(50):
# Create torch inputs
torch_input_tensor_0 = torch.rand(
(pcie_device_mesh.get_num_devices(), shape[1], shape[2], shape[3]), dtype=torch.bfloat16
)
torch_input_tensor_1 = torch.rand(
(pcie_device_mesh.get_num_devices(), shape[1], shape[2], shape[3]), dtype=torch.bfloat16
)
torch_weight = torch.rand(shape, dtype=torch.bfloat16)
# Compute PT Golden
torch_output_golden = torch.neg(
torch.add(
torch.mul(torch_input_tensor_1, torch.neg(torch.nn.functional.gelu(torch_input_tensor_0))),
torch.relu(torch_input_tensor_1),
)
) @ torch_silu(torch_weight)

torch_output_golden_1 = torch.tanh(
torch.mul(torch.sub(torch_input_tensor_0, torch_input_tensor_1), torch_weight)
)

# Convert torch tensors to TTNN Multi-Device Host Tensors
ttnn_input_tensor_0 = ttnn.from_torch(
torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(pcie_device_mesh, dim=0)
)
ttnn_input_tensor_1 = ttnn.from_torch(
torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(pcie_device_mesh, dim=0)
)
ttnn_weight = ttnn.from_torch(
torch_weight, layout=ttnn.TILE_LAYOUT, mesh_mapper=ReplicateTensorToMesh(pcie_device_mesh)
)

# Copy TTNN host tensors into preallocated Mult-Device tensors
logger.info("Send Inputs to Device")
ttnn.copy_host_to_device_tensor(ttnn_input_tensor_0, input_0_dev)
ttnn.copy_host_to_device_tensor(ttnn_input_tensor_1, input_1_dev)
ttnn.copy_host_to_device_tensor(ttnn_weight, weight_dev)

logger.info("Execute Trace 0")
# Execute trace
ttnn.execute_multi_device_trace(pcie_device_mesh, tid, 0, False)
logger.info("Execute Trace 1")
ttnn.execute_multi_device_trace(pcie_device_mesh, tid_1, 0, False)
if use_all_gather:
# Device All-Gather: Iterate through tensors on all devices. Ensure they match the full tensor
logger.info("Read Back Trace 0 Outputs")
device_tensors: typing.List[ttnn.Tensor] = ttnn.get_device_tensors(output_tensor)
for device_tensor in device_tensors:
device_tensor_torch = ttnn.to_torch(device_tensor)
assert_with_pcc(device_tensor_torch, torch_output_golden, pcc=0.99)

logger.info("Read Back Trace 1 Outputs")
device_tensors: typing.List[ttnn.Tensor] = ttnn.get_device_tensors(output_tensor_1)
for device_tensor in device_tensors:
device_tensor_torch = ttnn.to_torch(device_tensor)
assert_with_pcc(device_tensor_torch, torch_output_golden_1, pcc=0.99)
else:
# Perform host All-Gather
ttnn_torch_output_tensor = ttnn.to_torch(
output_tensor, mesh_composer=ConcatMeshToTensor(pcie_device_mesh, dim=0)
)
assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden, pcc=0.99)

ttnn_torch_output_tensor = ttnn.to_torch(
output_tensor_1, mesh_composer=ConcatMeshToTensor(pcie_device_mesh, dim=0)
)
assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden_1, pcc=0.99)

# Release trace buffer once workload is complete
ttnn.release_multi_device_trace(pcie_device_mesh, tid)
ttnn.release_multi_device_trace(pcie_device_mesh, tid_1)

for device_id in pcie_device_mesh.get_device_ids():
pcie_device_mesh.get_device(device_id).enable_async(False)
Loading

0 comments on commit cb34e1b

Please sign in to comment.