diff --git a/models/demos/resnet/tests/test_metal_resnet50.py b/models/demos/resnet/tests/test_metal_resnet50.py index b1791bbf0ee..e5ade5e7802 100644 --- a/models/demos/resnet/tests/test_metal_resnet50.py +++ b/models/demos/resnet/tests/test_metal_resnet50.py @@ -239,12 +239,20 @@ def test_run_resnet50_inference( [tt_lib.tensor.MathFidelity.HiFi4, tt_lib.tensor.MathFidelity.HiFi2, tt_lib.tensor.MathFidelity.LoFi], ids=["HiFi4", "HiFi2", "LoFi"], ) +@pytest.mark.parametrize("enable_async", [True, False]) def test_run_resnet50_trace_inference( - device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + enable_async, ): if is_e75(device): pytest.skip("Resnet50 is not supported on E75") - + device.enable_async(enable_async) if batch_size > 8 and ( activations_dtype != tt_lib.tensor.DataType.BFLOAT8_B or weights_dtype != tt_lib.tensor.DataType.BFLOAT8_B ): @@ -339,3 +347,4 @@ def test_run_resnet50_trace_inference( # assert passing # fails because of torch.allclose # Done with the trace, can deallocate the buffers now. tt_lib.device.ReleaseTrace(device, tid) + device.enable_async(False) diff --git a/models/demos/resnet/tests/test_perf_resnet.py b/models/demos/resnet/tests/test_perf_resnet.py index ac3f54cc9cb..f8811cd01cc 100644 --- a/models/demos/resnet/tests/test_perf_resnet.py +++ b/models/demos/resnet/tests/test_perf_resnet.py @@ -283,6 +283,7 @@ def run_perf_resnet_trace( (20, 0.04, 25), ), ) +@pytest.mark.parametrize("enable_async", [True, False]) def test_perf_trace_bare_metal( device, use_program_cache, @@ -290,10 +291,11 @@ def test_perf_trace_bare_metal( expected_inference_time, expected_compile_time, hf_cat_image_sample_input, + enable_async, ): if is_e75(device): pytest.skip("Resnet is not supported on E75") - + device.enable_async(enable_async) run_perf_resnet_trace( batch_size, expected_inference_time, @@ -301,3 +303,4 @@ def test_perf_trace_bare_metal( hf_cat_image_sample_input, device, ) + device.enable_async(False) diff --git a/tests/scripts/multi_chip/run_pre_post_commit_regressions_multi_device.sh b/tests/scripts/multi_chip/run_pre_post_commit_regressions_multi_device.sh index a2081e36d58..64419ec6dc2 100755 --- a/tests/scripts/multi_chip/run_pre_post_commit_regressions_multi_device.sh +++ b/tests/scripts/multi_chip/run_pre_post_commit_regressions_multi_device.sh @@ -23,8 +23,12 @@ TT_METAL_ENABLE_REMOTE_CHIP=1 ./build/test/tt_metal/unit_tests_fast_dispatch --g ./build/test/tt_metal/unit_tests_fast_dispatch --gtest_filter="DPrintFixture.*:WatcherFixture.*" pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py -k post_commit +# ttnn multi-device trace tests +pytest tests/ttnn/unit_tests/test_multi_device_trace.py + # ttnn multi-chip apis unit tests pytest tests/ttnn/unit_tests/test_multi_device.py +pytest tests/ttnn/unit_tests/test_multi_device_async.py # Falcon40b unit tests; prefill required 8x8 grids WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/t3000/falcon40b/tests/test_falcon_mlp.py diff --git a/tests/tt_eager/python_api_testing/trace_testing/misc/test_average_pool.py b/tests/tt_eager/python_api_testing/trace_testing/misc/test_average_pool.py index a8ec7b13742..0ea2a8c5e00 100644 --- a/tests/tt_eager/python_api_testing/trace_testing/misc/test_average_pool.py +++ b/tests/tt_eager/python_api_testing/trace_testing/misc/test_average_pool.py @@ -34,7 +34,10 @@ def shape_padded(shape): "BFLOAT16", ], ) -def test_run_average_pool(act_shape, dtype, device, use_program_cache): +@pytest.mark.parametrize("enable_async", [True, False]) +def test_run_average_pool(act_shape, dtype, device, use_program_cache, enable_async): + device.enable_async(enable_async) + batch_size, _, _, channels = act_shape torch.manual_seed(0) @@ -103,3 +106,4 @@ def run_ops(ttact_res): # Done with the trace, can deallocate the buffers now. ttl.device.ReleaseTrace(device, tid) + device.enable_async(False) diff --git a/tests/tt_eager/python_api_testing/trace_testing/misc/test_bert_ops.py b/tests/tt_eager/python_api_testing/trace_testing/misc/test_bert_ops.py index 4893e71dfa9..820d6782083 100644 --- a/tests/tt_eager/python_api_testing/trace_testing/misc/test_bert_ops.py +++ b/tests/tt_eager/python_api_testing/trace_testing/misc/test_bert_ops.py @@ -35,6 +35,7 @@ (False, False, False, 4608, 1024, 3072, None), # out interleaved, in0 interleaved ], ) +@pytest.mark.parametrize("enable_async", [True, False]) def test_bert_linear( device, fidelity, @@ -47,7 +48,9 @@ def test_bert_linear( activation, use_program_cache, function_level_defaults, + enable_async, ): + device.enable_async(enable_async) has_bias = False in0_shape = [1, 1, M, K] in1_shape = [1, 1, K, N] @@ -96,7 +99,6 @@ def test_bert_linear( in0 = torch.randn(in0_shape).bfloat16().float() in1 = torch.randn(in1_shape).bfloat16().float() bias = torch.randn(bias_shape).bfloat16().float() - in0_t_res = torch2tt_tensor( in0, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B ) @@ -195,7 +197,7 @@ def run_ops(in0_t_res): passing, output = comp_pcc(pt_out, tt_out) logger.info(output) assert passing - ttl.device.ReleaseLastTrace(device) # Done with the trace, can deallocate the buffers now. ttl.device.ReleaseTrace(device, tid) + device.enable_async(False) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_tensor_prealloc_and_write.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_tensor_prealloc_and_write.py index a1929ab3439..9c7333cbc80 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_tensor_prealloc_and_write.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_tensor_prealloc_and_write.py @@ -7,6 +7,7 @@ import torch import tt_lib as ttl +import ttnn from models.utility_functions import comp_pcc from models.utility_functions import is_grayskull @@ -31,7 +32,7 @@ def test_tensor_preallocation_and_write_apis( for tensor_shape in shapes: # Preallocate tensor on device preallocated_tensor = ttl.tensor.allocate_tensor_on_device( - tensor_shape, + ttnn.Shape(tensor_shape), in_dtype, tensor_layout, device, diff --git a/tests/ttnn/unit_tests/test_multi_device_trace.py b/tests/ttnn/unit_tests/test_multi_device_trace.py new file mode 100644 index 00000000000..02568b8fa51 --- /dev/null +++ b/tests/ttnn/unit_tests/test_multi_device_trace.py @@ -0,0 +1,222 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import typing +import pytest +import ttnn +import tempfile +from loguru import logger +from tests.ttnn.utils_for_testing import assert_with_pcc +from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor, ListMeshToTensor + + +@pytest.mark.parametrize("shape", [(1, 1, 512, 512), (1, 1, 32, 32), (1, 3, 512, 512), (1, 3, 32, 32)]) +@pytest.mark.parametrize("use_all_gather", [True, False]) +@pytest.mark.parametrize("enable_async", [True, False]) +def test_multi_device_single_trace(pcie_device_mesh, shape, use_all_gather, enable_async): + # Trace requires program cache to be enabled + for device_id in pcie_device_mesh.get_device_ids(): + pcie_device_mesh.get_device(device_id).enable_async(enable_async) + pcie_device_mesh.get_device(device_id).enable_program_cache() + + # Preallocate activation tensors. These will be used when capturing and executing the trace + input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, pcie_device_mesh) + input_1_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, pcie_device_mesh) + + # Op chains to be traced + def run_op_chain(input_0, input_1): + single_dev_output = ttnn.neg(ttnn.add(ttnn.mul(input_1, ttnn.neg(ttnn.gelu(input_0))), ttnn.relu(input_1))) + if use_all_gather: + return ttnn.all_gather(single_dev_output, dim=0, num_links=1) + return single_dev_output + + # Compile program binaries + run_op_chain(input_0_dev, input_1_dev) + + # Capture Trace + logger.info("Capture Trace") + tid = ttnn.begin_multi_device_trace_capture(pcie_device_mesh, 106496, 0) + output_tensor = run_op_chain(input_0_dev, input_1_dev) + ttnn.end_multi_device_trace_capture(pcie_device_mesh, tid, 0) + + for i in range(50): + # Create torch inputs + torch_input_tensor_0 = torch.rand( + (pcie_device_mesh.get_num_devices(), shape[1], shape[2], shape[3]), dtype=torch.bfloat16 + ) + torch_input_tensor_1 = torch.rand( + (pcie_device_mesh.get_num_devices(), shape[1], shape[2], shape[3]), dtype=torch.bfloat16 + ) + # Compute PT Golden + torch_output_golden = torch.neg( + torch.add( + torch.mul(torch_input_tensor_1, torch.neg(torch.nn.functional.gelu(torch_input_tensor_0))), + torch.relu(torch_input_tensor_1), + ) + ) + # Convert torch tensors to TTNN Multi-Device Host Tensors + ttnn_input_tensor_0 = ttnn.from_torch( + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(pcie_device_mesh, dim=0) + ) + ttnn_input_tensor_1 = ttnn.from_torch( + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(pcie_device_mesh, dim=0) + ) + + # Copy TTNN host tensors into preallocated Mult-Device tensors + logger.info("Send Inputs to Device") + ttnn.copy_host_to_device_tensor(ttnn_input_tensor_0, input_0_dev) + ttnn.copy_host_to_device_tensor(ttnn_input_tensor_1, input_1_dev) + logger.info("Execute Trace") + # Execute trace + ttnn.execute_multi_device_trace(pcie_device_mesh, tid, 0, False) + + if use_all_gather: + # Device All-Gather: Iterate through tensors on all devices. Ensure they match the full tensor + logger.info("Read Back Trace Outputs") + device_tensors: typing.List[ttnn.Tensor] = ttnn.get_device_tensors(output_tensor) + for device_tensor in device_tensors: + device_tensor_torch = ttnn.to_torch(device_tensor) + assert_with_pcc(device_tensor_torch, torch_output_golden, pcc=0.99) + + else: + # Perform host All-Gather + ttnn_torch_output_tensor = ttnn.to_torch( + output_tensor, mesh_composer=ConcatMeshToTensor(pcie_device_mesh, dim=0) + ) + assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden, pcc=0.99) + + # Release trace buffer once workload is complete + ttnn.release_multi_device_trace(pcie_device_mesh, tid) + + for device_id in pcie_device_mesh.get_device_ids(): + pcie_device_mesh.get_device(device_id).enable_async(False) + + +@pytest.mark.parametrize("shape", [(1, 1, 512, 512), (1, 1, 32, 32), (1, 3, 512, 512), (1, 3, 32, 32)]) +@pytest.mark.parametrize("use_all_gather", [True, False]) +@pytest.mark.parametrize("enable_async", [True, False]) +def test_multi_device_multi_trace(pcie_device_mesh, shape, use_all_gather, enable_async): + if use_all_gather: + # Currently all-gather tests pass only if blocking == False + if shape == (1, 1, 32, 32) or shape == (1, 3, 512, 512) or shape == (1, 3, 32, 32): + pytest.skip("This configuration is not working with all-gather") + + # Trace requires program cache to be enabled + for device_id in pcie_device_mesh.get_device_ids(): + pcie_device_mesh.get_device(device_id).enable_async(enable_async) + pcie_device_mesh.get_device(device_id).enable_program_cache() + + # Preallocate activation tensors. These will be used when capturing and executing the trace + input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, pcie_device_mesh) + input_1_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, pcie_device_mesh) + weight_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, pcie_device_mesh) + + # Op chains to be traced + def run_op_chain(input_0, input_1, weight): + single_dev_output = ttnn.neg( + ttnn.add(ttnn.mul(input_1, ttnn.neg(ttnn.gelu(input_0))), ttnn.relu(input_1)) + ) @ ttnn.silu(weight) + if use_all_gather: + return ttnn.all_gather(single_dev_output, dim=0, num_links=1) + return single_dev_output + + def run_op_chain_1(input_0, input_1, weight): + single_dev_output = ttnn.tanh(ttnn.mul(ttnn.sub(input_0, input_1), weight)) + if use_all_gather: + return ttnn.all_gather(single_dev_output, dim=0, num_links=1) + return single_dev_output + + # Compile program binaries + run_op_chain(input_0_dev, input_1_dev, weight_dev) + run_op_chain_1(input_0_dev, input_1_dev, weight_dev) + + # Capture Trace 0 + logger.info("Capture Trace 0") + tid = ttnn.begin_multi_device_trace_capture(pcie_device_mesh, 106496, 0) + output_tensor = run_op_chain(input_0_dev, input_1_dev, weight_dev) + ttnn.end_multi_device_trace_capture(pcie_device_mesh, tid, 0) + + # Capture Trace 1 + logger.info("Capture Trace 1") + tid_1 = ttnn.begin_multi_device_trace_capture(pcie_device_mesh, 26624, 0) + output_tensor_1 = run_op_chain_1(input_0_dev, input_1_dev, weight_dev) + ttnn.end_multi_device_trace_capture(pcie_device_mesh, tid_1, 0) + + # Execute and verify trace against pytorch + torch_silu = torch.nn.SiLU() + for i in range(50): + # Create torch inputs + torch_input_tensor_0 = torch.rand( + (pcie_device_mesh.get_num_devices(), shape[1], shape[2], shape[3]), dtype=torch.bfloat16 + ) + torch_input_tensor_1 = torch.rand( + (pcie_device_mesh.get_num_devices(), shape[1], shape[2], shape[3]), dtype=torch.bfloat16 + ) + torch_weight = torch.rand(shape, dtype=torch.bfloat16) + # Compute PT Golden + torch_output_golden = torch.neg( + torch.add( + torch.mul(torch_input_tensor_1, torch.neg(torch.nn.functional.gelu(torch_input_tensor_0))), + torch.relu(torch_input_tensor_1), + ) + ) @ torch_silu(torch_weight) + + torch_output_golden_1 = torch.tanh( + torch.mul(torch.sub(torch_input_tensor_0, torch_input_tensor_1), torch_weight) + ) + + # Convert torch tensors to TTNN Multi-Device Host Tensors + ttnn_input_tensor_0 = ttnn.from_torch( + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(pcie_device_mesh, dim=0) + ) + ttnn_input_tensor_1 = ttnn.from_torch( + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(pcie_device_mesh, dim=0) + ) + ttnn_weight = ttnn.from_torch( + torch_weight, layout=ttnn.TILE_LAYOUT, mesh_mapper=ReplicateTensorToMesh(pcie_device_mesh) + ) + + # Copy TTNN host tensors into preallocated Mult-Device tensors + logger.info("Send Inputs to Device") + ttnn.copy_host_to_device_tensor(ttnn_input_tensor_0, input_0_dev) + ttnn.copy_host_to_device_tensor(ttnn_input_tensor_1, input_1_dev) + ttnn.copy_host_to_device_tensor(ttnn_weight, weight_dev) + + logger.info("Execute Trace 0") + # Execute trace + ttnn.execute_multi_device_trace(pcie_device_mesh, tid, 0, False) + logger.info("Execute Trace 1") + ttnn.execute_multi_device_trace(pcie_device_mesh, tid_1, 0, False) + if use_all_gather: + # Device All-Gather: Iterate through tensors on all devices. Ensure they match the full tensor + logger.info("Read Back Trace 0 Outputs") + device_tensors: typing.List[ttnn.Tensor] = ttnn.get_device_tensors(output_tensor) + for device_tensor in device_tensors: + device_tensor_torch = ttnn.to_torch(device_tensor) + assert_with_pcc(device_tensor_torch, torch_output_golden, pcc=0.99) + + logger.info("Read Back Trace 1 Outputs") + device_tensors: typing.List[ttnn.Tensor] = ttnn.get_device_tensors(output_tensor_1) + for device_tensor in device_tensors: + device_tensor_torch = ttnn.to_torch(device_tensor) + assert_with_pcc(device_tensor_torch, torch_output_golden_1, pcc=0.99) + else: + # Perform host All-Gather + ttnn_torch_output_tensor = ttnn.to_torch( + output_tensor, mesh_composer=ConcatMeshToTensor(pcie_device_mesh, dim=0) + ) + assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden, pcc=0.99) + + ttnn_torch_output_tensor = ttnn.to_torch( + output_tensor_1, mesh_composer=ConcatMeshToTensor(pcie_device_mesh, dim=0) + ) + assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden_1, pcc=0.99) + + # Release trace buffer once workload is complete + ttnn.release_multi_device_trace(pcie_device_mesh, tid) + ttnn.release_multi_device_trace(pcie_device_mesh, tid_1) + + for device_id in pcie_device_mesh.get_device_ids(): + pcie_device_mesh.get_device(device_id).enable_async(False) diff --git a/tests/ttnn/unit_tests/test_single_device_trace.py b/tests/ttnn/unit_tests/test_single_device_trace.py new file mode 100644 index 00000000000..f55ca54fed0 --- /dev/null +++ b/tests/ttnn/unit_tests/test_single_device_trace.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import typing +import pytest +import ttnn +import tempfile +from loguru import logger +from tests.ttnn.utils_for_testing import assert_with_pcc + + +@pytest.mark.parametrize("shape", [(1, 1, 512, 512), (1, 1, 32, 32), (1, 3, 512, 512), (1, 3, 32, 32)]) +@pytest.mark.parametrize("enable_async", [True, False]) +@pytest.mark.parametrize("blocking", [True, False]) +def test_single_device_single_trace(device, shape, enable_async, blocking): + device.enable_async(enable_async) + device.enable_program_cache() + + # Preallocate activation tensors. These will be used when capturing and executing the trace + input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, device) + input_1_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, device) + + # Op chain to be traced + def run_op_chain(input_0, input_1): + return ttnn.neg(ttnn.add(ttnn.mul(input_1, ttnn.neg(ttnn.gelu(input_0))), ttnn.relu(input_1))) + + # Compile program binaries + run_op_chain(input_0_dev, input_1_dev) + + # Capture Trace + logger.info("Capture Trace") + tid = ttnn.begin_trace_capture(device, 106496, 0) + output_tensor = run_op_chain(input_0_dev, input_1_dev) + ttnn.end_trace_capture(device, tid, 0) + + for i in range(50): + # Create torch inputs + torch_input_tensor_0 = torch.rand(shape, dtype=torch.bfloat16) + torch_input_tensor_1 = torch.rand(shape, dtype=torch.bfloat16) + # Compute PT Golden + torch_output_golden = torch.neg( + torch.add( + torch.mul(torch_input_tensor_1, torch.neg(torch.nn.functional.gelu(torch_input_tensor_0))), + torch.relu(torch_input_tensor_1), + ) + ) + + # Convert torch tensors to TTNN Multi-Device Host Tensors + ttnn_input_tensor_0 = ttnn.from_torch(torch_input_tensor_0, layout=ttnn.TILE_LAYOUT) + ttnn_input_tensor_1 = ttnn.from_torch(torch_input_tensor_1, layout=ttnn.TILE_LAYOUT) + + # Copy TTNN host tensors into preallocated Mult-Device tensors + logger.info("Send Inputs to Device") + ttnn.copy_host_to_device_tensor(ttnn_input_tensor_0, input_0_dev) + ttnn.copy_host_to_device_tensor(ttnn_input_tensor_1, input_1_dev) + + if blocking: + ttnn.synchronize_device(device) + logger.info("Execute Trace") + # Execute trace + ttnn.execute_trace(device, tid, 0, blocking) + # Readback data + logger.info("Read Back Trace Outputs") + ttnn_torch_output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden, pcc=0.99) + + ttnn.release_trace(device, tid) + device.enable_async(False) + + +@pytest.mark.parametrize("shape", [(1, 1, 512, 512), (1, 1, 32, 32), (1, 3, 512, 512), (1, 3, 32, 32)]) +@pytest.mark.parametrize("enable_async", [True, False]) +@pytest.mark.parametrize("blocking", [True, False]) +def test_single_device_multi_trace(device, shape, enable_async, blocking): + device.enable_async(enable_async) + device.enable_program_cache() + + # Preallocate activation tensors. These will be used when capturing and executing the trace + input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, device) + input_1_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, device) + weight_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, device) + + # Op chains to be traced + def run_op_chain(input_0, input_1, weight): + return ttnn.neg(ttnn.add(ttnn.mul(input_1, ttnn.neg(ttnn.gelu(input_0))), ttnn.relu(input_1))) @ ttnn.silu( + weight + ) + + def run_op_chain_1(input_0, input_1, weight): + return ttnn.tanh(ttnn.mul(ttnn.sub(input_0, input_1), weight)) + + # Compile program binaries + run_op_chain(input_0_dev, input_1_dev, weight_dev) + run_op_chain_1(input_0_dev, input_1_dev, weight_dev) + + # Capture Trace 0 + logger.info("Capture Trace 0") + tid = ttnn.begin_trace_capture(device, 106496, 0) + output_tensor = run_op_chain(input_0_dev, input_1_dev, weight_dev) + ttnn.end_trace_capture(device, tid, 0) + + # Capture Trace 1 + logger.info("Capture Trace 1") + tid_1 = ttnn.begin_trace_capture(device, 26624, 0) + output_tensor_1 = run_op_chain_1(input_0_dev, input_1_dev, weight_dev) + ttnn.end_trace_capture(device, tid_1, 0) + + # Execute and verify trace against pytorch + torch_silu = torch.nn.SiLU() + for i in range(50): + # Create torch inputs + torch_input_tensor_0 = torch.rand(shape, dtype=torch.bfloat16) + torch_input_tensor_1 = torch.rand(shape, dtype=torch.bfloat16) + torch_weight = torch.rand(shape, dtype=torch.bfloat16) + # Compute PT Golden + torch_output_golden = torch.neg( + torch.add( + torch.mul(torch_input_tensor_1, torch.neg(torch.nn.functional.gelu(torch_input_tensor_0))), + torch.relu(torch_input_tensor_1), + ) + ) @ torch_silu(torch_weight) + + torch_output_golden_1 = torch.tanh( + torch.mul(torch.sub(torch_input_tensor_0, torch_input_tensor_1), torch_weight) + ) + + # Convert torch tensors to TTNN Multi-Device Host Tensors + ttnn_input_tensor_0 = ttnn.from_torch(torch_input_tensor_0, layout=ttnn.TILE_LAYOUT) + ttnn_input_tensor_1 = ttnn.from_torch(torch_input_tensor_1, layout=ttnn.TILE_LAYOUT) + ttnn_weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT) + + # Copy TTNN host tensors into preallocated Mult-Device tensors + logger.info("Send Inputs to Device") + ttnn.copy_host_to_device_tensor(ttnn_input_tensor_0, input_0_dev) + ttnn.copy_host_to_device_tensor(ttnn_input_tensor_1, input_1_dev) + ttnn.copy_host_to_device_tensor(ttnn_weight, weight_dev) + + if blocking: + ttnn.synchronize_device(device) + logger.info("Execute Trace 0") + # Execute trace + ttnn.execute_trace(device, tid, 0, blocking) + logger.info("Execute Trace 1") + ttnn.execute_trace(device, tid_1, 0, blocking) + + logger.info("Read Back Trace 0 Outputs") + ttnn_torch_output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden, pcc=0.99) + logger.info("Read Back Trace 1 Outputs") + ttnn_torch_output_tensor = ttnn.to_torch(output_tensor_1) + assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden_1, pcc=0.99) + + # Release trace buffer once workload is complete + ttnn.release_trace(device, tid) + ttnn.release_trace(device, tid_1) + + device.enable_async(False) diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index 555aa91fc21..ce56ef79a68 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -901,13 +901,13 @@ void memcpy(Tensor& dst, const Tensor& src, const std::optional tra } } -Tensor allocate_tensor_on_device(const Shape& shape, DataType data_type, Layout layout, Device *device, const MemoryConfig& memory_config) { +Tensor allocate_tensor_on_device(const ttnn::Shape& shape, DataType data_type, Layout layout, Device *device, const MemoryConfig& memory_config) { // Top level wrapper to asynchronously create a device tensor (single device) Tensor device_tensor = Tensor({device}); uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); device->push_work( [shape, data_type, layout, device, memory_config, device_tensor] () mutable { - auto local_tensor = create_device_tensor(shape, data_type, layout, device, memory_config); + auto local_tensor = create_device_tensor(shape.value(), data_type, layout, device, memory_config); device_tensor.populate_buffers_and_metadata(local_tensor); } ); @@ -915,7 +915,7 @@ Tensor allocate_tensor_on_device(const Shape& shape, DataType data_type, Layout return device_tensor; } -Tensor allocate_tensor_on_device(const Shape& shape, DataType data_type, Layout layout, DeviceMesh *device_mesh, const MemoryConfig& memory_config) { +Tensor allocate_tensor_on_device(const ttnn::Shape& shape, DataType data_type, Layout layout, DeviceMesh *device_mesh, const MemoryConfig& memory_config) { // Top level wrapper to asynchronously create a device tensor (multi-device) Tensor device_tensor = Tensor(device_mesh->get_devices()); uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); @@ -926,7 +926,7 @@ Tensor allocate_tensor_on_device(const Shape& shape, DataType data_type, Layout auto& worker = workers[worker_index]; worker->push_work( [shape, data_type, layout, worker, memory_config, device_tensor, worker_index] () mutable { - auto local_tensor = create_device_tensor(shape, data_type, layout, worker, memory_config); + auto local_tensor = create_device_tensor(shape.value(), data_type, layout, worker, memory_config); insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index); if (not worker->id()) { device_tensor.set_shape(ttnn::Shape(shape)); @@ -971,10 +971,10 @@ void write_tensor(Tensor host_tensor, Tensor device_tensor, uint8_t cq_id) { std::visit([&host_data] (auto&& b) { host_data = b.begin(); }, host_storage.get_buffer()); } EnqueueWriteBuffer(worker->command_queue(cq_id), s.get_buffer(), host_data, false); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { auto host_storage = std::get(async_safe_tensor.get_storage()); std::visit([worker_index, &host_data] (auto&& b) { host_data = b.begin(); }, host_storage.get_buffer(worker_index)); - EnqueueWriteBuffer(worker->command_queue(cq_id), s.get_buffer(worker), host_data, false); + EnqueueWriteBuffer(worker->command_queue(cq_id), s.get_buffer_for_device(worker), host_data, false); } }, device_tensor.get_storage()); } diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index bef6f2e2b62..afad4bab1d9 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -354,8 +354,8 @@ void memcpy(void *dst, const Tensor &src, const std::optional trans void memcpy(Tensor &dst, const void *src, const std::optional transfer_size = std::nullopt); void memcpy(Tensor &dst, const Tensor &src, const std::optional transfer_size = std::nullopt); -Tensor allocate_tensor_on_device(const Shape& shape, DataType data_type, Layout layout, Device *device, const MemoryConfig& memory_config = {.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED}); -Tensor allocate_tensor_on_device(const Shape& shape, DataType data_type, Layout layout, DeviceMesh *device_mesh, const MemoryConfig& memory_config = {.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED}); +Tensor allocate_tensor_on_device(const ttnn::Shape& shape, DataType data_type, Layout layout, Device *device, const MemoryConfig& memory_config = {.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED}); +Tensor allocate_tensor_on_device(const ttnn::Shape& shape, DataType data_type, Layout layout, DeviceMesh *device_mesh, const MemoryConfig& memory_config = {.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED}); void write_tensor(Tensor host_tensor, Tensor device_tensor, uint8_t cq_id = 0); } // namespace tt_metal diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index 0332081201a..2a6ff9c64c3 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -426,7 +426,7 @@ struct MultiDeviceHostStorage { return shapes[shape_index]; } - uint32_t num_buffers() { + uint32_t num_buffers() const { std::lock_guard lock(mtx); return buffers.size(); } @@ -524,7 +524,7 @@ struct MultiDeviceHostStorage { return shapes.at(device->id()); } - uint32_t num_buffers() { + uint32_t num_buffers() const { std::lock_guard lock(mtx); return buffers.size(); } diff --git a/tt_eager/tt_dnn/op_library/conv/optimized_conv_op.cpp b/tt_eager/tt_dnn/op_library/conv/optimized_conv_op.cpp index 4c1de9d03be..c14aee74ad6 100644 --- a/tt_eager/tt_dnn/op_library/conv/optimized_conv_op.cpp +++ b/tt_eager/tt_dnn/op_library/conv/optimized_conv_op.cpp @@ -72,28 +72,37 @@ Tensor optimized_conv(const Tensor& a, bool transpose_mcast, std::optional compute_kernel_config ) { - //TT_ASSERT(!untilize_out, "Optimized conv only supports tiled out"); - TT_ASSERT(b.get_layout() == Layout::TILE); // Weights should already be formatted - const auto& ashape = input_tensor_shape.has_value() ? Shape(input_tensor_shape.value()) : a.get_legacy_shape(); - auto padded_a_shape = Shape({ashape[0], ashape[1], ashape[2], round_up(ashape[3], 16)}); - FormatParams input_a_format_params = {.pad_shape=padded_a_shape, .pad_value=0.0, .target_layout=Layout::ROW_MAJOR}; - FormatParams input_b_format_params = {.pad_shape=b.get_legacy_shape(), .pad_value=0.0, .target_layout=Layout::TILE}; - FormatParams input_bias_format_params = {}; - if (has_bias) { - input_bias_format_params = {.pad_shape=bias.value().get_legacy_shape(), .pad_value=0, .target_layout=Layout::TILE}; - } - auto output_layout = untilize_out ? Layout::ROW_MAJOR : Layout::TILE; - if (output_mem_config.has_value()) { - TT_ASSERT((output_mem_config.value().is_sharded() || output_mem_config.value().memory_layout == TensorMemoryLayout::INTERLEAVED)); - } - auto arch = a.storage_type() == StorageType::DEVICE ? a.device()->arch() : AutoFormat::GetDefaultDevice()->arch(); - bool fp32_accum = a.device()->arch() == ARCH::WORMHOLE_B0; // && compute_kernel_config.has_value()) ? compute_kernel_config.value().fp32_dest_acc_en : false; - auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::LoFi, true, fp32_accum, false); - return operation::run_without_autoformat( - OptimizedConv(conv_params, output_channels, untilize_out, has_bias, fuse_relu, math_fidelity, parallelization_config, block_config, extra_padding_for_32B_alignment, output_mem_config.value_or(a.memory_config()), output_dtype.value_or(a.get_dtype()), ashape, use_shallow_conv_variant, transpose_mcast, kernel_config_val - ), - {a, b}, - {bias, conv_reader_indices}).at(0); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({a, b}))}; + operation::launch_op( + [conv_params, output_channels, untilize_out, has_bias, fuse_relu, math_fidelity, parallelization_config, block_config, extra_padding_for_32B_alignment, output_mem_config, output_dtype, input_tensor_shape, use_shallow_conv_variant, transpose_mcast, compute_kernel_config] + (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + auto& a = input_tensors.at(0); + auto& b = input_tensors.at(1); + auto& bias = optional_input_tensors.at(0); + //TT_ASSERT(!untilize_out, "Optimized conv only supports tiled out"); + TT_ASSERT(b.get_layout() == Layout::TILE); // Weights should already be formatted + const auto& ashape = input_tensor_shape.has_value() ? Shape(input_tensor_shape.value()) : a.get_legacy_shape(); + auto padded_a_shape = Shape({ashape[0], ashape[1], ashape[2], round_up(ashape[3], 16)}); + FormatParams input_a_format_params = {.pad_shape=padded_a_shape, .pad_value=0.0, .target_layout=Layout::ROW_MAJOR}; + FormatParams input_b_format_params = {.pad_shape=b.get_legacy_shape(), .pad_value=0.0, .target_layout=Layout::TILE}; + FormatParams input_bias_format_params = {}; + if (has_bias) { + input_bias_format_params = {.pad_shape=bias.value().get_legacy_shape(), .pad_value=0, .target_layout=Layout::TILE}; + } + auto output_layout = untilize_out ? Layout::ROW_MAJOR : Layout::TILE; + if (output_mem_config.has_value()) { + TT_ASSERT((output_mem_config.value().is_sharded() || output_mem_config.value().memory_layout == TensorMemoryLayout::INTERLEAVED)); + } + auto arch = a.storage_type() == StorageType::DEVICE ? a.device()->arch() : AutoFormat::GetDefaultDevice()->arch(); + bool fp32_accum = a.device()->arch() == ARCH::WORMHOLE_B0; // && compute_kernel_config.has_value()) ? compute_kernel_config.value().fp32_dest_acc_en : false; + auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::LoFi, true, fp32_accum, false); + return operation::run_without_autoformat( + OptimizedConv(conv_params, output_channels, untilize_out, has_bias, fuse_relu, math_fidelity, parallelization_config, block_config, extra_padding_for_32B_alignment, output_mem_config.value_or(a.memory_config()), output_dtype.value_or(a.get_dtype()), ashape, use_shallow_conv_variant, transpose_mcast, kernel_config_val + ), + input_tensors, + optional_input_tensors); + }, {a, b}, output_tensors, {bias, conv_reader_indices}); + return output_tensors.at(0); } void OptimizedConv::validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const { diff --git a/tt_eager/tt_dnn/op_library/copy/copy_op.cpp b/tt_eager/tt_dnn/op_library/copy/copy_op.cpp index 03816e5bb8a..afcb50acf2c 100644 --- a/tt_eager/tt_dnn/op_library/copy/copy_op.cpp +++ b/tt_eager/tt_dnn/op_library/copy/copy_op.cpp @@ -82,14 +82,14 @@ tt::stl::reflection::Attributes Copy::attributes() const { } Tensor copy(const Tensor& src_tensor, const Tensor& dst_tensor) { - std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({src_tensor}))}; + std::vector dummy_outputs = {Tensor(operation::get_workers_for_op_output({src_tensor}))}; operation::launch_op( - [] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { - const auto& src_tensor = input_tensors.at(0); - const auto& dst_tensor = input_tensors.at(1); - operation::run(Copy{dst_tensor.memory_config(), dst_tensor.get_dtype()}, {src_tensor, dst_tensor}); - return {dst_tensor}; - }, {src_tensor, dst_tensor}, output_tensors); + [] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + auto& src_tensor = input_tensors.at(0); + auto& dst_tensor = optional_output_tensors.at(0).value(); + operation::run(Copy{dst_tensor.memory_config(), dst_tensor.get_dtype()}, {src_tensor, dst_tensor}); + return {}; + }, {src_tensor}, dummy_outputs, {}, {dst_tensor}); return dst_tensor; } diff --git a/tt_eager/tt_dnn/op_library/downsample/downsample_op.cpp b/tt_eager/tt_dnn/op_library/downsample/downsample_op.cpp index 8a8988782ed..01279831d4e 100644 --- a/tt_eager/tt_dnn/op_library/downsample/downsample_op.cpp +++ b/tt_eager/tt_dnn/op_library/downsample/downsample_op.cpp @@ -80,7 +80,12 @@ operation::ProgramWithCallbacks Downsample::create_program(const std::vector downsample_params, std::optional output_dtype) { - return operation::run_without_autoformat(Downsample{downsample_params, output_dtype.value_or(input_tensor_a.get_dtype())}, {input_tensor_a}).at(0); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))}; + operation::launch_op( + [downsample_params, output_dtype] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + return operation::run_without_autoformat(Downsample{downsample_params, output_dtype.value_or(input_tensors.at(0).get_dtype())}, input_tensors); + }, {input_tensor_a}, output_tensors); + return output_tensors.at(0); } struct DownsampleReadPatternParams { diff --git a/tt_eager/tt_dnn/op_library/move/move_op.hpp b/tt_eager/tt_dnn/op_library/move/move_op.hpp index 25c53d3617e..22d98435376 100644 --- a/tt_eager/tt_dnn/op_library/move/move_op.hpp +++ b/tt_eager/tt_dnn/op_library/move/move_op.hpp @@ -114,36 +114,41 @@ inline Tensor move(const Tensor& input_tensor, const std::optional } inline Tensor move_sharded(const Tensor& input_tensor, const std::optional& mem_config) { - TT_ASSERT(input_tensor.is_allocated(), "Expected input tensor to be allocated"); - auto input_mem_config = input_tensor.memory_config(); - TT_FATAL(input_mem_config.is_sharded(), "Expected input tensor to be sharded"); - auto input_address = input_tensor.buffer()->address(); - auto output_mem_config = mem_config.value_or(input_mem_config); - TT_FATAL(output_mem_config.is_sharded(), "Expected output tensor memory config to be sharded"); - if (not move_op_utils::can_deallocate(input_tensor)) { - TT_FATAL(false, "Expect input tensor to be deallocated after move op. Cannot deallocate before there is probably another consumer."); - // TODO: Should this throw error? - return input_tensor; - } - auto shard_spec = input_tensor.shard_spec().value(); - auto shard_shape = shard_spec.shape; - auto shard_grid = shard_spec.grid; - auto input_shape = input_tensor.get_legacy_shape(); - auto input_dtype = input_tensor.get_dtype(); - auto input_layout = input_tensor.get_layout(); - - DeallocateBuffer(*input_tensor.buffer()); - // log_debug(LogOp, "OUTPUT SHARD SPEC: {}", out_shard_spec); - auto shard_mem_config = output_mem_config; - shard_mem_config.shard_spec = shard_spec; - auto output_tensor = create_device_tensor(input_shape, input_dtype, input_layout, input_tensor.device(), shard_mem_config); - if (input_tensor.buffer()->address() == output_tensor.buffer()->address()) { - tt::log_debug(tt::LogOp, "WARNING: No space to move the tensor. Move op's input address and output address are equal: {}", input_address); - return output_tensor; - } - MoveOpParallelizationStrategy move_op_parallelization_strategy = MoveOpParallelizationStrategy::MULTI_CORE_SHARDED; - auto output = operation::run(Move{output_mem_config, move_op_parallelization_strategy}, {input_tensor, output_tensor}).at(0); - return output; + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + operation::launch_op( + [mem_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + auto& input_tensor = input_tensors.at(0); + TT_ASSERT(input_tensor.is_allocated(), "Expected input tensor to be allocated"); + auto input_mem_config = input_tensor.memory_config(); + TT_FATAL(input_mem_config.is_sharded(), "Expected input tensor to be sharded"); + auto input_address = input_tensor.buffer()->address(); + auto output_mem_config = mem_config.value_or(input_mem_config); + TT_FATAL(output_mem_config.is_sharded(), "Expected output tensor memory config to be sharded"); + if (not move_op_utils::can_deallocate(input_tensor)) { + TT_FATAL(false, "Expect input tensor to be deallocated after move op. Cannot deallocate before there is probably another consumer."); + // TODO: Should this throw error? + return {input_tensor}; + } + auto shard_spec = input_tensor.shard_spec().value(); + auto shard_shape = shard_spec.shape; + auto shard_grid = shard_spec.grid; + auto input_shape = input_tensor.get_legacy_shape(); + auto input_dtype = input_tensor.get_dtype(); + auto input_layout = input_tensor.get_layout(); + + DeallocateBuffer(*input_tensor.buffer()); + // log_debug(LogOp, "OUTPUT SHARD SPEC: {}", out_shard_spec); + auto shard_mem_config = output_mem_config; + shard_mem_config.shard_spec = shard_spec; + auto output_tensor = create_device_tensor(input_shape, input_dtype, input_layout, input_tensor.device(), shard_mem_config); + if (input_tensor.buffer()->address() == output_tensor.buffer()->address()) { + tt::log_debug(tt::LogOp, "WARNING: No space to move the tensor. Move op's input address and output address are equal: {}", input_address); + return {output_tensor}; + } + MoveOpParallelizationStrategy move_op_parallelization_strategy = MoveOpParallelizationStrategy::MULTI_CORE_SHARDED; + return operation::run(Move{output_mem_config, move_op_parallelization_strategy}, {input_tensor, output_tensor}); + }, {input_tensor}, output_tensors); + return output_tensors.at(0); } } // namespace tt_metal diff --git a/tt_eager/tt_dnn/op_library/nlp_tms/nlp_tms.cpp b/tt_eager/tt_dnn/op_library/nlp_tms/nlp_tms.cpp index cfa35d0ccf3..0be4d0d73df 100644 --- a/tt_eager/tt_dnn/op_library/nlp_tms/nlp_tms.cpp +++ b/tt_eager/tt_dnn/op_library/nlp_tms/nlp_tms.cpp @@ -485,7 +485,7 @@ std::vector NlpKVCacheLoadSlice::create_output_tensors(const std::vector auto mem_config = tt::tt_metal::MemoryConfig{TensorMemoryLayout::HEIGHT_SHARDED, BufferType::L1}; mem_config.shard_spec = shard_spec; - return {create_sharded_device_tensor( + return {create_device_tensor( this->compute_output_shapes(input_tensors).at(0), input_tensor_a.get_dtype(), input_tensor_a.get_layout(), diff --git a/tt_eager/tt_dnn/op_library/pool/max_pool.cpp b/tt_eager/tt_dnn/op_library/pool/max_pool.cpp index cb8ed3b1a20..241cfe13fe4 100644 --- a/tt_eager/tt_dnn/op_library/pool/max_pool.cpp +++ b/tt_eager/tt_dnn/op_library/pool/max_pool.cpp @@ -199,22 +199,28 @@ Tensor max_pool2d_v2(const Tensor &input, const MemoryConfig& out_mem_config, uint32_t nblocks, bool use_multicore) { - TT_ASSERT(dilation_h == 1 && dilation_w == 1 && "Dilation not yet supported in max_pool2d."); - TT_ASSERT(pad_h < 2 && pad_w < 2 && "Padding > 1 not yet supported."); - TT_ASSERT(stride_h == stride_w && "Stride should be equal for both H and W for now."); - // calculate the H and W dims for output - uint32_t out_h = ((in_h + 2 * pad_h - (dilation_h * kernel_size_h - 1) - 1) / stride_h) + 1; // floor - uint32_t out_w = ((in_w + 2 * pad_w - (dilation_w * kernel_size_w - 1) - 1) / stride_w) + 1; // floor - return operation::run_without_autoformat(MaxPool{in_n, in_h, in_w, - out_h, out_w, - kernel_size_h, kernel_size_w, - stride_h, stride_w, - pad_h, pad_w, - dilation_h, dilation_w, - out_mem_config, - nblocks, - use_multicore}, - {input, reader_indices}).at(0); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input, reader_indices}))}; + operation::launch_op( + [in_n, in_h, in_w, kernel_size_h, kernel_size_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, out_mem_config, nblocks, use_multicore] + (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + TT_ASSERT(dilation_h == 1 && dilation_w == 1 && "Dilation not yet supported in max_pool2d."); + TT_ASSERT(pad_h < 2 && pad_w < 2 && "Padding > 1 not yet supported."); + TT_ASSERT(stride_h == stride_w && "Stride should be equal for both H and W for now."); + // calculate the H and W dims for output + uint32_t out_h = ((in_h + 2 * pad_h - (dilation_h * kernel_size_h - 1) - 1) / stride_h) + 1; // floor + uint32_t out_w = ((in_w + 2 * pad_w - (dilation_w * kernel_size_w - 1) - 1) / stride_w) + 1; // floor + return operation::run_without_autoformat(MaxPool{in_n, in_h, in_w, + out_h, out_w, + kernel_size_h, kernel_size_w, + stride_h, stride_w, + pad_h, pad_w, + dilation_h, dilation_w, + out_mem_config, + nblocks, + use_multicore}, + input_tensors); + }, {input, reader_indices}, output_tensors); + return output_tensors.at(0); } operation::OpPerformanceModel MaxPool::create_op_performance_model(const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors, const std::vector &output_tensors) const { diff --git a/tt_eager/tt_dnn/op_library/tilize/tilize_op.cpp b/tt_eager/tt_dnn/op_library/tilize/tilize_op.cpp index 72f2e039ae2..6b2280136f6 100644 --- a/tt_eager/tt_dnn/op_library/tilize/tilize_op.cpp +++ b/tt_eager/tt_dnn/op_library/tilize/tilize_op.cpp @@ -197,30 +197,28 @@ Tensor tilize_with_val_padding( bool use_multicore) { // No-op (Will do a tensor copy) // TODO: We need to run asserts before this - if (input_tensor_a.get_layout() == Layout::TILE) { - if (output_tensor_shape == input_tensor_a.get_legacy_shape()) { - log_warning("Perf warning: tilize with padding called on already tilized tensor of target shape."); - return input_tensor_a; - } else { - TT_FATAL(false, "Cannot tilize and pad tensor that is already tilized"); - } - } - if (is_multi_device_tensor(input_tensor_a)) { - return transform(input_tensor_a, [&](const Tensor& tensor) { - return tilize_with_val_padding( - tensor, output_tensor_shape, pad_value, output_mem_config, output_dtype, use_multicore); - }); - } - - return operation::run_without_autoformat( + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))}; + operation::launch_op( + [output_tensor_shape, pad_value, output_mem_config, output_dtype, use_multicore] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + auto& input_tensor_a = input_tensors.at(0); + if (input_tensor_a.get_layout() == Layout::TILE) { + if (output_tensor_shape == input_tensor_a.get_legacy_shape()) { + log_warning("Perf warning: tilize with padding called on already tilized tensor of target shape."); + return {input_tensor_a}; + } else { + TT_FATAL(false, "Cannot tilize and pad tensor that is already tilized"); + } + } + return operation::run_without_autoformat( TilizeWithValPadding{ output_tensor_shape, pad_value, output_mem_config, output_dtype.value_or(input_tensor_a.get_dtype()), use_multicore}, - {input_tensor_a}) - .at(0); + {input_tensor_a}); + }, {input_tensor_a}, output_tensors); + return output_tensors.at(0); } Tensor tilize_with_zero_padding( diff --git a/tt_eager/tt_dnn/op_library/untilize/untilize_op.cpp b/tt_eager/tt_dnn/op_library/untilize/untilize_op.cpp index 54ff4528065..7ca06831cc5 100644 --- a/tt_eager/tt_dnn/op_library/untilize/untilize_op.cpp +++ b/tt_eager/tt_dnn/op_library/untilize/untilize_op.cpp @@ -230,22 +230,28 @@ UntilizeWithUnpaddingOpParallelizationStrategy UntilizeWithUnpadding::get_parall Tensor untilize_with_unpadding(const Tensor &input_tensor_a, const Shape &output_tensor_start, const Shape &output_tensor_end, const MemoryConfig& output_mem_config, bool use_pack_untilize) { // No-op (Will do a tensor copy) // TODO: We need to run asserts before this - const Shape output_tensor_shape = { - output_tensor_end[0] - output_tensor_start[0] + 1, - output_tensor_end[1] - output_tensor_start[1] + 1, - output_tensor_end[2] - output_tensor_start[2] + 1, - output_tensor_end[3] - output_tensor_start[3] + 1, - }; - if (input_tensor_a.get_layout() != Layout::TILE) { - if (input_tensor_a.get_legacy_shape() == output_tensor_shape) { - log_warning("Perf warning: Untilize with unpadding called on already untilized tensor of target shape"); - return AutoFormat::move_tensor_to_mem_config(input_tensor_a, output_mem_config); - } else { - TT_FATAL(false, "Cannot untilize and unpad input which is not tilized"); - } - } - bool fp32_dest_acc_en = input_tensor_a.get_dtype() == DataType::UINT32; // MT: Currently only uint32 is moved to DST directly, fp32 is converted to fp16b - return operation::run_without_autoformat(UntilizeWithUnpadding{output_tensor_start, output_tensor_end, output_mem_config, use_pack_untilize, fp32_dest_acc_en}, {input_tensor_a}).at(0); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))}; + operation::launch_op( + [output_tensor_start, output_tensor_end, output_mem_config, use_pack_untilize] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + auto& input_tensor_a = input_tensors.at(0); + const Shape output_tensor_shape = { + output_tensor_end[0] - output_tensor_start[0] + 1, + output_tensor_end[1] - output_tensor_start[1] + 1, + output_tensor_end[2] - output_tensor_start[2] + 1, + output_tensor_end[3] - output_tensor_start[3] + 1, + }; + if (input_tensor_a.get_layout() != Layout::TILE) { + if (input_tensor_a.get_legacy_shape() == output_tensor_shape) { + log_warning("Perf warning: Untilize with unpadding called on already untilized tensor of target shape"); + return {AutoFormat::move_tensor_to_mem_config(input_tensor_a, output_mem_config)}; + } else { + TT_FATAL(false, "Cannot untilize and unpad input which is not tilized"); + } + } + bool fp32_dest_acc_en = input_tensor_a.get_dtype() == DataType::UINT32; // MT: Currently only uint32 is moved to DST directly, fp32 is converted to fp16b + return operation::run_without_autoformat(UntilizeWithUnpadding{output_tensor_start, output_tensor_end, output_mem_config, use_pack_untilize, fp32_dest_acc_en}, {input_tensor_a}); + }, {input_tensor_a}, output_tensors); + return output_tensors.at(0); } } // namespace tt_metal diff --git a/tt_eager/tt_dnn/op_library/untilize/untilize_with_halo_op_v2.cpp b/tt_eager/tt_dnn/op_library/untilize/untilize_with_halo_op_v2.cpp index b151ef32261..a81d5ec19c7 100644 --- a/tt_eager/tt_dnn/op_library/untilize/untilize_with_halo_op_v2.cpp +++ b/tt_eager/tt_dnn/op_library/untilize/untilize_with_halo_op_v2.cpp @@ -348,20 +348,28 @@ Tensor untilize_with_halo_v2( const MemoryConfig& mem_config, const bool remote_read, const bool transpose_mcast) { - TT_ASSERT(input_tensor.memory_config().is_sharded()); - TT_ASSERT(input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || input_tensor.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED); - // NOTE: for HEIGHT_SHARDED, ncores_nhw == ncores - // for BLOCK_SHARDED, ncores_nhw is just the ncores along height dim (last tensor dim is split along width) - - return operation::run_without_autoformat( - UntilizeWithHaloV2{pad_val, ncores_nhw, max_out_nsticks_per_core, mem_config, remote_read, transpose_mcast}, - { - input_tensor, - padding_config, - local_config, - remote_config, - }) - .at(0); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor, padding_config, local_config, remote_config}))}; + operation::launch_op( + [pad_val, ncores_nhw, max_out_nsticks_per_core, mem_config, remote_read, transpose_mcast] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + auto& input_tensor = input_tensors.at(0); + auto& padding_config = input_tensors.at(1); + auto& local_config = input_tensors.at(2); + auto& remote_config = input_tensors.at(3); + TT_ASSERT(input_tensor.memory_config().is_sharded()); + TT_ASSERT(input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || input_tensor.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED); + // NOTE: for HEIGHT_SHARDED, ncores_nhw == ncores + // for BLOCK_SHARDED, ncores_nhw is just the ncores along height dim (last tensor dim is split along width) + + return operation::run_without_autoformat( + UntilizeWithHaloV2{pad_val, ncores_nhw, max_out_nsticks_per_core, mem_config, remote_read, transpose_mcast}, + { + input_tensor, + padding_config, + local_config, + remote_config, + }); + }, {input_tensor, padding_config, local_config, remote_config}, output_tensors); + return output_tensors.at(0); } } // namespace tt_metal diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp index e555c1adaf6..778f9b1bbae 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp @@ -12,6 +12,7 @@ #include "tt_metal/detail/reports/compilation_reporter.hpp" #include "tt_metal/detail/reports/memory_reporter.hpp" #include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/impl/trace/trace.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" #include "type_caster.hpp" @@ -204,17 +205,44 @@ void DeviceModule(py::module &m_device) { m_device.def("DeallocateBuffers", &detail::DeallocateBuffers, R"doc( Deallocate all buffers associated with Device handle )doc"); - m_device.def("BeginTraceCapture", &BeginTraceCapture, R"doc( + m_device.def("BeginTraceCapture", + [] (Device* device, const uint8_t cq_id, const uint32_t trace_buff_size) { + uint32_t tid = Trace::next_id(); + device->push_work([device, cq_id, tid, trace_buff_size] () mutable { + device->begin_trace(cq_id, tid, trace_buff_size); + }); + return tid; + }, R"doc( Begin trace capture on Device handle )doc"); - m_device.def("EndTraceCapture", &EndTraceCapture, R"doc( + m_device.def("EndTraceCapture", + [] (Device* device, const uint8_t cq_id, const uint32_t tid) { + device->push_work([device, cq_id, tid] () mutable { + device->end_trace(cq_id, tid); + }); + }, R"doc( End trace capture on Device handle )doc"); - m_device.def("ReplayTrace", &ReplayTrace, R"doc( - Replay last captured trace on Device handle + m_device.def("ReplayTrace", + [] (Device* device, const uint8_t cq_id, const uint32_t tid, bool blocking) { + // If blocking, ensure that worker thread blocks until trace is completed + device->push_work([device, cq_id, tid, blocking] { + device->replay_trace(cq_id, tid, blocking); + }); + // If blocking, wait until worker threads have completed + if (blocking) { + device->synchronize(); + } + }, R"doc( + Replay captured trace on Device handle )doc"); - m_device.def("ReleaseTrace", &ReleaseTrace, R"doc( - Release last captured Trace on Device handle + m_device.def("ReleaseTrace", + [] (Device* device, const uint32_t tid) { + device->push_work([device, tid] { + device->release_trace(tid); + }); + }, R"doc( + Release captured Trace on Device handle )doc"); m_device.attr("DEFAULT_L1_SMALL_SIZE") = py::int_(DEFAULT_L1_SMALL_SIZE); diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp index a7b2ec5b263..c764a866220 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp @@ -772,7 +772,7 @@ void TensorModule(py::module &m_tensor) { m_tensor.def( "allocate_tensor_on_device", - py::overload_cast(&allocate_tensor_on_device), + py::overload_cast(&allocate_tensor_on_device), py::arg("shape"), py::arg("dtype"), py::arg("layout"), py::arg("device"), py::arg("memory_config") = MemoryConfig{.memory_layout=TensorMemoryLayout::INTERLEAVED}, R"doc( Allocate a tensor with specified attributes on a device. @@ -781,7 +781,7 @@ void TensorModule(py::module &m_tensor) { m_tensor.def( "allocate_tensor_on_device", - py::overload_cast(&allocate_tensor_on_device), + py::overload_cast(&allocate_tensor_on_device), py::arg("shape"), py::arg("dtype"), py::arg("layout"), py::arg("device"), py::arg("memory_config") = MemoryConfig{.memory_layout=TensorMemoryLayout::INTERLEAVED}, R"doc( Allocate a tensor with specified attributes on a device. diff --git a/ttnn/cpp/pybind11/device.hpp b/ttnn/cpp/pybind11/device.hpp index 99e1db36730..c7dd3b00ddb 100644 --- a/ttnn/cpp/pybind11/device.hpp +++ b/ttnn/cpp/pybind11/device.hpp @@ -27,6 +27,15 @@ void py_module(py::module& module) { module.def("enable_program_cache", &ttnn::enable_program_cache, py::arg("device"), py::kw_only()); module.def("disable_and_clear_program_cache", &ttnn::disable_and_clear_program_cache, py::arg("device"), py::kw_only()); + + module.def("begin_trace_capture", &ttnn::begin_trace_capture, py::arg("device"), py::arg("trace_buffer_size"), py::arg("cq_id") = 0); + + module.def("end_trace_capture", &ttnn::end_trace_capture, py::arg("device"), py::arg("trace_id"), py::arg("cq_id") = 0); + + module.def("execute_trace", &ttnn::execute_trace, py::arg("device"), py::arg("trace_id"), py::arg("cq_id") = 0, py::arg("blocking") = true); + + module.def("release_trace", &ttnn::release_trace, py::arg("device"), py::arg("trace_id")); + } } // namespace device diff --git a/ttnn/cpp/pybind11/multi_device.hpp b/ttnn/cpp/pybind11/multi_device.hpp index adc224c7f8c..0f48caacc03 100644 --- a/ttnn/cpp/pybind11/multi_device.hpp +++ b/ttnn/cpp/pybind11/multi_device.hpp @@ -36,6 +36,10 @@ void py_module(py::module& module) { py::arg("l1_small_size")); module.def("close_device_mesh", &close_device_mesh, py::arg("device_mesh"), py::kw_only()); + module.def("begin_trace_capture", &begin_trace_capture, py::arg("device_mesh"), py::arg("trace_buffer_size"), py::arg("cq_id") = 0); + module.def("end_trace_capture", &end_trace_capture, py::arg("device_mesh"), py::arg("trace_id"), py::arg("cq_id") = 0); + module.def("execute_trace", &execute_trace, py::arg("device_mesh"), py::arg("trace_id"), py::arg("cq_id") = 0, py::arg("blocking") = true); + module.def("release_trace", &release_trace, py::arg("device_mesh"), py::arg("trace_id")); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); module.def("aggregate_as_tensor", &aggregate_as_tensor, py::arg("tensors"), py::kw_only()); } diff --git a/ttnn/cpp/pybind11/operations/core.hpp b/ttnn/cpp/pybind11/operations/core.hpp index d045816bdc7..a8ca044da20 100644 --- a/ttnn/cpp/pybind11/operations/core.hpp +++ b/ttnn/cpp/pybind11/operations/core.hpp @@ -130,6 +130,28 @@ Deallocates device tensor and returns a reallocated tensor * :attr:`input_tensor`: Input Tensor )doc"); + module.def( + "allocate_tensor_on_device", + py::overload_cast&>( + &ttnn::operations::core::allocate_tensor_on_device), + py::arg("shape"), + py::arg("dtype"), + py::arg("layout"), + py::arg("device"), + py::arg("memory_config") = std::nullopt); + + module.def( + "allocate_tensor_on_device", + py::overload_cast&>( + &ttnn::operations::core::allocate_tensor_on_device), + py::arg("shape"), + py::arg("dtype"), + py::arg("layout"), + py::arg("device_mesh"), + py::arg("memory_config") = std::nullopt); + + module.def("copy_host_to_device_tensor", &ttnn::operations::core::copy_host_to_device_tensor, py::arg("host_tensor"), py::arg("device_tensor"), py::arg("cq_id") = 0); + bind_registered_operation( module, ttnn::to_layout, diff --git a/ttnn/cpp/ttnn/device.cpp b/ttnn/cpp/ttnn/device.cpp index 9f9a2434860..0e2f0dcd8a0 100644 --- a/ttnn/cpp/ttnn/device.cpp +++ b/ttnn/cpp/ttnn/device.cpp @@ -58,6 +58,45 @@ void close_device(Device &device) { } } +uint32_t begin_trace_capture(Device* device, const uint32_t trace_buff_size, const uint8_t cq_id) { + uint32_t tid = Trace::next_id(); + device->push_work( + [device, trace_buff_size, cq_id, tid] () mutable { + device->begin_trace(cq_id, tid, trace_buff_size); + }); + return tid; +} + +void end_trace_capture(Device* device, const uint32_t tid, const uint8_t cq_id) { + device->push_work( + [device, cq_id, tid] () mutable { + device->end_trace(cq_id, tid); + } + ); +} + +void execute_trace(Device* device, const uint32_t tid, const uint8_t cq_id, bool blocking) { + // If blocking, ensure that worker thread blocks until trace is completed + device->push_work( + [device, cq_id, tid, blocking] () mutable { + device->replay_trace(cq_id, tid, blocking); + } + ); + // If blocking, wait until worker threads have completed + if (blocking) { + device->synchronize(); + } +} + +void release_trace(Device* device, const uint32_t tid) { + device->push_work( + [device, tid] () mutable { + device->release_trace(tid); + } + ); +} + + } // namespace device using namespace device; diff --git a/ttnn/cpp/ttnn/device.hpp b/ttnn/cpp/ttnn/device.hpp index a9d95081aad..80433291e04 100644 --- a/ttnn/cpp/ttnn/device.hpp +++ b/ttnn/cpp/ttnn/device.hpp @@ -5,6 +5,7 @@ #pragma once #include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/impl/trace/trace.hpp" #include "ttnn/types.hpp" #include "ttnn/device_pool.hpp" namespace ttnn { @@ -15,6 +16,10 @@ Device &open_device(int device_id, size_t l1_small_size = DEFAULT_L1_SMALL_SIZE) void close_device(Device &device); void enable_program_cache(Device &device); void disable_and_clear_program_cache(Device &device); +uint32_t begin_trace_capture(Device* device, const uint32_t trace_buff_size, const uint8_t cq_id = 0); +void end_trace_capture(Device *device, const uint32_t tid, const uint8_t cq_id = 0); +void execute_trace(Device *device, const uint32_t tid, const uint8_t cq_id = 0, bool blocking = true); +void release_trace(Device* device, const uint32_t tid); } // namespace device diff --git a/ttnn/cpp/ttnn/multi_device.hpp b/ttnn/cpp/ttnn/multi_device.hpp index 9222bdce6a4..12366afb65c 100644 --- a/ttnn/cpp/ttnn/multi_device.hpp +++ b/ttnn/cpp/ttnn/multi_device.hpp @@ -8,7 +8,7 @@ #include "tt_eager/tensor/tensor.hpp" #include "tt_metal/impl/device/multi_device.hpp" - +#include "tt_metal/impl/trace/trace.hpp" using Device = ttnn::Device; @@ -28,6 +28,55 @@ inline void close_device_mesh(DeviceMesh &multi_device) { multi_device.close_devices(); } +inline uint32_t begin_trace_capture(DeviceMesh* device, const uint32_t trace_buff_size, const uint8_t cq_id = 0) { + auto workers = device->get_devices(); + uint32_t tid = Trace::next_id(); + for (auto& worker : workers) { + worker->push_work( + [worker, trace_buff_size, cq_id, tid] () mutable { + worker->begin_trace(cq_id, tid, trace_buff_size); + }); + } + return tid; +} + +void end_trace_capture(DeviceMesh* device, const uint32_t tid, const uint8_t cq_id = 0) { + auto workers = device->get_devices(); + for (auto& worker : workers) { + worker->push_work( + [worker, cq_id, tid] () mutable { + worker->end_trace(cq_id, tid); + }); + } +} + +inline void execute_trace(DeviceMesh* device, const uint32_t tid, const uint8_t cq_id = 0, bool blocking = true) { + auto workers = device->get_devices(); + // If blocking, ensure that each worker thread blocks until device-local trace is completed + for (auto& worker : workers) { + worker->push_work( + [worker, cq_id, tid, blocking] () mutable { + worker->replay_trace(cq_id, tid, blocking); + }); + } + // If blocking, wait until worker threads have completed + if (blocking) { + for (auto& worker : workers) { + worker->synchronize(); + } + } +} + +inline void release_trace(DeviceMesh* device, const uint32_t tid) { + auto workers = device->get_devices(); + for (auto& worker : workers) { + worker->push_work( + [worker, tid] () mutable { + worker->release_trace(tid); + }); + } +} + std::vector get_device_tensors(const ttnn::Tensor& tensor) { if (std::holds_alternative(tensor.get_storage())) { std::vector tensors; diff --git a/ttnn/cpp/ttnn/operations/core.hpp b/ttnn/cpp/ttnn/operations/core.hpp index 911dd40b3b0..6e780524fc4 100644 --- a/ttnn/cpp/ttnn/operations/core.hpp +++ b/ttnn/cpp/ttnn/operations/core.hpp @@ -165,6 +165,20 @@ inline ttnn::Tensor to_device( return tensor.to(device_mesh, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); } +inline ttnn::Tensor allocate_tensor_on_device( + const Shape& shape, DataType data_type, Layout layout, Device *device, const std::optional& memory_config) { + return tt::tt_metal::allocate_tensor_on_device(shape, data_type, layout, device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); +} + +inline ttnn::Tensor allocate_tensor_on_device( + const Shape& shape, DataType data_type, Layout layout, DeviceMesh *device_mesh, const std::optional& memory_config) { + return tt::tt_metal::allocate_tensor_on_device(shape, data_type, layout, device_mesh, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); +} + +inline void copy_host_to_device_tensor(ttnn::Tensor host_tensor, ttnn::Tensor device_tensor, uint8_t cq_id = 0) { + tt::tt_metal::write_tensor(host_tensor, device_tensor, cq_id); +} + inline ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking = true) { return tensor.cpu(blocking); } inline void deallocate(Tensor& tensor, bool force = true) { tensor.deallocate(force); } diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 47ec5fc67cf..34abee07f98 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -197,6 +197,10 @@ def manage_config(name, value): manage_device, synchronize_device, dump_device_memory_state, + begin_trace_capture, + end_trace_capture, + execute_trace, + release_trace, ) from ttnn.multi_device import ( @@ -214,6 +218,10 @@ def manage_config(name, value): MeshToTensor, ConcatMeshToTensor, ListMeshToTensor, + begin_multi_device_trace_capture, + end_multi_device_trace_capture, + execute_multi_device_trace, + release_multi_device_trace, ) from ttnn.core import ( @@ -260,6 +268,8 @@ def manage_config(name, value): squeeze, clone, as_tensor, + allocate_tensor_on_device, + copy_host_to_device_tensor, ) from ttnn.operations.matmul import ( diff --git a/ttnn/ttnn/device.py b/ttnn/ttnn/device.py index c926936dc75..b84f7b3e068 100644 --- a/ttnn/ttnn/device.py +++ b/ttnn/ttnn/device.py @@ -54,6 +54,22 @@ def synchronize_device(device): ttl.device.Synchronize(device) +def begin_trace_capture(device, trace_buff_size, cq_id=0): + return ttnn._ttnn.device.begin_trace_capture(device, trace_buff_size, cq_id) + + +def end_trace_capture(device, trace_id, cq_id=0): + ttnn._ttnn.device.end_trace_capture(device, trace_id, cq_id) + + +def execute_trace(device, trace_id, cq_id=0, blocking=True): + ttnn._ttnn.device.execute_trace(device, trace_id, cq_id, blocking) + + +def release_trace(device, trace_id): + ttnn._ttnn.device.release_trace(device, trace_id) + + @contextlib.contextmanager def manage_device(device_id: int): """ diff --git a/ttnn/ttnn/multi_device.py b/ttnn/ttnn/multi_device.py index aba7f33e8ad..cc2934eeb13 100644 --- a/ttnn/ttnn/multi_device.py +++ b/ttnn/ttnn/multi_device.py @@ -57,6 +57,22 @@ def close_device_mesh(device_mesh): return ttnn._ttnn.multi_device.close_device_mesh(device_mesh) +def begin_multi_device_trace_capture(device_mesh, trace_buffer_size, cq_id=0): + return ttnn._ttnn.multi_device.begin_trace_capture(device_mesh, trace_buffer_size, cq_id) + + +def end_multi_device_trace_capture(device_mesh, trace_id, cq_id=0): + ttnn._ttnn.multi_device.end_trace_capture(device_mesh, trace_id, cq_id) + + +def execute_multi_device_trace(device_mesh, trace_id, cq_id=0, blocking=True): + ttnn._ttnn.multi_device.execute_trace(device_mesh, trace_id, cq_id, blocking) + + +def release_multi_device_trace(device_mesh, trace_id): + ttnn._ttnn.multi_device.release_trace(device_mesh, trace_id) + + @contextlib.contextmanager def create_device_mesh( device_grid: ttnn.DeviceGrid, device_ids: List[int], l1_small_size: int = ttl.device.DEFAULT_L1_SMALL_SIZE diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index aea2c429fff..8ef9fc5ca6f 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -454,6 +454,13 @@ def _golden_function(tensor, *args, **kwargs): doc=doc, )(ttnn._ttnn.operations.core.from_device) +allocate_tensor_on_device = ttnn.register_operation( + name="ttnn.allocate_tensor_on_device", +)(ttnn._ttnn.operations.core.allocate_tensor_on_device) + +copy_host_to_device_tensor = ttnn.register_operation( + name="ttnn.copy_host_to_device_tensor", +)(ttnn._ttnn.operations.core.copy_host_to_device_tensor) doc = """ deallocate(tensor: ttnn.Tensor, force: bool = True) -> None