diff --git a/models/demos/resnet/tests/test_metal_resnet50.py b/models/demos/resnet/tests/test_metal_resnet50.py index b24297caab8..ad332a641c2 100644 --- a/models/demos/resnet/tests/test_metal_resnet50.py +++ b/models/demos/resnet/tests/test_metal_resnet50.py @@ -8,7 +8,7 @@ import pytest import tt_lib -from models.utility_functions import is_e75, skip_for_wormhole_b0 +from models.utility_functions import is_e75, skip_for_wormhole_b0, divup from models.demos.resnet.tt.metalResnetBlock50 import ResNet, Bottleneck from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( @@ -117,26 +117,107 @@ } -@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") -@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) -@pytest.mark.parametrize("batch_size", [1, 2, 16, 20], ids=["batch_1", "batch_2", "batch_16", "batch_20"]) -@pytest.mark.parametrize( - "weights_dtype", - [tt_lib.tensor.DataType.BFLOAT16, tt_lib.tensor.DataType.BFLOAT8_B], - ids=["weights_BFLOAT16", "weights_BFLOAT8_B"], -) -@pytest.mark.parametrize( - "activations_dtype", - [tt_lib.tensor.DataType.BFLOAT16, tt_lib.tensor.DataType.BFLOAT8_B], - ids=["activations_BFLOAT16", "activations_BFLOAT8_B"], -) -@pytest.mark.parametrize( - "math_fidelity", - [tt_lib.tensor.MathFidelity.HiFi4, tt_lib.tensor.MathFidelity.HiFi2, tt_lib.tensor.MathFidelity.LoFi], - ids=["HiFi4", "HiFi2", "LoFi"], -) -def test_run_resnet50_inference( - device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +def run_model(device, tt_image, tt_resnet50): + tt_output = tt_resnet50(tt_image) + return tt_output.cpu(blocking=True) + + +def run_2cq_model(device, tt_image, tt_resnet50): + input_shape = tt_image.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_image.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_image.shape, tt_image.dtype, tt_image.layout, device, sharded_mem_config_DRAM + ) + op_event = tt_lib.device.CreateEvent() + write_event = tt_lib.device.CreateEvent() + # Initialize the op event so we can write + tt_lib.device.RecordEvent(device, 0, op_event) + + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_image, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + _ = tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=True) + + # Test overlapping write + outputs = [] + for iter in range(0, 2): + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_image, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + outputs.append(tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=False)) + tt_lib.device.Synchronize(device) + return outputs[1] + + +def run_trace_model(device, tt_image, tt_resnet50): + input_shape = tt_image.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_image.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_image.shape, tt_image.dtype, tt_image.layout, device, sharded_mem_config_DRAM + ) + tt_lib.tensor.write_tensor(tt_image, tt_image_res) + + # Compile + tt_resnet50(tt_image_res) + # Trace + tid = tt_lib.device.BeginTraceCapture(device, 0, 1500000) + tt_output_res = tt_resnet50(tt_image_res) + tt_lib.device.EndTraceCapture(device, 0, tid) + + tt_lib.tensor.write_tensor(tt_image, tt_image_res) + tt_lib.device.ReplayTrace(device, 0, tid, True) + + # Done with the trace, can deallocate the buffers now. + tt_lib.device.ReleaseTrace(device, tid) + + return tt_output_res.cpu(blocking=True) + + +def run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_fn, ): if is_e75(device): pytest.skip("Resnet50 is not supported on E75") @@ -159,8 +240,6 @@ def test_run_resnet50_inference( with torch.no_grad(): torch.manual_seed(1234) - tt_lib.device.EnableMemoryReports() - torch_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) torch_resnet50.eval() @@ -185,17 +264,8 @@ def test_run_resnet50_inference( torch_output = torch_resnet50(image).unsqueeze(1).unsqueeze(1) tt_image = tt_resnet50.preprocessing(image) - tt_output = tt_resnet50(tt_image) - tt_output = tt_output.cpu().to_torch().to(torch.float) - - # # run again to measure end to end perf - # start_time = datetime.now() - # tt_output = tt_resnet50(image) - # end_time = datetime.now() - # diff = end_time - start_time - # logger.info("End to end time (microseconds))", diff.microseconds) - # throughput_fps = (float) (1000000 / diff.microseconds) - # logger.info("Throughput (fps)", throughput_fps) + tt_output = run_fn(device, tt_image, tt_resnet50) + tt_output = tt_output.to_torch().to(torch.float) _, _, _, info = get_atol_rtol_pcc(torch_output, tt_output) logger.info(info) @@ -239,6 +309,72 @@ def test_run_resnet50_inference( [tt_lib.tensor.MathFidelity.HiFi4, tt_lib.tensor.MathFidelity.HiFi2, tt_lib.tensor.MathFidelity.LoFi], ids=["HiFi4", "HiFi2", "LoFi"], ) +def test_run_resnet50_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_model, + ) + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +def test_run_resnet50_2cqs_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_2cq_model, + ) + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) @pytest.mark.parametrize("enable_async", [True, False]) def test_run_resnet50_trace_inference( device, @@ -250,101 +386,17 @@ def test_run_resnet50_trace_inference( imagenet_sample_input, enable_async, ): - if is_e75(device): - pytest.skip("Resnet50 is not supported on E75") device.enable_async(enable_async) - if batch_size > 8 and ( - activations_dtype != tt_lib.tensor.DataType.BFLOAT8_B or weights_dtype != tt_lib.tensor.DataType.BFLOAT8_B - ): - pytest.skip("Batch > 8 must be run fully bfp8") - if batch_size <= 2: - pytest.skip("batch 1 and 2 are not supported with sharded data") - image1 = imagenet_sample_input - image = image1 - model_config = { - "MATH_FIDELITY": math_fidelity, - "WEIGHTS_DTYPE": weights_dtype, - "ACTIVATIONS_DTYPE": activations_dtype, - } - for i in range(batch_size - 1): - image = torch.cat((image, image1), dim=0) - with torch.no_grad(): - torch.manual_seed(1234) - - tt_lib.device.EnableMemoryReports() - - torch_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) - torch_resnet50.eval() - - state_dict = torch_resnet50.state_dict() - storage_in_dram = False - sharded = False - if batch_size >= 8: - sharded = True - # run once to compile ops - tt_resnet50 = ResNet( - Bottleneck, - [3, 4, 6, 3], - device=device, - state_dict=state_dict, - base_address="", - fold_batchnorm=True, - storage_in_dram=storage_in_dram, - batch_size=batch_size, - model_config=model_config, - sharded=sharded, - ) - - torch_output = torch_resnet50(image).unsqueeze(1).unsqueeze(1) - interleaved_mem_config_DRAM = tt_lib.tensor.MemoryConfig( - memory_layout=tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, - buffer_type=tt_lib.tensor.BufferType.DRAM, - ) - - tt_image_res = tt_resnet50.preprocessing(image).to(device, interleaved_mem_config_DRAM) - # Compile - tt_resnet50(tt_image_res) - # Trace - tid = tt_lib.device.BeginTraceCapture(device, 0, 1334880) - tt_output_res = tt_resnet50(tt_image_res) - tt_lib.device.EndTraceCapture(device, 0, tid) + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_trace_model, + ) - tt_lib.device.ReplayTrace(device, 0, tid, True) - - tt_output = tt_output_res.cpu().to_torch().to(torch.float) - - # # run again to measure end to end perf - # start_time = datetime.now() - # tt_output = tt_resnet50(image) - # end_time = datetime.now() - # diff = end_time - start_time - # logger.info("End to end time (microseconds))", diff.microseconds) - # throughput_fps = (float) (1000000 / diff.microseconds) - # logger.info("Throughput (fps)", throughput_fps) - - _, _, _, info = get_atol_rtol_pcc(torch_output, tt_output) - logger.info(info) - - valid_pcc = 1.0 - if batch_size >= 8: - valid_pcc = golden_pcc[batch_size][ - (model_config["MATH_FIDELITY"], model_config["WEIGHTS_DTYPE"], model_config["ACTIVATIONS_DTYPE"]) - ] - else: - if model_config["ACTIVATIONS_DTYPE"] == tt_lib.tensor.DataType.BFLOAT8_B: - if model_config["MATH_FIDELITY"] == tt_lib.tensor.MathFidelity.LoFi: - valid_pcc = 0.87 - else: - valid_pcc = 0.94 - else: - if model_config["MATH_FIDELITY"] == tt_lib.tensor.MathFidelity.LoFi: - valid_pcc = 0.93 - else: - valid_pcc = 0.982 - passing_pcc, _ = comp_pcc(torch_output, tt_output, pcc=valid_pcc) - assert passing_pcc - # assert passing # fails because of torch.allclose - # Done with the trace, can deallocate the buffers now. - tt_lib.device.ReleaseTrace(device, tid) device.enable_async(False) diff --git a/models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py b/models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py new file mode 100644 index 00000000000..6bb3147c6d3 --- /dev/null +++ b/models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import tt_lib + +from models.demos.resnet.tests.test_metal_resnet50 import run_resnet50_inference, run_2cq_model +from models.utility_functions import skip_for_wormhole_b0 + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +def test_run_resnet50_2cqs_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_2cq_model, + ) diff --git a/models/demos/resnet/tests/test_metal_resnet50_performant.py b/models/demos/resnet/tests/test_metal_resnet50_performant.py new file mode 100644 index 00000000000..cbd266c568c --- /dev/null +++ b/models/demos/resnet/tests/test_metal_resnet50_performant.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import tt_lib + +from models.demos.resnet.tests.test_metal_resnet50 import run_resnet50_inference, run_model, run_trace_model +from models.utility_functions import skip_for_wormhole_b0 + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +def test_run_resnet50_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_model, + ) + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +@pytest.mark.parametrize("enable_async", [True, False]) +def test_run_resnet50_trace_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + enable_async, +): + device.enable_async(enable_async) + + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_trace_model, + ) + + device.enable_async(False) diff --git a/models/demos/resnet/tests/test_perf_accuracy_resnet.py b/models/demos/resnet/tests/test_perf_accuracy_resnet.py index 722000caea5..6c719ebbf5b 100644 --- a/models/demos/resnet/tests/test_perf_accuracy_resnet.py +++ b/models/demos/resnet/tests/test_perf_accuracy_resnet.py @@ -84,6 +84,7 @@ def run_perf_resnet( tt_output = tt_output.cpu().to_torch().to(torch.float) profiler.end(first_key) del tt_output + return enable_persistent_kernel_cache() diff --git a/models/demos/resnet/tests/test_perf_resnet.py b/models/demos/resnet/tests/test_perf_resnet.py index f7bc7368ed2..a93c82876c9 100644 --- a/models/demos/resnet/tests/test_perf_resnet.py +++ b/models/demos/resnet/tests/test_perf_resnet.py @@ -9,9 +9,7 @@ import pytest import tt_lib -from models.utility_functions import is_e75 -from models.utility_functions import profiler -from models.utility_functions import disable_persistent_kernel_cache, skip_for_wormhole_b0 +from models.utility_functions import is_e75, profiler, divup, disable_persistent_kernel_cache, skip_for_wormhole_b0 from models.perf.perf_utils import prep_perf_report from loguru import logger @@ -24,13 +22,145 @@ } +def run_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations): + profiler.start("compile") + _ = tt_resnet50(tt_inputs).cpu(blocking=True) + profiler.end("compile") + tt_lib.device.DumpDeviceProfiler(device) + + for iter in range(0, num_warmup_iterations): + _ = tt_resnet50(tt_inputs).cpu(blocking=True) + tt_lib.device.DumpDeviceProfiler(device) + + outputs = [] + profiler.start(f"run") + for iter in range(0, num_measurement_iterations): + outputs.append(tt_resnet50(tt_inputs).cpu(blocking=False)) + tt_lib.device.Synchronize(device) + profiler.end(f"run") + tt_lib.device.DumpDeviceProfiler(device) + + +def run_2cq_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations): + input_shape = tt_inputs.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_inputs.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_inputs.shape, tt_inputs.dtype, tt_inputs.layout, device, sharded_mem_config_DRAM + ) + op_event = tt_lib.device.CreateEvent() + write_event = tt_lib.device.CreateEvent() + # Initialize the op event so we can write + tt_lib.device.RecordEvent(device, 0, op_event) + + profiler.start("compile") + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + _ = tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=True) + profiler.end("compile") + tt_lib.device.DumpDeviceProfiler(device) + + for iter in range(0, num_warmup_iterations): + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + _ = tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=True) + tt_lib.device.DumpDeviceProfiler(device) + + outputs = [] + profiler.start(f"run") + for iter in range(0, num_measurement_iterations): + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + outputs.append(tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=False)) + tt_lib.device.Synchronize(device) + profiler.end(f"run") + tt_lib.device.DumpDeviceProfiler(device) + + +def run_trace_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations): + input_shape = tt_inputs.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_inputs.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_inputs.shape, tt_inputs.dtype, tt_inputs.layout, device, sharded_mem_config_DRAM + ) + # Compile + profiler.start("compile") + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) + tt_resnet50(tt_image_res).cpu(blocking=True) + profiler.end("compile") + tt_lib.device.DumpDeviceProfiler(device) + + # Capture + tid = tt_lib.device.BeginTraceCapture(device, 0, 1500000) + tt_output_res = tt_resnet50(tt_image_res) + tt_lib.device.EndTraceCapture(device, 0, tid) + tt_lib.device.DumpDeviceProfiler(device) + + for iter in range(0, num_warmup_iterations): + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) + tt_lib.device.ReplayTrace(device, 0, tid, False) + _ = tt_output_res.cpu(blocking=True) + tt_lib.device.DumpDeviceProfiler(device) + + outputs = [] + profiler.start(f"run") + for iter in range(0, num_measurement_iterations): + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) + tt_lib.device.ReplayTrace(device, 0, tid, False) + outputs.append(tt_output_res.cpu(blocking=False)) + tt_lib.device.Synchronize(device) + profiler.end(f"run") + tt_lib.device.DumpDeviceProfiler(device) + + def run_perf_resnet( batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, + model_version, ): + if is_e75(device): + pytest.skip("Resnet is not supported on E75") disable_persistent_kernel_cache() if batch_size <= 2: pytest.skip("Batch size 1 and 2 are not supported with sharded data") @@ -69,6 +199,10 @@ def run_perf_resnet( model_config=model_config, sharded=sharded, ) + tt_lib.device.Synchronize(device) + + num_warmup_iterations = 5 + num_measurement_iterations = 15 with torch.no_grad(): profiler.start(cpu_key) @@ -76,36 +210,24 @@ def run_perf_resnet( profiler.end(cpu_key) tt_inputs = tt_resnet50.preprocessing(inputs) - warmup_end = 5 - for iter in range(0, warmup_end): - profiler.start(f"{iter}_key") - _ = tt_resnet50(tt_inputs).cpu(blocking=True) - profiler.end(f"{iter}_key") - tt_lib.device.DumpDeviceProfiler(device) - - num_warm_iterations = 15 - warm_start = warmup_end - warm_end = warm_start + num_warm_iterations - - outputs = [] - profiler.start(f"run") - for iter in range(warm_start, warm_end): - outputs.append(tt_resnet50(tt_inputs).cpu(blocking=False)) - tt_lib.device.Synchronize(device) - profiler.end(f"run") - tt_lib.device.DumpDeviceProfiler(device) + if "resnet50_2cqs" in model_version: + run_2cq_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations) + elif "resnet50_trace" in model_version: + run_trace_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations) + elif "resnet50" in model_version: + run_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations) + else: + assert False, f"Model version to run {model_version} not found" - # enable_persistent_kernel_cache() - - first_iter_time = profiler.get(f"{0}_key") + first_iter_time = profiler.get(f"compile") # ensuring inference time fluctuations is not noise - inference_time_avg = profiler.get("run") / num_warm_iterations + inference_time_avg = profiler.get("run") / num_measurement_iterations cpu_time = profiler.get(cpu_key) compile_time = first_iter_time - inference_time_avg prep_perf_report( - model_name=f"resnet50_batch_size{batch_size}", + model_name=f"{model_version}_batch_size{batch_size}", batch_size=batch_size, inference_and_compile_time=first_iter_time, inference_time=inference_time_avg, @@ -115,8 +237,8 @@ def run_perf_resnet( inference_time_cpu=cpu_time, ) - logger.info(f"resnet50 {comments} inference time (avg): {inference_time_avg}") - logger.info(f"resnet50 compile time: {compile_time}") + logger.info(f"{model_name} {comments} inference time (avg): {inference_time_avg}") + logger.info(f"{model_name} compile time: {compile_time}") @skip_for_wormhole_b0(reason_str="Not tested on single WH") @@ -125,10 +247,8 @@ def run_perf_resnet( @pytest.mark.parametrize( "batch_size, expected_inference_time, expected_compile_time", ( - (1, 0.001, 1), - (2, 0.001, 1), - (16, 0.007, 7), - (20, 0.007, 7), + (16, 0.007, 16), + (20, 0.007, 16), ), ) def test_perf_bare_metal( @@ -143,145 +263,16 @@ def test_perf_bare_metal( pytest.skip("Resnet is not supported on E75") run_perf_resnet( - batch_size, - expected_inference_time, - expected_compile_time, - hf_cat_image_sample_input, - device, - ) - - -def run_perf_resnet_trace( - batch_size, - expected_inference_time, - expected_compile_time, - hf_cat_image_sample_input, - device, -): - disable_persistent_kernel_cache() - if batch_size <= 2: - pytest.skip("Batch size 1 and 2 are not supported with sharded data") - first_key = f"first_iter_batchsize{batch_size}" - second_key = f"second_iter_batchsize{batch_size}" - cpu_key = f"ref_key_batchsize{batch_size}" - model_name = "microsoft/resnet-50" - - image = hf_cat_image_sample_input - image_processor = AutoImageProcessor.from_pretrained(model_name) - inputs = image_processor(image, return_tensors="pt") - - inputs = inputs["pixel_values"] - comments = f"{list(inputs.shape)[-2]}x{list(inputs.shape)[-1]}_batchsize{batch_size}" - - inputs1 = inputs - for i in range(batch_size - 1): - inputs = torch.cat((inputs, inputs1), dim=0) - - torch_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) - torch_resnet50.eval() - - state_dict = torch_resnet50.state_dict() - sharded = False - if batch_size >= 8: - sharded = True - tt_resnet50 = ResNet( - Bottleneck, - [3, 4, 6, 3], - device=device, - state_dict=state_dict, - base_address="", - fold_batchnorm=True, - storage_in_dram=False, - batch_size=batch_size, - model_config=model_config, - sharded=sharded, + batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, "resnet50" ) - with torch.no_grad(): - profiler.start(cpu_key) - logits = torch_resnet50(inputs) - profiler.end(cpu_key) - - tt_inputs = tt_resnet50.preprocessing(inputs) - interleaved_mem_config_DRAM = tt_lib.tensor.MemoryConfig( - memory_layout=tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, - buffer_type=tt_lib.tensor.BufferType.DRAM, - ) - tt_image_res = tt_inputs.to(device, interleaved_mem_config_DRAM) - # Compile - profiler.start(f"{0}_key") - tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) - tt_resnet50(tt_image_res).cpu(blocking=True) - profiler.end(f"{0}_key") - tt_lib.device.DumpDeviceProfiler(device) - - # Capture - tid = tt_lib.device.BeginTraceCapture(device, 0, 1334880) - tt_output_res = tt_resnet50(tt_image_res) - tt_lib.device.EndTraceCapture(device, 0, tid) - tt_lib.device.DumpDeviceProfiler(device) - - warmup_end = 6 - for iter in range(1, warmup_end): - profiler.start(f"{iter}_key") - tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) - tt_lib.device.ReplayTrace(device, 0, tid, False) - _ = tt_output_res.cpu(blocking=True) - profiler.end(f"{iter}_key") - tt_lib.device.DumpDeviceProfiler(device) - - num_warm_iterations = 15 - warm_start = warmup_end - warm_end = warm_start + num_warm_iterations - - outputs = [] - profiler.start(f"run") - for iter in range(warm_start, warm_end): - tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) - tt_lib.device.ReplayTrace(device, 0, tid, False) - outputs.append(tt_output_res.cpu(blocking=False)) - tt_lib.device.Synchronize(device) - profiler.end(f"run") - tt_lib.device.DumpDeviceProfiler(device) - - # enable_persistent_kernel_cache() - - first_iter_time = profiler.get(f"{0}_key") - - # ensuring inference time fluctuations is not noise - inference_time_avg = profiler.get("run") / num_warm_iterations - - cpu_time = profiler.get(cpu_key) - compile_time = first_iter_time - inference_time_avg - prep_perf_report( - model_name=f"resnet50_trace_batch_size{batch_size}", - batch_size=batch_size, - inference_and_compile_time=first_iter_time, - inference_time=inference_time_avg, - expected_compile_time=expected_compile_time, - expected_inference_time=expected_inference_time, - comments=comments, - inference_time_cpu=cpu_time, - ) - - logger.info(f"resnet50 {comments} inference time (avg): {inference_time_avg}") - logger.info(f"resnet50 compile time: {compile_time}") - - tt_lib.device.ReleaseTrace(device, tid) - - assert inference_time_avg < expected_inference_time, f"resnet50 {comments} inference is too slow" - assert compile_time < expected_compile_time, f"resnet50 {comments} compilation is too slow" - @skip_for_wormhole_b0(reason_str="Not tested on single WH") @pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) @pytest.mark.models_performance_bare_metal @pytest.mark.parametrize( "batch_size, expected_inference_time, expected_compile_time", - ( - (16, 0.04, 25), - (20, 0.04, 25), - ), + ((20, 0.008, 16),), ) @pytest.mark.parametrize("enable_async", [True, False]) def test_perf_trace_bare_metal( @@ -293,14 +284,14 @@ def test_perf_trace_bare_metal( hf_cat_image_sample_input, enable_async, ): - if is_e75(device): - pytest.skip("Resnet is not supported on E75") device.enable_async(enable_async) - run_perf_resnet_trace( + mode = "async" if enable_async else "sync" + run_perf_resnet( batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, + f"resnet50_trace_{mode}", ) device.enable_async(False) diff --git a/models/demos/resnet/tests/test_perf_resnet_2cqs.py b/models/demos/resnet/tests/test_perf_resnet_2cqs.py new file mode 100644 index 00000000000..eddbc1bf4ed --- /dev/null +++ b/models/demos/resnet/tests/test_perf_resnet_2cqs.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from models.demos.resnet.tests.test_perf_resnet import run_perf_resnet +from models.utility_functions import skip_for_wormhole_b0 + + +@skip_for_wormhole_b0(reason_str="Not tested on single WH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, expected_inference_time, expected_compile_time", + ((20, 0.0055, 16),), +) +def test_perf_2cqs_bare_metal( + device, + use_program_cache, + batch_size, + expected_inference_time, + expected_compile_time, + hf_cat_image_sample_input, +): + run_perf_resnet( + batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, "resnet50_2cqs" + ) diff --git a/models/demos/resnet/tt/metalResnetBlock50.py b/models/demos/resnet/tt/metalResnetBlock50.py index 16f8fb01ffb..32e3f913c31 100644 --- a/models/demos/resnet/tt/metalResnetBlock50.py +++ b/models/demos/resnet/tt/metalResnetBlock50.py @@ -2101,7 +2101,7 @@ def preprocessing_with_fold(self, x: torch.Tensor) -> tt_lib.tensor: return x - def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: + def forward(self, x: tt_lib.tensor, write_event=None, op_event=None) -> tt_lib.tensor: if not self.sharded: original_A_cl_host_shape = x.get_legacy_shape() x = x.reshape( @@ -2116,7 +2116,7 @@ def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: original_A_cl_host_shape[2], original_A_cl_host_shape[3], ) - elif x.storage_type() != tt_lib.tensor.StorageType.DEVICE: + else: x_shape = x.get_legacy_shape() shard_spec = tt_lib.tensor.ShardSpec( self.shard_grid, @@ -2130,21 +2130,16 @@ def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: mem_config = tt_lib.tensor.MemoryConfig( tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.L1, shard_spec ) - x = x.to(self.device, mem_config) - else: - shard_spec = tt_lib.tensor.ShardSpec( - self.shard_grid, - [ - x.get_legacy_shape()[2] // self.first_conv_num_cores_nhw, - x.get_legacy_shape()[3], - ], - tt_lib.tensor.ShardOrientation.ROW_MAJOR, - False, - ) - mem_config = tt_lib.tensor.MemoryConfig( - tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.L1, shard_spec - ) - x = tt_lib.tensor.interleaved_to_sharded(x, mem_config) + if write_event is not None: + tt_lib.device.WaitForEvent(self.device, 0, write_event) + if x.storage_type() != tt_lib.tensor.StorageType.DEVICE: + x = x.to(self.device, mem_config) + elif x.memory_config().is_sharded(): + x = tt_lib.tensor.reshard(x, mem_config) + else: + x = tt_lib.tensor.interleaved_to_sharded(x, mem_config) + if op_event is not None: + tt_lib.device.RecordEvent(self.device, 0, op_event) x = self.conv1(x) # Relu is fused with conv1 diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index 23cc2d0d0ba..e535e635d45 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -17,7 +17,9 @@ run_perf_models_other() { env pytest models/demos/ttnn_falcon7b/tests -m $test_marker - env pytest models/demos/resnet/tests -m $test_marker + # Separate calls since we can't mix switching between number of cqs + env pytest models/demos/resnet/tests/test_perf_resnet.py -m $test_marker + env pytest models/demos/resnet/tests/test_perf_resnet_2cqs.py -m $test_marker env pytest tests/ttnn/integration_tests/whisper/test_performance.py -m $test_marker diff --git a/tests/scripts/single_card/nightly/run_gs_only.sh b/tests/scripts/single_card/nightly/run_gs_only.sh index 9973f35b7bd..36ed969d4a0 100755 --- a/tests/scripts/single_card/nightly/run_gs_only.sh +++ b/tests/scripts/single_card/nightly/run_gs_only.sh @@ -11,6 +11,6 @@ echo "Running model nightly tests for GS only" env pytest models/demos/metal_BERT_large_11/tests/test_demo.py -env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_inference[LoFi-activations_BFLOAT8_B-weights_BFLOAT8_B-batch_20-device_params0] +env pytest models/demos/resnet/tests/test_metal_resnet50_performant.py -env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_trace_inference -k "LoFi-activations_BFLOAT8_B-weights_BFLOAT8_B-batch_20-device_params0" +env pytest models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp index 9ae0f0adffb..9f94b540aaf 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp @@ -17,7 +17,7 @@ void kernel_main() { volatile tt_l1_ptr uint32_t* done_address = reinterpret_cast(L1_UNRESERVED_BASE); while (done_address[0] == 0) { - uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y), pcie_read_ptr); + uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_PCIE_ENCODING(PCIE_NOC_X, PCIE_NOC_Y, NOC_INDEX), pcie_read_ptr); noc_async_read(host_src_addr, L1_UNRESERVED_BASE, read_sizeB); pcie_read_ptr += read_sizeB; if (pcie_read_ptr > pcie_base + pcie_sizeB) { diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp index 05c4a338ff5..ac8945a4d6d 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp @@ -11,7 +11,7 @@ void kernel_main() { constexpr uint32_t base_pcie_dst_address = get_compile_time_arg_val(1); constexpr uint32_t num_16b_writes = get_compile_time_arg_val(2); - uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y)) << 32; + uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_PCIE_ENCODING(PCIE_NOC_X, PCIE_NOC_Y, NOC_INDEX)) << 32; uint32_t l1_src_address = base_l1_src_address; uint32_t pcie_dst_address = base_pcie_dst_address; diff --git a/tests/ttnn/integration_tests/bert/test_performance.py b/tests/ttnn/integration_tests/bert/test_performance.py index 034df32b53d..e29b0a44329 100644 --- a/tests/ttnn/integration_tests/bert/test_performance.py +++ b/tests/ttnn/integration_tests/bert/test_performance.py @@ -59,7 +59,7 @@ def get_expected_times(bert): return { ttnn_bert: (0.1, 0.1), ttnn_optimized_bert: (5.5, 0.07), - ttnn_optimized_sharded_bert: (5.2, 0.07), + ttnn_optimized_sharded_bert: (5.5, 0.07), }[bert] diff --git a/tests/ttnn/integration_tests/whisper/test_performance.py b/tests/ttnn/integration_tests/whisper/test_performance.py index b88669f43d9..41c559c5ef0 100644 --- a/tests/ttnn/integration_tests/whisper/test_performance.py +++ b/tests/ttnn/integration_tests/whisper/test_performance.py @@ -17,7 +17,7 @@ def get_expected_times(functional_whisper): return { - ttnn_functional_whisper: (10.5, 4.16), + ttnn_functional_whisper: (11, 4.16), ttnn_optimized_functional_whisper: (1.2, 1.35), }[functional_whisper] diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp index fc371068921..2d15d354531 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp @@ -251,6 +251,30 @@ void DeviceModule(py::module &m_device) { Release captured Trace on Device handle )doc"); + auto pyEvent = py::class_>(m_device, "Event", "Event class"); + m_device.def("CreateEvent", + [] () { + return std::make_shared(); + }, R"doc( + Create new event + )doc"); + m_device.def("RecordEvent", + [] (Device* device, const uint8_t cq_id, std::shared_ptr event) { + device->push_work([device, cq_id, event] { + EnqueueRecordEvent(device->command_queue(cq_id), event); + }); + }, R"doc( + Record an event + )doc"); + m_device.def("WaitForEvent", + [] (Device* device, const uint8_t cq_id, std::shared_ptr event) { + device->push_work([device, cq_id, event] { + EnqueueWaitForEvent(device->command_queue(cq_id), event); + }); + }, R"doc( + Wait for an event + )doc"); + m_device.attr("DEFAULT_L1_SMALL_SIZE") = py::int_(DEFAULT_L1_SMALL_SIZE); } diff --git a/tt_metal/hw/inc/blackhole/noc/noc_parameters.h b/tt_metal/hw/inc/blackhole/noc/noc_parameters.h index 8b8e9ad1415..7f6529f9915 100644 --- a/tt_metal/hw/inc/blackhole/noc/noc_parameters.h +++ b/tt_metal/hw/inc/blackhole/noc/noc_parameters.h @@ -14,6 +14,9 @@ #define NOC_XY_ENCODING(x, y) \ ((((uint64_t)(y)) << (NOC_ADDR_LOCAL_BITS + NOC_ADDR_NODE_ID_BITS)) | (((uint64_t)(x)) << NOC_ADDR_LOCAL_BITS)) +#define NOC_XY_PCIE_ENCODING(x, y, noc_index) \ + NOC_XY_ENCODING(x, y) + #define NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end) \ ((((uint64_t)(x_start)) << (NOC_ADDR_LOCAL_BITS + 2 * NOC_ADDR_NODE_ID_BITS)) | \ (((uint64_t)(y_start)) << (NOC_ADDR_LOCAL_BITS + 3 * NOC_ADDR_NODE_ID_BITS)) | \ diff --git a/tt_metal/hw/inc/dataflow_api.h b/tt_metal/hw/inc/dataflow_api.h index 91b1a26f8f3..12df89b03de 100644 --- a/tt_metal/hw/inc/dataflow_api.h +++ b/tt_metal/hw/inc/dataflow_api.h @@ -476,7 +476,7 @@ uint64_t get_l1_noc_addr(const uint32_t id, const uint32_t page_size, const uint } uint64_t get_system_memory_noc_addr(const uint32_t id, const uint32_t page_size, const uint32_t base_addr, const uint32_t offset = 0) { - constexpr static uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y)) << 32; + uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_PCIE_ENCODING(NOC_X(PCIE_NOC_X), NOC_Y(PCIE_NOC_Y), noc_index)) << 32; uint32_t addr = base_addr + page_size * id + offset; uint64_t noc_addr = pcie_core_noc_encoding | addr; return noc_addr; diff --git a/tt_metal/hw/inc/grayskull/noc/noc_parameters.h b/tt_metal/hw/inc/grayskull/noc/noc_parameters.h index 3fa07c45294..ed13f98ea8f 100644 --- a/tt_metal/hw/inc/grayskull/noc/noc_parameters.h +++ b/tt_metal/hw/inc/grayskull/noc/noc_parameters.h @@ -12,6 +12,8 @@ // Address formats #define NOC_XY_ENCODING(x, y) ((((uint32_t)(y)) << (NOC_ADDR_NODE_ID_BITS)) | (((uint32_t)(x)))) +#define NOC_XY_PCIE_ENCODING(x, y, noc_index) NOC_XY_ENCODING(x, y) + #define NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end) \ ((x_start) << (2 * NOC_ADDR_NODE_ID_BITS)) | ((y_start) << (3 * NOC_ADDR_NODE_ID_BITS)) | (x_end) | \ ((y_end) << (NOC_ADDR_NODE_ID_BITS)) diff --git a/tt_metal/hw/inc/wormhole/noc/noc_parameters.h b/tt_metal/hw/inc/wormhole/noc/noc_parameters.h index 0a2256ffeeb..f6b361d3ff3 100644 --- a/tt_metal/hw/inc/wormhole/noc/noc_parameters.h +++ b/tt_metal/hw/inc/wormhole/noc/noc_parameters.h @@ -9,13 +9,21 @@ #define PCIE_NOC_X 0 #define PCIE_NOC_Y 3 +#define PCIE_NOC1_X 9 +#define PCIE_NOC1_Y 8 + // Address formats #define NOC_XY_ENCODING(x, y) \ (((uint32_t)(y)) << ((NOC_ADDR_LOCAL_BITS % 32)+NOC_ADDR_NODE_ID_BITS)) | \ - (((uint32_t)(x)) << (NOC_ADDR_LOCAL_BITS % 32)) | ((x == PCIE_NOC_X and y == PCIE_NOC_Y) * 0x8) \ + (((uint32_t)(x)) << (NOC_ADDR_LOCAL_BITS % 32)) \ + +// Address formats +#define NOC_XY_PCIE_ENCODING(x, y, noc_index) \ + NOC_XY_ENCODING(x, y) | \ + ((noc_index ? (x == PCIE_NOC1_X and y == PCIE_NOC1_Y) : (x == PCIE_NOC_X and y == PCIE_NOC_Y)) * 0x8) \ #define NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end) \ - (((uint32_t)(x_start)) << ((NOC_ADDR_LOCAL_BITS % 32)+2*NOC_ADDR_NODE_ID_BITS)) | \ + (((uint32_t)(x_start)) << ((NOC_ADDR_LOCAL_BITS % 32)+2*NOC_ADDR_NODE_ID_BITS)) | \ (((uint32_t)(y_start)) << ((NOC_ADDR_LOCAL_BITS % 32)+3*NOC_ADDR_NODE_ID_BITS)) | \ (((uint32_t)(x_end)) << (NOC_ADDR_LOCAL_BITS % 32)) | \ (((uint32_t)(y_end)) << ((NOC_ADDR_LOCAL_BITS % 32)+NOC_ADDR_NODE_ID_BITS)) \ diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index e73c9efbdbb..6e9892c130c 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -16,6 +16,7 @@ #include "common/utils.hpp" #include "llrt/llrt.hpp" #include "dev_msgs.h" +#include "noc/noc_parameters.h" namespace tt { @@ -344,16 +345,19 @@ void Device::configure_kernel_variant( CoreCoord upstream_physical_core, CoreCoord downstream_physical_core, std::map defines_in, + NOC noc_index, bool is_active_eth_core) { + const auto& grid_size = tt::Cluster::instance().get_soc_desc(this->id()).grid_size; + std::map defines = { {"DISPATCH_KERNEL", "1"}, - {"MY_NOC_X", std::to_string(kernel_physical_core.x)}, - {"MY_NOC_Y", std::to_string(kernel_physical_core.y)}, - {"UPSTREAM_NOC_X", std::to_string(upstream_physical_core.x)}, - {"UPSTREAM_NOC_Y", std::to_string(upstream_physical_core.y)}, - {"DOWNSTREAM_NOC_X", std::to_string(downstream_physical_core.x)}, - {"DOWNSTREAM_NOC_Y", std::to_string(downstream_physical_core.y)}, + {"MY_NOC_X", std::to_string(NOC_0_X(noc_index, grid_size.x, kernel_physical_core.x))}, + {"MY_NOC_Y", std::to_string(NOC_0_Y(noc_index, grid_size.y, kernel_physical_core.y))}, + {"UPSTREAM_NOC_X", std::to_string(NOC_0_X(noc_index, grid_size.x, upstream_physical_core.x))}, + {"UPSTREAM_NOC_Y", std::to_string(NOC_0_Y(noc_index, grid_size.y, upstream_physical_core.y))}, + {"DOWNSTREAM_NOC_X", std::to_string(NOC_0_X(noc_index, grid_size.x, downstream_physical_core.x))}, + {"DOWNSTREAM_NOC_Y", std::to_string(NOC_0_Y(noc_index, grid_size.y, downstream_physical_core.y))}, }; defines.insert(defines_in.begin(), defines_in.end()); @@ -364,7 +368,7 @@ void Device::configure_kernel_variant( kernel_core, tt::tt_metal::DataMovementConfig { .processor = tt::tt_metal::DataMovementProcessor::RISCV_1, - .noc = NOC::NOC_0, + .noc = noc_index, .compile_args = compile_args, .defines = defines } @@ -376,7 +380,7 @@ void Device::configure_kernel_variant( kernel_core, tt::tt_metal::EthernetConfig{ .eth_mode = is_active_eth_core ? Eth::SENDER : Eth::IDLE, - .noc = NOC::NOC_0, + .noc = noc_index, .compile_args = compile_args, .defines = defines } @@ -420,6 +424,8 @@ void Device::compile_command_queue_programs() { CoreCoord prefetch_physical_core = get_physical_core_coordinate(prefetch_core, dispatch_core_type); CoreCoord dispatch_physical_core = get_physical_core_coordinate(dispatch_core, dispatch_core_type); + NOC noc_index = this->hw_command_queues_[cq_id]->noc_index; + log_debug(LogDevice, "Dispatching out of {} cores", magic_enum::enum_name(dispatch_core_type)); log_debug(LogDevice, "Prefetch HD logical location: {} physical core: {}", prefetch_core.str(), prefetch_physical_core.str()); log_debug(LogDevice, "Dispatch HD logical location: {} physical core {}", dispatch_core.str(), dispatch_physical_core.str()); @@ -465,7 +471,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, dispatch_physical_core, - std::map {} + std::map {}, + noc_index ); tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, prefetch_core, 0, dispatch_core_type); // prefetch_sync_sem @@ -501,7 +508,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, prefetch_physical_core, CoreCoord{0, 0}, - std::map {} + std::map {}, + noc_index ); tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, dispatch_core, 0, dispatch_core_type); // dispatch_sem @@ -517,7 +525,7 @@ void Device::compile_command_queue_programs() { Device *mmio_device = tt::tt_metal::detail::GetDeviceHandle(mmio_device_id); uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_id); uint32_t cq_size = mmio_device->sysmem_manager().get_cq_size(); - + NOC noc_index = this->hw_command_queues_[cq_id]->noc_index; CoreType dispatch_core_type = dispatch_core_manager::get(num_hw_cqs).get_dispatch_core_type(mmio_device_id); tt_cxy_pair prefetch_core = dispatch_core_manager::get(num_hw_cqs).prefetcher_core(device_id, channel, cq_id); @@ -610,7 +618,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, mux_physical_core, - std::map {} + std::map {}, + noc_index ); log_debug(LogDevice, "run prefetch_h {}", prefetch_core.str()); @@ -671,7 +680,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, CoreCoord{0, 0}, - std::map {{"SKIP_NOC_LOGGING", "1"}} + std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index ); std::vector tunneler_l_compile_args = @@ -715,6 +725,7 @@ void Device::compile_command_queue_programs() { CoreCoord{0, 0}, CoreCoord{0, 0}, std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index, true ); @@ -782,7 +793,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, CoreCoord{0, 0}, - std::map {{"SKIP_NOC_LOGGING", "1"}} + std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index ); log_debug(LogDevice, "run dispatch demux at {}", demux_core.str()); @@ -816,7 +828,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, demux_physical_core, CoreCoord{0xffffffff, 0xffffffff}, - std::map {} + std::map {}, + noc_index ); log_debug(LogDevice, "run dispatch_h at {}", dispatch_core.str()); @@ -895,6 +908,7 @@ void Device::compile_command_queue_programs() { CoreCoord{0, 0}, CoreCoord{0, 0}, std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index, true ); @@ -959,7 +973,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, CoreCoord{0, 0}, - std::map {{"SKIP_NOC_LOGGING", "1"}} + std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index ); log_debug(LogDevice, "run demux at {}", demux_d_core.str()); @@ -1007,7 +1022,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, demux_d_physical_core, dispatch_physical_core, - std::map {} + std::map {}, + noc_index ); log_debug(LogDevice, "run prefertch_d at {}", prefetch_d_core.str()); @@ -1041,7 +1057,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, prefetch_d_physical_core, mux_d_physical_core, - std::map {} + std::map {}, + noc_index ); log_debug(LogDevice, "run dispatch at {}", dispatch_core.str()); @@ -1100,7 +1117,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, CoreCoord{0, 0}, - std::map {{"SKIP_NOC_LOGGING", "1"}} + std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index ); log_debug(LogDevice, "run mux at {}", mux_d_core.str()); @@ -1194,7 +1212,7 @@ void Device::initialize_command_queue() { this->sysmem_manager_ = std::make_unique(this->id_, this->num_hw_cqs()); hw_command_queues_.resize(num_hw_cqs()); for (size_t cq_id = 0; cq_id < num_hw_cqs(); cq_id++) { - hw_command_queues_[cq_id] = std::make_unique(this, cq_id); + hw_command_queues_[cq_id] = std::make_unique(this, cq_id, static_cast(cq_id)); // Need to do this since CommandQueue constructor is private sw_command_queues_.push_back(std::unique_ptr(new CommandQueue(this, cq_id))); } @@ -1530,6 +1548,24 @@ std::vector Device::ethernet_cores_from_logical_cores(const std::vect return ethernet_cores; } +uint32_t Device::get_noc_unicast_encoding(uint8_t noc_index, const CoreCoord& physical_core) const { + const auto& grid_size = tt::Cluster::instance().get_soc_desc(this->id()).grid_size; + return NOC_XY_ENCODING( + NOC_0_X(noc_index, grid_size.x, physical_core.x), + NOC_0_Y(noc_index, grid_size.y, physical_core.y) + ); +} + +uint32_t Device::get_noc_multicast_encoding(uint8_t noc_index, const CoreRange& physical_cores) const { + const auto& grid_size = tt::Cluster::instance().get_soc_desc(this->id()).grid_size; + return NOC_MULTICAST_ENCODING( + NOC_0_X(noc_index, grid_size.x, physical_cores.start.x), + NOC_0_Y(noc_index, grid_size.y, physical_cores.start.y), + NOC_0_X(noc_index, grid_size.x, physical_cores.end.x), + NOC_0_Y(noc_index, grid_size.y, physical_cores.end.y) + ); +} + void Device::check_allocator_is_initialized() const { if (this->allocator_ == nullptr) { TT_THROW("No memory allocator! Device has not been initialized, did you forget to call InitializeDevice?"); diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index 7b054f03068..12df80a6bee 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -11,6 +11,7 @@ #include "impl/dispatch/work_executor.hpp" #include "tt_metal/impl/allocator/basic_allocator.hpp" #include "tt_metal/impl/allocator/l1_banking_allocator.hpp" +#include "tt_metal/impl/kernels/data_types.hpp" #include "tt_metal/impl/trace/trace_buffer.hpp" #include "tt_metal/jit_build/build.hpp" #include "llrt/tt_cluster.hpp" @@ -192,6 +193,9 @@ class Device { // core.y represents different channels along one const std::set ðernet_cores() const { return this->ethernet_cores_; } + uint32_t get_noc_unicast_encoding(uint8_t noc_index, const CoreCoord& physical_core) const; + uint32_t get_noc_multicast_encoding(uint8_t noc_index, const CoreRange& physical_cores) const; + void deallocate_buffers(); // machine epsilon @@ -229,7 +233,7 @@ class Device { void initialize_command_queue(); void initialize_synchronous_sw_cmd_queue(); void configure_kernel_variant(Program& program, string path, std::vector compile_args, CoreCoord kernel_core, CoreCoord Kernel_physical_core, - CoreType dispatch_core_type, CoreCoord upstream_physical_core, CoreCoord downstream_physical_core, std::map defines_in , bool is_active_eth_core = false); + CoreType dispatch_core_type, CoreCoord upstream_physical_core, CoreCoord downstream_physical_core, std::map defines_in, NOC noc_index, bool is_active_eth_core = false); void compile_command_queue_programs(); void configure_command_queue_programs(); void clear_l1_state(); diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 5df863d7b3b..8b5ca124ab4 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -43,16 +43,12 @@ namespace tt::tt_metal { thread_local std::unordered_map EnqueueProgramCommand::cached_program_command_sequences = {}; -uint32_t get_noc_unicast_encoding(const CoreCoord& coord) { return NOC_XY_ENCODING(NOC_X(coord.x), NOC_Y(coord.y)); } -uint32_t get_noc_multicast_encoding(const CoreCoord& start, const CoreCoord& end) { - return NOC_MULTICAST_ENCODING(start.x, start.y, end.x, end.y); -} - // EnqueueReadBufferCommandSection EnqueueReadBufferCommand::EnqueueReadBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Buffer& buffer, void* dst, SystemMemoryManager& manager, @@ -60,6 +56,7 @@ EnqueueReadBufferCommand::EnqueueReadBufferCommand( uint32_t src_page_index, std::optional pages_to_read) : command_queue_id(command_queue_id), + noc_index(noc_index), dst(dst), manager(manager), buffer(buffer), @@ -89,7 +86,7 @@ void EnqueueReadShardedBufferCommand::add_prefetch_relay(HugepageDeviceCommand& const CoreCoord physical_core = this->buffer.device()->physical_core_from_logical_core(this->core, this->buffer.core_type()); command.add_prefetch_relay_linear( - get_noc_unicast_encoding(physical_core), padded_page_size * this->pages_to_read, this->bank_base_address); + this->device->get_noc_unicast_encoding(this->noc_index, physical_core), padded_page_size * this->pages_to_read, this->bank_base_address); } void EnqueueReadBufferCommand::process() { @@ -125,6 +122,7 @@ void EnqueueReadBufferCommand::process() { EnqueueWriteBufferCommand::EnqueueWriteBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, const Buffer& buffer, const void* src, SystemMemoryManager& manager, @@ -135,6 +133,7 @@ EnqueueWriteBufferCommand::EnqueueWriteBufferCommand( uint32_t dst_page_index, std::optional pages_to_write) : command_queue_id(command_queue_id), + noc_index(noc_index), manager(manager), issue_wait(issue_wait), src(src), @@ -211,7 +210,7 @@ void EnqueueWriteShardedBufferCommand::add_dispatch_write(HugepageDeviceCommand& this->buffer.device()->physical_core_from_logical_core(this->core, this->buffer.core_type()); bool flush_prefetch = true; command_sequence.add_dispatch_write_linear( - flush_prefetch, 0, get_noc_unicast_encoding(physical_core), this->bank_base_address, data_size_bytes); + flush_prefetch, 0, this->device->get_noc_unicast_encoding(this->noc_index, physical_core), this->bank_base_address, data_size_bytes); } void EnqueueWriteShardedBufferCommand::add_buffer_data(HugepageDeviceCommand& command_sequence) { @@ -287,10 +286,12 @@ void EnqueueWriteBufferCommand::process() { EnqueueProgramCommand::EnqueueProgramCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Program& program, SystemMemoryManager& manager, uint32_t expected_num_workers_completed) : command_queue_id(command_queue_id), + noc_index(noc_index), manager(manager), expected_num_workers_completed(expected_num_workers_completed), program(program) { @@ -462,13 +463,12 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { // can make a vector of unicast encodings here CoreCoord physical_core = device->physical_core_from_logical_core(core_coord, kernel->get_kernel_core_type()); - uint32_t unicast_noc_encoding = get_noc_unicast_encoding(physical_core); const auto& runtime_args_data = kernel->runtime_args(core_coord); unique_rt_args_data[processor_idx].emplace_back(kernel->runtime_args_data(core_coord)); // 2, 17, could be differnet len here unique_sub_cmds[processor_idx].emplace_back( - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding}); + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, physical_core)}); unique_rt_data_and_sizes[processor_idx].emplace_back( runtime_args_data.data(), runtime_args_data.size() * sizeof(uint32_t)); unique_max_runtime_args_len[processor_idx] = @@ -496,12 +496,11 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { for (auto& core_coord : kernel->logical_cores()) { // can make a vector of unicast encodings here CoreCoord physical_core = device->ethernet_core_from_logical_core(core_coord); - uint32_t unicast_noc_encoding = get_noc_unicast_encoding(physical_core); unicast_sub_cmd.emplace_back( - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding}); + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, physical_core)}); } } else { - vector> dst_noc_multicast_info = + vector> dst_noc_multicast_info = extract_dst_noc_multicast_info>( device, kernel->logical_coreranges(), kernel->get_kernel_core_type()); common_sub_cmds[kernel_id].emplace>( @@ -511,7 +510,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { multicast_sub_cmd.reserve(dst_noc_multicast_info.size()); for (const auto& mcast_dests : dst_noc_multicast_info) { multicast_sub_cmd.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = mcast_dests.first, .num_mcast_dests = mcast_dests.second}); + .noc_xy_addr = this->device->get_noc_multicast_encoding(this->noc_index, std::get(mcast_dests.first)), .num_mcast_dests = mcast_dests.second}); } } } @@ -634,7 +633,6 @@ void EnqueueProgramCommand::assemble_device_commands() { for (const CoreRange& core_range : circular_buffers_unique_coreranges) { const CoreCoord physical_start = device->worker_core_from_logical_core(core_range.start); const CoreCoord physical_end = device->worker_core_from_logical_core(core_range.end); - const uint32_t dst_noc_multicast_encoding = get_noc_multicast_encoding(physical_start, physical_end); const uint32_t num_receivers = core_range.size(); auto& cb_config_payload = cb_config_payloads[i]; @@ -659,7 +657,7 @@ void EnqueueProgramCommand::assemble_device_commands() { } } multicast_cb_config_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = dst_noc_multicast_encoding, .num_mcast_dests = (uint32_t)core_range.size()}); + .noc_xy_addr = this->device->get_noc_multicast_encoding(this->noc_index, CoreRange(physical_start, physical_end)), .num_mcast_dests = (uint32_t)core_range.size()}); multicast_cb_config_data.emplace_back( cb_config_payload.data(), (max_base_index + UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG) * sizeof(uint32_t)); @@ -683,7 +681,7 @@ void EnqueueProgramCommand::assemble_device_commands() { for (int buffer_idx = 0; buffer_idx < program.program_transfer_info.kernel_bins.size(); buffer_idx++) { const auto& kg_transfer_info = program.program_transfer_info.kernel_bins[buffer_idx]; for (int kernel_idx = 0; kernel_idx < kg_transfer_info.dst_base_addrs.size(); kernel_idx++) { - for (const pair& dst_noc_info : kg_transfer_info.dst_noc_info) { + for (const pair& dst_noc_info : kg_transfer_info.dst_noc_info) { cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE; cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE; } @@ -709,9 +707,8 @@ void EnqueueProgramCommand::assemble_device_commands() { CoreCoord physical_end = device->physical_core_from_logical_core(core_range.end, kernel_group.get_core_type()); - uint32_t dst_noc_multicast_encoding = get_noc_multicast_encoding(physical_start, physical_end); multicast_go_signal_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = dst_noc_multicast_encoding, .num_mcast_dests = (uint32_t)core_range.size()}); + .noc_xy_addr = this->device->get_noc_multicast_encoding(this->noc_index, CoreRange(physical_start, physical_end)), .num_mcast_dests = (uint32_t)core_range.size()}); multicast_go_signal_data.emplace_back(launch_message_data, go_signal_sizeB); } } @@ -733,9 +730,8 @@ void EnqueueProgramCommand::assemble_device_commands() { for (auto y = core_range.start.y; y <= core_range.end.y; y++) { CoreCoord physical_coord = device->physical_core_from_logical_core(CoreCoord({x, y}), kernel_group.get_core_type()); - uint32_t dst_noc_unicast_encoding = get_noc_unicast_encoding(physical_coord); unicast_go_signal_sub_cmds.emplace_back( - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = dst_noc_unicast_encoding}); + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, physical_coord)}); unicast_go_signal_data.emplace_back(launch_message_data, go_signal_sizeB); } } @@ -768,7 +764,7 @@ void EnqueueProgramCommand::assemble_device_commands() { for (const auto& dst_noc_info : transfer_info.dst_noc_info) { num_packed_cmds += 1; multicast_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = dst_noc_info.first, .num_mcast_dests = dst_noc_info.second}); + .noc_xy_addr =this->device->get_noc_multicast_encoding(this->noc_index, std::get(dst_noc_info.first)), .num_mcast_dests = dst_noc_info.second}); sem_data.emplace_back(transfer_info.data.data(), transfer_info.data.size() * sizeof(uint32_t)); } } @@ -796,7 +792,7 @@ void EnqueueProgramCommand::assemble_device_commands() { for (const auto& dst_noc_info : transfer_info.dst_noc_info) { num_packed_cmds += 1; unicast_sub_cmds.emplace_back( - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = dst_noc_info.first}); + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr =this->device->get_noc_unicast_encoding(this->noc_index, std::get(dst_noc_info.first))}); sem_data.emplace_back(transfer_info.data.data(), transfer_info.data.size() * sizeof(uint32_t)); } } @@ -828,11 +824,22 @@ void EnqueueProgramCommand::assemble_device_commands() { for (int buffer_idx = 0; buffer_idx < program.program_transfer_info.kernel_bins.size(); buffer_idx++) { const auto& kg_transfer_info = program.program_transfer_info.kernel_bins[buffer_idx]; for (int kernel_idx = 0; kernel_idx < kg_transfer_info.dst_base_addrs.size(); kernel_idx++) { - for (const pair& dst_noc_info : kg_transfer_info.dst_noc_info) { + for (const pair& dst_noc_info : kg_transfer_info.dst_noc_info) { + uint32_t noc_encoding; + std::visit( + [&](auto&& cores) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + noc_encoding = this->device->get_noc_multicast_encoding(this->noc_index, cores); + } else { + noc_encoding = this->device->get_noc_unicast_encoding(this->noc_index, cores); + } + }, + dst_noc_info.first); program_command_sequence.add_dispatch_write_linear( false, // flush_prefetch dst_noc_info.second, // num_mcast_dests - dst_noc_info.first, // noc_xy_addr + noc_encoding, // noc_xy_addr kg_transfer_info.dst_base_addrs[kernel_idx], align(kg_transfer_info.lengths[kernel_idx], NOC_DRAM_ALIGNMENT_BYTES)); // Difference between prefetch total relayed pages and dispatch write linear @@ -1026,12 +1033,14 @@ void EnqueueProgramCommand::process() { EnqueueRecordEventCommand::EnqueueRecordEventCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, SystemMemoryManager& manager, uint32_t event_id, uint32_t expected_num_workers_completed, bool clear_count) : command_queue_id(command_queue_id), device(device), + noc_index(noc_index), manager(manager), event_id(event_id), expected_num_workers_completed(expected_num_workers_completed), @@ -1080,7 +1089,7 @@ void EnqueueRecordEventCommand::process() { CoreCoord dispatch_physical_core = get_physical_core_coordinate(dispatch_location, core_type); unicast_sub_cmds[cq_id] = - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = get_noc_unicast_encoding(dispatch_physical_core)}; + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, dispatch_physical_core)}; event_payloads[cq_id] = {event_payload.data(), event_payload.size() * sizeof(uint32_t)}; } @@ -1209,11 +1218,12 @@ void EnqueueTerminateCommand::process() { } // HWCommandQueue section -HWCommandQueue::HWCommandQueue(Device* device, uint32_t id) : +HWCommandQueue::HWCommandQueue(Device* device, uint32_t id, NOC noc_index) : manager(device->sysmem_manager()), completion_queue_thread{} { ZoneScopedN("CommandQueue_constructor"); this->device = device; this->id = id; + this->noc_index = noc_index; this->num_entries_in_completion_q = 0; this->num_completed_completion_q_reads = 0; @@ -1340,6 +1350,7 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin auto command = EnqueueReadShardedBufferCommand( this->id, this->device, + this->noc_index, buffer, dst, this->manager, @@ -1376,6 +1387,7 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin auto command = EnqueueReadInterleavedBufferCommand( this->id, this->device, + this->noc_index, buffer, dst, this->manager, @@ -1514,6 +1526,7 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src, auto command = EnqueueWriteShardedBufferCommand( this->id, this->device, + this->noc_index, buffer, src, this->manager, @@ -1605,6 +1618,7 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src, auto command = EnqueueWriteInterleavedBufferCommand( this->id, this->device, + this->noc_index, buffer, src, this->manager, @@ -1646,7 +1660,7 @@ void HWCommandQueue::enqueue_program(Program& program, bool blocking) { // Snapshot of expected workers from previous programs, used for dispatch_wait cmd generation. uint32_t expected_workers_completed = this->manager.get_bypass_mode() ? this->trace_ctx->num_completion_worker_cores : this->expected_num_workers_completed; - auto command = EnqueueProgramCommand(this->id, this->device, program, this->manager, expected_workers_completed); + auto command = EnqueueProgramCommand(this->id, this->device, this->noc_index, program, this->manager, expected_workers_completed); this->enqueue_command(command, blocking); log_trace( @@ -1677,7 +1691,7 @@ void HWCommandQueue::enqueue_record_event(std::shared_ptr event, bool cle event->ready = true; // what does this mean??? auto command = EnqueueRecordEventCommand( - this->id, this->device, this->manager, event->event_id, this->expected_num_workers_completed, clear_count); + this->id, this->device, this->noc_index, this->manager, event->event_id, this->expected_num_workers_completed, clear_count); this->enqueue_command(command, false); if (clear_count) { @@ -2295,9 +2309,6 @@ void EnqueueProgramImpl( } void EnqueueRecordEvent(CommandQueue& cq, std::shared_ptr event) { - TT_ASSERT(event->device == nullptr, "EnqueueRecordEvent expected to be given an uninitialized event"); - TT_ASSERT(event->event_id == -1, "EnqueueRecordEvent expected to be given an uninitialized event"); - TT_ASSERT(event->cq_id == -1, "EnqueueRecordEvent expected to be given an uninitialized event"); detail::DispatchStateCheck(true); cq.run_command(CommandInterface{ diff --git a/tt_metal/impl/dispatch/command_queue.hpp b/tt_metal/impl/dispatch/command_queue.hpp index 578724880f0..9809824eab5 100644 --- a/tt_metal/impl/dispatch/command_queue.hpp +++ b/tt_metal/impl/dispatch/command_queue.hpp @@ -55,9 +55,6 @@ string EnqueueCommandTypeToString(EnqueueCommandType ctype); #define NOC_X(x) x #define NOC_Y(y) y -uint32_t get_noc_unicast_encoding(const CoreCoord& coord); -uint32_t get_noc_multicast_encoding(const CoreCoord& start, const CoreCoord& end); - class CommandQueue; class CommandInterface; @@ -74,13 +71,14 @@ class EnqueueReadBufferCommand : public Command { private: SystemMemoryManager& manager; void* dst; - uint32_t command_queue_id; CoreType dispatch_core_type; virtual void add_prefetch_relay(HugepageDeviceCommand& command) = 0; protected: Device* device; + uint32_t command_queue_id; + NOC noc_index; uint32_t expected_num_workers_completed; uint32_t src_page_index; uint32_t pages_to_read; @@ -90,6 +88,7 @@ class EnqueueReadBufferCommand : public Command { EnqueueReadBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Buffer& buffer, void* dst, SystemMemoryManager& manager, @@ -112,6 +111,7 @@ class EnqueueReadInterleavedBufferCommand : public EnqueueReadBufferCommand { EnqueueReadInterleavedBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Buffer& buffer, void* dst, SystemMemoryManager& manager, @@ -121,6 +121,7 @@ class EnqueueReadInterleavedBufferCommand : public EnqueueReadBufferCommand { EnqueueReadBufferCommand( command_queue_id, device, + noc_index, buffer, dst, manager, @@ -139,6 +140,7 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand { EnqueueReadShardedBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Buffer& buffer, void* dst, SystemMemoryManager& manager, @@ -150,6 +152,7 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand { EnqueueReadBufferCommand( command_queue_id, device, + noc_index, buffer, dst, manager, @@ -165,7 +168,6 @@ class EnqueueWriteInterleavedBufferCommand; class EnqueueWriteBufferCommand : public Command { private: SystemMemoryManager& manager; - uint32_t command_queue_id; CoreType dispatch_core_type; virtual void add_dispatch_write(HugepageDeviceCommand& command) = 0; @@ -173,6 +175,8 @@ class EnqueueWriteBufferCommand : public Command { protected: Device* device; + uint32_t command_queue_id; + NOC noc_index; const void* src; const Buffer& buffer; uint32_t expected_num_workers_completed; @@ -186,6 +190,7 @@ class EnqueueWriteBufferCommand : public Command { EnqueueWriteBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, const Buffer& buffer, const void* src, SystemMemoryManager& manager, @@ -212,6 +217,7 @@ class EnqueueWriteInterleavedBufferCommand : public EnqueueWriteBufferCommand { EnqueueWriteInterleavedBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, const Buffer& buffer, const void* src, SystemMemoryManager& manager, @@ -224,6 +230,7 @@ class EnqueueWriteInterleavedBufferCommand : public EnqueueWriteBufferCommand { EnqueueWriteBufferCommand( command_queue_id, device, + noc_index, buffer, src, manager, @@ -249,6 +256,7 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand { EnqueueWriteShardedBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, const Buffer& buffer, const void* src, SystemMemoryManager& manager, @@ -263,6 +271,7 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand { EnqueueWriteBufferCommand( command_queue_id, device, + noc_index, buffer, src, manager, @@ -282,6 +291,7 @@ class EnqueueProgramCommand : public Command { private: uint32_t command_queue_id; Device* device; + NOC noc_index; Program& program; SystemMemoryManager& manager; CoreType dispatch_core_type; @@ -302,6 +312,7 @@ class EnqueueProgramCommand : public Command { EnqueueProgramCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Program& program, SystemMemoryManager& manager, uint32_t expected_num_workers_completed); @@ -321,6 +332,7 @@ class EnqueueRecordEventCommand : public Command { private: uint32_t command_queue_id; Device* device; + NOC noc_index; SystemMemoryManager& manager; uint32_t event_id; uint32_t expected_num_workers_completed; @@ -330,6 +342,7 @@ class EnqueueRecordEventCommand : public Command { EnqueueRecordEventCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, SystemMemoryManager& manager, uint32_t event_id, uint32_t expected_num_workers_completed, @@ -474,11 +487,12 @@ struct RuntimeArgsMetadata { class HWCommandQueue { public: - HWCommandQueue(Device* device, uint32_t id); + HWCommandQueue(Device* device, uint32_t id, NOC noc_index); ~HWCommandQueue(); CoreCoord completion_queue_writer_core; + NOC noc_index; volatile bool is_dprint_server_hung(); volatile bool is_noc_hung(); diff --git a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp index a506c16df3e..8002bd01704 100644 --- a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp @@ -43,7 +43,7 @@ constexpr uint32_t is_h_variant = get_compile_time_arg_val(16); constexpr uint32_t upstream_noc_xy = uint32_t(NOC_XY_ENCODING(UPSTREAM_NOC_X, UPSTREAM_NOC_Y)); constexpr uint32_t downstream_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_NOC_X, DOWNSTREAM_NOC_Y)); constexpr uint32_t my_noc_xy = uint32_t(NOC_XY_ENCODING(MY_NOC_X, MY_NOC_Y)); -constexpr uint32_t pcie_noc_xy_encoding = uint32_t(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y)); +constexpr uint32_t pcie_noc_xy = uint32_t(NOC_XY_PCIE_ENCODING(NOC_0_X(static_cast(NOC_INDEX), noc_size_x, PCIE_NOC_X), NOC_0_Y(static_cast(NOC_INDEX), noc_size_y, PCIE_NOC_Y), NOC_INDEX)); constexpr uint32_t dispatch_cb_page_size = 1 << dispatch_cb_log_page_size; constexpr uint32_t completion_queue_end_addr = completion_queue_base_addr + completion_queue_size; @@ -141,7 +141,7 @@ void completion_queue_reserve_back(uint32_t num_pages) { FORCE_INLINE void notify_host_of_completion_queue_write_pointer() { uint64_t completion_queue_write_ptr_addr = command_queue_base_addr + HOST_CQ_COMPLETION_WRITE_PTR; - uint64_t pcie_address = get_noc_addr_helper(pcie_noc_xy_encoding, completion_queue_write_ptr_addr); // For now, we are writing to host hugepages at offset + uint64_t pcie_address = get_noc_addr_helper(pcie_noc_xy, completion_queue_write_ptr_addr); // For now, we are writing to host hugepages at offset uint32_t completion_wr_ptr_and_toggle = cq_write_interface.completion_fifo_wr_ptr | (cq_write_interface.completion_fifo_wr_toggle << 31); volatile tt_l1_ptr uint32_t* completion_wr_ptr_addr = get_cq_completion_write_ptr(); completion_wr_ptr_addr[0] = completion_wr_ptr_and_toggle; @@ -208,7 +208,7 @@ void process_write_host_h() { uint32_t npages = (xfer_size + completion_queue_page_size - 1) / completion_queue_page_size; completion_queue_reserve_back(npages); uint32_t completion_queue_write_addr = cq_write_interface.completion_fifo_wr_ptr << 4; - uint64_t host_completion_queue_write_addr = get_noc_addr_helper(pcie_noc_xy_encoding, completion_queue_write_addr); + uint64_t host_completion_queue_write_addr = get_noc_addr_helper(pcie_noc_xy, completion_queue_write_addr); // completion_queue_write_addr will never be equal to completion_queue_end_addr due to completion_queue_push_back // wrap logic so we don't need to handle this case explicitly to avoid 0 sized transactions if (completion_queue_write_addr + xfer_size > completion_queue_end_addr) { @@ -218,7 +218,7 @@ void process_write_host_h() { data_ptr += last_chunk_size; length -= last_chunk_size; xfer_size -= last_chunk_size; - host_completion_queue_write_addr = get_noc_addr_helper(pcie_noc_xy_encoding, completion_queue_write_addr); + host_completion_queue_write_addr = get_noc_addr_helper(pcie_noc_xy, completion_queue_write_addr); block_noc_writes_to_clear[rd_block_idx]+=(last_chunk_size + NOC_MAX_BURST_SIZE - 1) / NOC_MAX_BURST_SIZE; // XXXXX maybe just write the noc internal api counter } noc_async_write(data_ptr, host_completion_queue_write_addr, xfer_size); @@ -783,7 +783,6 @@ static inline bool process_cmd_d(uint32_t& cmd_ptr) { DPRINT << "cmd_wait" << ENDL(); process_wait(); break; - case CQ_DISPATCH_CMD_GO: DPRINT << "cmd_go" << ENDL(); break; diff --git a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp index f990132a60c..0124d992b2c 100644 --- a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp @@ -52,6 +52,7 @@ constexpr uint32_t is_h_variant = get_compile_time_arg_val(22); constexpr uint32_t my_noc_xy = uint32_t(NOC_XY_ENCODING(MY_NOC_X, MY_NOC_Y)); constexpr uint32_t upstream_noc_xy = uint32_t(NOC_XY_ENCODING(UPSTREAM_NOC_X, UPSTREAM_NOC_Y)); constexpr uint32_t downstream_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_NOC_X, DOWNSTREAM_NOC_Y)); +constexpr uint32_t pcie_noc_xy = uint32_t(NOC_XY_PCIE_ENCODING(NOC_0_X(static_cast(NOC_INDEX), noc_size_x, PCIE_NOC_X), NOC_0_Y(static_cast(NOC_INDEX), noc_size_y, PCIE_NOC_Y), NOC_INDEX)); constexpr uint32_t downstream_cb_page_size = 1 << downstream_cb_log_page_size; constexpr uint32_t downstream_cb_end = downstream_cb_base + (1 << downstream_cb_log_page_size) * downstream_cb_pages; constexpr uint32_t prefetch_q_end = prefetch_q_base + prefetch_q_size; @@ -146,7 +147,7 @@ void read_from_pcie(volatile tt_l1_ptr prefetch_q_entry_type *& prefetch_q_rd_pt pcie_read_ptr = pcie_base; } - uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y), pcie_read_ptr); + uint64_t host_src_addr = get_noc_addr_helper(pcie_noc_xy, pcie_read_ptr); DPRINT << "read_from_pcie: " << fence + preamble_size << " " << pcie_read_ptr << ENDL(); noc_async_read(host_src_addr, fence + preamble_size, size); pending_read_size = size + preamble_size; diff --git a/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp b/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp deleted file mode 100644 index f77c26d9f33..00000000000 --- a/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp +++ /dev/null @@ -1,674 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -// Common prefetch code for use by _hd, _h, _d prefetch variants - -#include "dataflow_api.h" -#include "debug/dprint.h" -#include "tt_metal/impl/dispatch/kernels/cq_common.hpp" - -extern const uint32_t scratch_db_top[2]; - - -template -FORCE_INLINE -void write_downstream(uint32_t& data_ptr, - uint32_t& downstream_data_ptr, - uint32_t length) { - - uint32_t remaining = cb_end - downstream_data_ptr; - if (length > remaining) { - if (remaining > 0) { - noc_async_write(data_ptr, get_noc_addr_helper(downstream_noc_xy, downstream_data_ptr), remaining); - data_ptr += remaining; - length -= remaining; - } - downstream_data_ptr = cb_base; - } - - noc_async_write(data_ptr, get_noc_addr_helper(downstream_noc_xy, downstream_data_ptr), length); - downstream_data_ptr += length; -} - -template -FORCE_INLINE -void read_from_pcie(volatile tt_l1_ptr uint16_t *& prefetch_q_rd_ptr, - uint32_t& pending_read_size, - uint32_t& fence, - uint32_t& pcie_read_ptr, - uint32_t cmd_ptr, - uint32_t size) { - - // Wrap cmddat_q - if (fence + size + preamble_size > cmddat_q_base + cmddat_q_size) { - // only wrap if there are no commands ready, otherwise we'll leave some on the floor - // TODO: does this matter for perf? - if (cmd_ptr != fence) { - return; - } - fence = cmddat_q_base; - } - - // Wrap pcie/hugepage - if (pcie_read_ptr + size > pcie_base + pcie_size) { - pcie_read_ptr = pcie_base; - } - - uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y), pcie_read_ptr); - noc_async_read(host_src_addr, fence + preamble_size, size); - pending_read_size = size + preamble_size; - pcie_read_ptr += size; - - *prefetch_q_rd_ptr = 0; - - // Tell host we read - *(volatile tt_l1_ptr uint32_t *) prefetch_q_rd_ptr_addr = (uint32_t)prefetch_q_rd_ptr; - - prefetch_q_rd_ptr++; - - // Wrap prefetch_q - if ((uint32_t)prefetch_q_rd_ptr == prefetch_q_end) { - prefetch_q_rd_ptr = (volatile tt_l1_ptr uint16_t*)prefetch_q_base; - } -} - -// This routine can be called in 8 states based on the boolean values cmd_ready, prefetch_q_ready, read_pending: -// - !cmd_ready, !prefetch_q_ready, !read_pending: stall on prefetch_q, issue read, read barrier -// - !cmd_ready, !prefetch_q_ready, read pending: read barrier (and re-evaluate prefetch_q_ready) -// - !cmd_ready, prefetch_q_ready, !read_pending: issue read, read barrier (XXXX +issue read after?) -// - !cmd_ready, prefetch_q_ready, read_pending: read barrier, issue read -// - cmd_ready, !prefetch_q_ready, !read_pending: exit -// - cmd_ready, !prefetch_q_ready, read_pending: exit (no barrier yet) -// - cmd_ready, prefetch_q_ready, !read_pending: issue read -// - cmd_ready, prefetch_q_ready, read_pending: exit (don't add latency to the in flight request) -// -// With WH tagging of reads: -// open question: should fetcher loop on prefetch_q_ready issuing reads until !prefetch_q_ready -// - !cmd_ready, !prefetch_q_ready, !read_pending: stall on prefetch_q, issue read, read barrier -// - !cmd_ready, !prefetch_q_ready, read pending: read barrier on oldest tag -// - !cmd_ready, prefetch_q_ready, !read_pending: issue read, read barrier (XXXX +retry after?) -// - !cmd_ready, prefetch_q_ready, read_pending: issue read, read barrier on oldest tag -// - cmd_ready, !prefetch_q_ready, !read_pending: exit -// - cmd_ready, !prefetch_q_ready, read_pending: exit (no barrier yet) -// - cmd_ready, prefetch_q_ready, !read_pending: issue and tag read -// - cmd_ready, prefetch_q_ready, read_pending: issue and tag read -template -void fetch_q_get_cmds(uint32_t& fence, uint32_t& cmd_ptr, uint32_t& pcie_read_ptr) { - - static uint32_t pending_read_size = 0; - static volatile tt_l1_ptr uint16_t* prefetch_q_rd_ptr = (volatile tt_l1_ptr uint16_t*)prefetch_q_base; - - if (fence < cmd_ptr) { - DPRINT << "wrap cmd ptr1 " << fence << " " << cmd_ptr << ENDL(); - cmd_ptr = fence; - } - - bool cmd_ready = (cmd_ptr != fence); - uint32_t fetch_size = (uint32_t)*prefetch_q_rd_ptr << prefetch_q_log_minsize; - - if (fetch_size != 0 && pending_read_size == 0) { - DPRINT << "read1: " << (uint32_t)prefetch_q_rd_ptr << " " << " " << fence << " " << fetch_size << ENDL(); - read_from_pcie - (prefetch_q_rd_ptr, pending_read_size, fence, pcie_read_ptr, cmd_ptr, fetch_size); - } - if (!cmd_ready) { - if (pending_read_size != 0) { - DPRINT << "barrier" << ENDL(); - noc_async_read_barrier(); - - // wrap the cmddat_q - if (fence < cmd_ptr) { - cmd_ptr = fence; - } - - fence += pending_read_size; - pending_read_size = 0; - // After the stall, re-check the host - fetch_size = (uint32_t)*prefetch_q_rd_ptr << prefetch_q_log_minsize; - if (fetch_size != 0) { - read_from_pcie - (prefetch_q_rd_ptr, pending_read_size, fence, pcie_read_ptr, cmd_ptr, fetch_size); - } - } else { - // By here, prefetch_q_ready must be false - // Nothing to fetch, nothing pending, nothing available, stall on host - DEBUG_STATUS("HQW"); - DPRINT << "prefetcher stall" << ENDL(); - while ((fetch_size = *prefetch_q_rd_ptr) == 0); - DPRINT << "recurse" << ENDL(); - fetch_q_get_cmds(fence, cmd_ptr, pcie_read_ptr); - DEBUG_STATUS("HQD"); - } - } -} - -template -uint32_t process_debug_cmd(uint32_t cmd_ptr) { - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - uint32_t checksum = 0; - uint32_t data_start = (uint32_t)cmd + sizeof(CQPrefetchCmd); - uint32_t *data = (uint32_t *)data_start; - uint32_t size = cmd->debug.size; - - uint32_t front_size = (size <= cmddat_end - data_start) ? size : cmddat_end - data_start; - for (uint32_t i = 0; i < front_size / sizeof(uint32_t); i++) { - checksum += *data++; - } - uint32_t back_size = size - front_size; - if (back_size > 0) { - data = (uint32_t *)cmddat_base; - for (uint32_t i = 0; i < back_size / sizeof(uint32_t); i++) { - checksum += *data++; - } - } - - if (checksum != cmd->debug.checksum) { - DEBUG_STATUS("!CHK"); - ASSERT(0); - } - - return cmd->debug.stride; -} - -template -static uint32_t process_relay_inline_cmd(uint32_t cmd_ptr, - uint32_t& dispatch_data_ptr) { - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - - uint32_t length = cmd->relay_inline.length; - uint32_t data_ptr = cmd_ptr + sizeof(CQPrefetchCmd); - - uint32_t npages = (length + cb_page_size - 1) >> cb_log_page_size; - - // Assume the downstream buffer is big relative to cmddat command size that we can - // grab what we need in one chunk - cb_acquire_pages(npages); - - uint32_t remaining = cmddat_end - data_ptr; - if (cmddat_wrap_enable && length > remaining) { - // wrap cmddat - write_downstream(data_ptr, dispatch_data_ptr, remaining); - length -= remaining; - data_ptr = cmddat_base; - } - - DPRINT << my_noc_xy << " " << dispatch_noc_xy << " " << cb_base << ENDL(); - write_downstream(data_ptr, dispatch_data_ptr, length); - - // Round to nearest page - dispatch_data_ptr += (cb_page_size - (dispatch_data_ptr & (cb_page_size - 1))) & (cb_page_size - 1); - - // XXXXX - painful syncing right now? move this into get_cmds - noc_async_writes_flushed(); - cb_release_pages(npages); - - return cmd->relay_inline.stride; -} - -// This version of inline sends inline data to the dispatcher but doesn't flush the page to the dispatcher -// This is used to assemble dispatcher commands when data comes out of band, eg, reading from DRAM -// That means this command is stateful, incorrect use will be...bad -// NOTE: this routine assumes we're sending a command header and that is LESS THAN A PAGE -template -static uint32_t process_relay_inline_noflush_cmd(uint32_t cmd_ptr, - uint32_t& dispatch_data_ptr) { - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - - uint32_t length = sizeof(CQDispatchCmd); - uint32_t data_ptr = cmd_ptr + sizeof(CQPrefetchCmd); - - cb_acquire_pages(1); - if (dispatch_data_ptr == cb_end) { - dispatch_data_ptr = cb_base; - } - noc_async_write(data_ptr, get_noc_addr_helper(dispatch_noc_xy, dispatch_data_ptr), length); - dispatch_data_ptr += length; - - return CQ_PREFETCH_CMD_BARE_MIN_SIZE; -} - -template -static uint32_t write_pages_to_dispatcher(uint32_t& dispatch_data_ptr, - uint32_t& scratch_write_addr, - uint32_t& amt_to_write) { - - uint32_t page_residual_space = dispatch_cb_page_size - (dispatch_data_ptr & (dispatch_cb_page_size - 1)); - uint32_t npages = (amt_to_write - page_residual_space + dispatch_cb_page_size + extra_space - 1) / dispatch_cb_page_size; - - // Grabbing all pages at once is ok if scratch_size < 3 * dispatch_cb_block_size - if (!test_for_nonzero || npages != 0) { - cb_acquire_pages(npages); - } - - uint64_t noc_addr = get_noc_addr_helper(dispatch_noc_xy, dispatch_data_ptr); - if (dispatch_data_ptr == dispatch_cb_end) { - dispatch_data_ptr = dispatch_cb_base; - } else if (dispatch_data_ptr + amt_to_write > dispatch_cb_end) { // wrap - uint32_t last_chunk_size = dispatch_cb_end - dispatch_data_ptr; - noc_async_write(scratch_write_addr, noc_addr, last_chunk_size); - dispatch_data_ptr = dispatch_cb_base; - scratch_write_addr += last_chunk_size; - amt_to_write -= last_chunk_size; - noc_addr = get_noc_addr_helper(dispatch_noc_xy, dispatch_data_ptr); - } - - noc_async_write(scratch_write_addr, noc_addr, amt_to_write); - dispatch_data_ptr += amt_to_write; - - return npages; -} - -// This fn prefetches data from DRAM memory and writes data to the dispatch core. -// Reading from DRAM has the following characteristics: -// - latency is moderately high ~400 cycles on WH -// - DRAM bw is ~maximized when page size reaches 2K -// - for kernel dispatch, it is expected that page sizes will often be <2K -// - for buffer writing, page sizes will vary -// - writing to dispatcher works best with 4K pages (2K pages cover overhead, 4K gives perf cushion) -// - writing a 4K page takes ~32*4=128 cycles -// - writing 4 4K pages is 512 cycles, close to parity w/ the latency of DRAM -// - to hide the latency (~12% overhead), assume we need to read ~32 pages=128K, double buffered -// - in other words, we'll never achieve high efficiency and always be (somewhat) latency bound -// Algorithm does: -// - read a batch from DRAM -// - loop: read a batch from DRAM while sending to dispatcher -// - send a batch to dispatcher -// The size of the first read should be based on latency. With small page sizes -// bandwidth will be low and we'll be DRAM bound (send to dispatcher is ~free). -// With larger pages we'll get closer to a bandwidth match -// The dispatch buffer is a ring buffer. -template -uint32_t process_relay_paged_cmd(uint32_t cmd_ptr, - uint32_t& dispatch_data_ptr) { - - // This ensures that a previous cmd using the scratch buf has finished - noc_async_writes_flushed(); - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - uint32_t page_id = cmd->relay_paged.start_page; - uint32_t base_addr = cmd->relay_paged.base_addr; - uint32_t page_size = cmd->relay_paged.page_size; - uint32_t pages = cmd->relay_paged.pages; - uint32_t read_length = pages * page_size; - - InterleavedAddrGen addr_gen; - addr_gen.bank_base_address = base_addr; - addr_gen.page_size = page_size; - - // First step - read into DB0 - uint32_t scratch_read_addr = scratch_db_top[0]; - uint32_t amt_to_read = (scratch_db_half_size > read_length) ? read_length : scratch_db_half_size; - uint32_t amt_read = 0; - while (amt_to_read >= page_size) { - uint64_t noc_addr = addr_gen.get_noc_addr(page_id); // XXXX replace this w/ walking the banks to save mul on GS - noc_async_read(noc_addr, scratch_read_addr, page_size); - scratch_read_addr += page_size; - page_id++; - amt_to_read -= page_size; - amt_read += page_size; - } - noc_async_read_barrier(); - - // Second step - read into DB[x], write from DB[x], toggle x, iterate - // Writes are fast, reads are slow - uint32_t db_toggle = 0; - uint32_t scratch_write_addr; - read_length -= amt_read; - while (read_length != 0) { - // This ensures that writes from prior iteration are done - // TODO(pgk); we can do better on WH w/ tagging - noc_async_writes_flushed(); - - db_toggle ^= 1; - scratch_read_addr = scratch_db_top[db_toggle]; - scratch_write_addr = scratch_db_top[db_toggle ^ 1]; - - uint32_t amt_to_write = amt_read; - amt_to_read = (scratch_db_half_size > read_length) ? read_length : scratch_db_half_size; - amt_read = 0; - while (amt_to_read >= page_size) { - uint64_t noc_addr = addr_gen.get_noc_addr(page_id); // XXXX replace this w/ walking the banks to save mul on GS - noc_async_read(noc_addr, scratch_read_addr, page_size); - scratch_read_addr += page_size; - page_id++; - amt_to_read -= page_size; - amt_read += page_size; - } - - // Third step - write from DB - uint32_t npages = write_pages_to_dispatcher< - 0, - false, - my_noc_xy, - my_dispatch_cb_sem_id, - dispatch_noc_xy, - dispatch_cb_base, - dispatch_cb_end, - dispatch_cb_page_size>(dispatch_data_ptr, scratch_write_addr, amt_to_write); - cb_release_pages(npages); - - read_length -= amt_read; - - // TODO(pgk); we can do better on WH w/ tagging - noc_async_read_barrier(); - } - - // Third step - write from DB - scratch_write_addr = scratch_db_top[db_toggle]; - uint32_t amt_to_write = amt_read; - uint32_t npages = write_pages_to_dispatcher< - CQ_DISPATCH_CMD_SIZE, - true, - my_noc_xy, - my_dispatch_cb_sem_id, - dispatch_noc_xy, - dispatch_cb_base, - dispatch_cb_end, - dispatch_cb_page_size>(dispatch_data_ptr, scratch_write_addr, amt_to_write); - - uint32_t pad_to_page = dispatch_cb_page_size - (dispatch_data_ptr & (dispatch_cb_page_size - 1)); - dispatch_data_ptr += pad_to_page; - - // One page was acquired w/ the cmd in CMD_RELAY_INLINE_NOFLUSH - cb_release_pages(npages + 1); - - return CQ_PREFETCH_CMD_BARE_MIN_SIZE; -} - -template -uint32_t process_relay_linear_cmd(uint32_t cmd_ptr, - uint32_t& dispatch_data_ptr) { - - // This ensures that a previous cmd using the scratch buf has finished - noc_async_writes_flushed(); - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - uint32_t noc_xy_addr = cmd->relay_linear.noc_xy_addr; - uint32_t read_addr = cmd->relay_linear.addr; - uint32_t length = cmd->relay_linear.length; - uint32_t read_length = length; - - // First step - read into DB0 - uint32_t scratch_read_addr = scratch_db_top[0]; - uint32_t amt_to_read = (scratch_db_half_size > read_length) ? read_length : scratch_db_half_size; - uint64_t noc_addr = get_noc_addr_helper(noc_xy_addr, read_addr); - noc_async_read(noc_addr, scratch_read_addr, amt_to_read); - read_addr += amt_to_read; - noc_async_read_barrier(); - - // Second step - read into DB[x], write from DB[x], toggle x, iterate - // Writes are fast, reads are slow - uint32_t db_toggle = 0; - uint32_t scratch_write_addr; - read_length -= amt_to_read; - while (read_length != 0) { - // This ensures that writes from prior iteration are done - // TODO(pgk); we can do better on WH w/ tagging - noc_async_writes_flushed(); - - db_toggle ^= 1; - scratch_read_addr = scratch_db_top[db_toggle]; - scratch_write_addr = scratch_db_top[db_toggle ^ 1]; - - uint32_t amt_to_write = amt_to_read; - amt_to_read = (scratch_db_half_size > read_length) ? read_length : scratch_db_half_size; - noc_addr = get_noc_addr_helper(noc_xy_addr, read_addr); - noc_async_read(noc_addr, scratch_read_addr, amt_to_read); - read_addr += amt_to_read; - - // Third step - write from DB - uint32_t npages = write_pages_to_dispatcher< - 0, - false, - my_noc_xy, - my_dispatch_cb_sem_id, - dispatch_noc_xy, - dispatch_cb_base, - dispatch_cb_end, - dispatch_cb_page_size>(dispatch_data_ptr, scratch_write_addr, amt_to_write); - - cb_release_pages(npages); - - read_length -= amt_to_read; - - // TODO(pgk); we can do better on WH w/ tagging - noc_async_read_barrier(); - } - - // Third step - write from DB - scratch_write_addr = scratch_db_top[db_toggle]; - uint32_t amt_to_write = amt_to_read; - uint32_t npages = write_pages_to_dispatcher< - CQ_DISPATCH_CMD_SIZE, - true, - my_noc_xy, - my_dispatch_cb_sem_id, - dispatch_noc_xy, - dispatch_cb_base, - dispatch_cb_end, - dispatch_cb_page_size>(dispatch_data_ptr, scratch_write_addr, amt_to_write); - - uint32_t pad_to_page = dispatch_cb_page_size - (dispatch_data_ptr & (dispatch_cb_page_size - 1)); - dispatch_data_ptr += pad_to_page; - - // One page was acquired w/ the cmd in CMD_RELAY_INLINE_NOFLUSH - cb_release_pages(npages + 1); - - return CQ_PREFETCH_CMD_BARE_MIN_SIZE; -} - -template -uint32_t process_stall(uint32_t cmd_ptr) { - - static uint32_t count = 0; - - count++; - - DEBUG_STATUS("PSW"); - volatile tt_l1_ptr uint32_t* sem_addr = - reinterpret_cast(get_semaphore(dispatch_sync_sem_id)); - while (*sem_addr != count); - DEBUG_STATUS("PSD"); - - return CQ_PREFETCH_CMD_BARE_MIN_SIZE; -} - -template -bool process_cmd(uint32_t cmd_ptr, - uint32_t& downstream_data_ptr, - uint32_t& stride) { - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - bool done = false; - - switch (cmd->base.cmd_id) { - case CQ_PREFETCH_CMD_RELAY_LINEAR: - DPRINT << "relay linear: " << cmd_ptr << ENDL(); - stride = process_relay_linear_cmd< - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_sem_id, - downstream_cb_base, - downstream_cb_end, - downstream_cb_page_size, - scratch_db_half_size>(cmd_ptr, downstream_data_ptr); - break; - - case CQ_PREFETCH_CMD_RELAY_PAGED: - DPRINT << "relay dram page: " << cmd_ptr << ENDL(); - if (cmd->relay_paged.is_dram) { - stride = process_relay_paged_cmd< - true, - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_sem_id, - downstream_cb_base, - downstream_cb_end, - downstream_cb_page_size, - scratch_db_half_size>(cmd_ptr, downstream_data_ptr); - } else { - stride = process_relay_paged_cmd< - false, - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_sem_id, - downstream_cb_base, - downstream_cb_end, - downstream_cb_page_size, - scratch_db_half_size>(cmd_ptr, downstream_data_ptr); - } - break; - - case CQ_PREFETCH_CMD_RELAY_INLINE: - DPRINT << "inline" << ENDL(); - stride = process_relay_inline_cmd< - cmddat_wrap_enable, - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_sem_id, - cmddat_base, - cmddat_end, - downstream_cb_base, - downstream_cb_end, - downstream_cb_log_page_size, - downstream_cb_page_size>(cmd_ptr, downstream_data_ptr); - break; - - case CQ_PREFETCH_CMD_RELAY_INLINE_NOFLUSH: - DPRINT << "inline no flush" << ENDL(); - stride = process_relay_inline_noflush_cmd< - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_base, - downstream_cb_end>(cmd_ptr, downstream_data_ptr); - break; - - case CQ_PREFETCH_CMD_STALL: - DPRINT << "stall" << ENDL(); - stride = process_stall(cmd_ptr); - break; - - case CQ_PREFETCH_CMD_DEBUG: - DPRINT << "debug" << ENDL(); - stride = process_debug_cmd(cmd_ptr); - break; - - case CQ_PREFETCH_CMD_TERMINATE: - DPRINT << "terminating\n"; - done = true; - break; - - default: - DPRINT << "prefetch invalid command:" << (uint32_t)cmd->base.cmd_id << " " << cmd_ptr << " " << ENDL(); - DPRINT << HEX() << *(uint32_t*)cmd_ptr << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+1) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+2) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+3) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+4) << ENDL(); - DEBUG_STATUS("!CMD"); - ASSERT(0); - } - - return done; -} diff --git a/tt_metal/impl/program/program.cpp b/tt_metal/impl/program/program.cpp index 1edcca12168..a507e2e2337 100644 --- a/tt_metal/impl/program/program.cpp +++ b/tt_metal/impl/program/program.cpp @@ -590,16 +590,14 @@ void Program::populate_dispatch_data(Device *device) { {RISCV::ERISC, eth_l1_mem::address_map::FIRMWARE_BASE}}; auto extract_dst_noc_unicast_info = - [&device](const set &ranges, const CoreType core_type) -> vector> { + [&device](const set &ranges, const CoreType core_type) -> vector> { // This API extracts all the pairs of noc multicast encodings given a set of core ranges - vector> dst_noc_unicast_info; + vector> dst_noc_unicast_info; for (const CoreRange &core_range : ranges) { for (auto x = core_range.start.x; x <= core_range.end.x; x++) { for (auto y = core_range.start.y; y <= core_range.end.y; y++) { CoreCoord physical_coord = device->physical_core_from_logical_core(CoreCoord({x, y}), core_type); - uint32_t dst_noc_unicast_encoding = - NOC_XY_ENCODING(NOC_X(physical_coord.x), NOC_Y(physical_coord.y)); - dst_noc_unicast_info.push_back(std::make_pair(dst_noc_unicast_encoding, /*num_mcast_dests=*/0)); + dst_noc_unicast_info.push_back(std::make_pair(physical_coord, /*num_mcast_dests=*/0)); } } } @@ -613,7 +611,7 @@ void Program::populate_dispatch_data(Device *device) { // TODO: use semaphore.core_type from main if (semaphore.core_type() == CoreType::WORKER) { - vector> dst_noc_multicast_info = + vector> dst_noc_multicast_info = extract_dst_noc_multicast_info>( device, semaphore.core_range_set().ranges(), semaphore.core_type()); transfer_info_2 transfer_info = { @@ -623,7 +621,7 @@ void Program::populate_dispatch_data(Device *device) { .data = semaphore_data}; this->program_transfer_info.multicast_semaphores[semaphore.address()].push_back(transfer_info); } else if (semaphore.core_type() == CoreType::ETH) { - vector> dst_noc_unicast_info = + vector> dst_noc_unicast_info = extract_dst_noc_unicast_info(semaphore.core_range_set().ranges(), semaphore.core_type()); transfer_info_2 transfer_info = { .dst_base_addr = semaphore.address(), @@ -640,7 +638,7 @@ void Program::populate_dispatch_data(Device *device) { // Program Binaries and Go Signals // TODO: cleanup put the WORKERS and ETH logic together.. for (KernelGroup &kernel_group : this->get_kernel_groups(CoreType::WORKER)) { - vector> dst_noc_multicast_info = extract_dst_noc_multicast_info>( + vector> dst_noc_multicast_info = extract_dst_noc_multicast_info>( device, kernel_group.core_ranges.ranges(), kernel_group.get_core_type()); // So far, we don't support linking optimizations for kernel groups @@ -710,7 +708,7 @@ void Program::populate_dispatch_data(Device *device) { } } for (KernelGroup &kernel_group : this->get_kernel_groups(CoreType::ETH)) { - vector> dst_noc_unicast_info = + vector> dst_noc_unicast_info = extract_dst_noc_unicast_info(kernel_group.core_ranges.ranges(), kernel_group.get_core_type()); vector kernel_ids; diff --git a/tt_metal/impl/program/program.hpp b/tt_metal/impl/program/program.hpp index 868a9c711e1..10e33f55591 100644 --- a/tt_metal/impl/program/program.hpp +++ b/tt_metal/impl/program/program.hpp @@ -54,19 +54,16 @@ struct KernelGroup { }; template -vector> extract_dst_noc_multicast_info(Device* device, const CoreRangeContainer& ranges, const CoreType core_type) { +vector> extract_dst_noc_multicast_info(Device* device, const CoreRangeContainer& ranges, const CoreType core_type) { // This API extracts all the pairs of noc multicast encodings given a set of core ranges - vector> dst_noc_multicast_info; + vector> dst_noc_multicast_info; dst_noc_multicast_info.reserve(ranges.size()); for (const CoreRange& core_range : ranges) { CoreCoord physical_start = device->physical_core_from_logical_core(core_range.start, core_type); CoreCoord physical_end = device->physical_core_from_logical_core(core_range.end, core_type); - uint32_t dst_noc_multicast_encoding = - NOC_MULTICAST_ENCODING(physical_start.x, physical_start.y, physical_end.x, physical_end.y); - uint32_t num_receivers = core_range.size(); - dst_noc_multicast_info.push_back(std::make_pair(dst_noc_multicast_encoding, num_receivers)); + dst_noc_multicast_info.push_back(std::make_pair(CoreRange(physical_start, physical_end), num_receivers)); } return dst_noc_multicast_info; } diff --git a/tt_metal/impl/program/program_device_map.hpp b/tt_metal/impl/program/program_device_map.hpp index e5c6d5cfd5a..dc648887b13 100644 --- a/tt_metal/impl/program/program_device_map.hpp +++ b/tt_metal/impl/program/program_device_map.hpp @@ -16,9 +16,11 @@ struct transfer_info { bool linked; }; +using transfer_info_cores = std::variant; + struct transfer_info_2 { std::uint32_t dst_base_addr; - vector> dst_noc_info; // noc_encoding, num_mcast_dests + vector> dst_noc_info; // noc_encoding, num_mcast_dests bool linked; vector data; }; @@ -26,7 +28,7 @@ struct kernel_bins_transfer_info { vector dst_base_addrs; // BRISC, NCRISC, TRISC etc.. vector page_offsets; // offsets into paged buffer in DRAM vector lengths; // WriteLinear lengths - vector> dst_noc_info; // noc_encoding, num_mcast_dests + vector> dst_noc_info; // noc_encoding, num_mcast_dests bool linked; vector data; // all binaries' data for kernel group };