From d78ed37c3149d3e2237725f14b3dd127f32ea6a5 Mon Sep 17 00:00:00 2001 From: Joseph Chu <122298491+cfjchu@users.noreply.github.com> Date: Thu, 3 Oct 2024 18:07:16 -0700 Subject: [PATCH] #13454: Refactor API for MeshDevice::enable_async (#13455) --- .../test_falcon_end_to_end_1_layer_t3000.py | 6 +-- ..._to_end_60_layer_t3000_prefill_10_loops.py | 6 +-- .../demos/t3000/falcon40b/tests/test_demo.py | 4 +- .../t3000/falcon40b/tests/test_perf_falcon.py | 6 +-- models/demos/t3000/llama2_70b/demo/demo.py | 4 +- .../demo/demo_continuous_batching.py | 4 +- ...emo_continuous_batching_paged_attention.py | 4 +- models/demos/t3000/llama2_70b/demo/eval.py | 4 +- .../demos/t3000/llama2_70b/demo/eval_t3000.py | 4 +- .../llama2_70b/tests/test_llama_generation.py | 5 ++- .../t3000/llama2_70b/tests/test_llama_perf.py | 2 +- .../tests/test_llama_perf_decode.py | 4 +- models/demos/t3000/llama3_70b/demo/demo.py | 4 +- models/demos/t3000/mixtral8x7b/demo/demo.py | 3 +- .../mixtral8x7b/demo/demo_with_prefill.py | 3 +- .../tests/test_mixtral_attention.py | 3 +- .../tests/test_mixtral_attention_prefill.py | 3 +- .../mixtral8x7b/tests/test_mixtral_decoder.py | 3 +- .../tests/test_mixtral_decoder_prefill.py | 3 +- .../mixtral8x7b/tests/test_mixtral_mlp.py | 3 +- .../tests/test_mixtral_mlp_prefill.py | 3 +- .../mixtral8x7b/tests/test_mixtral_model.py | 3 +- .../tests/test_mixtral_model_prefill.py | 3 +- .../mixtral8x7b/tests/test_mixtral_moe.py | 3 +- .../tests/test_mixtral_moe_prefill.py | 3 +- .../mixtral8x7b/tests/test_mixtral_perf.py | 6 +-- .../tests/test_mixtral_perplexity.py | 3 +- .../tests/test_mixtral_rms_norm.py | 3 +- .../mixtral8x7b/tests/test_mixtral_topk.py | 3 +- models/demos/tg/llama3_70b/demo/demo.py | 4 +- .../tests/test_llama_demo_nightly.py | 4 +- .../tg/llama3_70b/tests/test_llama_perf.py | 2 +- .../tests/multi_chip/test_falcon_attention.py | 6 +-- .../tests/multi_chip/test_falcon_causallm.py | 11 ++--- .../tests/multi_chip/test_falcon_decoder.py | 6 +-- .../tests/multi_chip/test_falcon_mlp.py | 6 +-- .../grok/tests/test_grok_attention.py | 3 +- .../grok/tests/test_grok_decoder.py | 3 +- .../grok/tests/test_grok_embedding.py | 3 +- .../experimental/grok/tests/test_grok_mlp.py | 3 +- .../grok/tests/test_grok_model.py | 3 +- .../experimental/grok/tests/test_grok_moe.py | 3 +- .../experimental/grok/tests/test_grok_perf.py | 3 +- .../grok/tests/test_grok_rms_norm.py | 6 +-- .../sweeps/ccl/line_all_gather.py | 3 +- .../unit_tests/operations/test_all_gather.py | 9 ++-- .../test_all_gather_TG_post_commit.py | 3 +- .../operations/test_all_gather_nightly.py | 3 +- .../test_reduce_scatter_post_commit.py | 6 +-- .../unit_tests/test_multi_device_async.py | 43 +++++++------------ .../unit_tests/test_multi_device_events.py | 5 +-- .../unit_tests/test_multi_device_trace.py | 10 ++--- .../unit_tests/test_multi_device_trace_TG.py | 10 ++--- .../unit_tests/test_multi_device_trace_tgg.py | 11 ++--- tt_metal/impl/device/mesh_device.cpp | 6 +++ tt_metal/impl/device/mesh_device.hpp | 1 + ttnn/cpp/pybind11/multi_device.hpp | 10 +++++ 57 files changed, 111 insertions(+), 183 deletions(-) diff --git a/models/demos/t3000/falcon40b/tests/ci/test_falcon_end_to_end_1_layer_t3000.py b/models/demos/t3000/falcon40b/tests/ci/test_falcon_end_to_end_1_layer_t3000.py index 6f280de234b..f256afb0b86 100644 --- a/models/demos/t3000/falcon40b/tests/ci/test_falcon_end_to_end_1_layer_t3000.py +++ b/models/demos/t3000/falcon40b/tests/ci/test_falcon_end_to_end_1_layer_t3000.py @@ -97,10 +97,8 @@ def test_FalconCausalLM_end_to_end_with_program_cache( input_shape = [batch, seq_len] model_config = get_model_config(model_config_str, llm_mode, input_shape, num_devices) - devices = t3k_mesh_device.get_devices() - for device in devices: - device.enable_async(async_mode) - compute_grid_size = devices[0].compute_with_storage_grid_size() + t3k_mesh_device.enable_async(async_mode) + compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size() if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") diff --git a/models/demos/t3000/falcon40b/tests/ci/test_falcon_end_to_end_60_layer_t3000_prefill_10_loops.py b/models/demos/t3000/falcon40b/tests/ci/test_falcon_end_to_end_60_layer_t3000_prefill_10_loops.py index 7e0cba26fa0..bd090b0fa54 100644 --- a/models/demos/t3000/falcon40b/tests/ci/test_falcon_end_to_end_60_layer_t3000_prefill_10_loops.py +++ b/models/demos/t3000/falcon40b/tests/ci/test_falcon_end_to_end_60_layer_t3000_prefill_10_loops.py @@ -88,10 +88,8 @@ def test_FalconCausalLM_prefill_end_to_end_t3000_ci_loops_10( input_shape = [batch, seq_len] model_config_str = f"{data_type}-{memcfg}" model_config = get_model_config(model_config_str, llm_mode, input_shape, num_devices) - devices = t3k_mesh_device.get_devices() - for device in devices: - device.enable_async(async_mode) - compute_grid_size = devices[0].compute_with_storage_grid_size() + t3k_mesh_device.enable_async(async_mode) + compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size() if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") diff --git a/models/demos/t3000/falcon40b/tests/test_demo.py b/models/demos/t3000/falcon40b/tests/test_demo.py index bbae745b0d6..8a133625e1d 100644 --- a/models/demos/t3000/falcon40b/tests/test_demo.py +++ b/models/demos/t3000/falcon40b/tests/test_demo.py @@ -48,9 +48,7 @@ def test_demo( use_program_cache, ): input_file = "models/demos/t3000/falcon40b/demo/input_data.json" - # Enable async mode - for device in t3k_mesh_device.get_devices(): - device.enable_async(True) + t3k_mesh_device.enable_async(True) generated_text, measurements = run_falcon_demo_kv( user_input=input_file, diff --git a/models/demos/t3000/falcon40b/tests/test_perf_falcon.py b/models/demos/t3000/falcon40b/tests/test_perf_falcon.py index 55d34d715a6..36762830e91 100644 --- a/models/demos/t3000/falcon40b/tests/test_perf_falcon.py +++ b/models/demos/t3000/falcon40b/tests/test_perf_falcon.py @@ -382,10 +382,8 @@ def test_perf_bare_metal( input_shape = [batch, seq_len] model_config = get_model_config(model_config_str, llm_mode, input_shape, num_devices) - devices = t3k_mesh_device.get_devices() - for device in devices: - device.enable_async(async_mode) - compute_grid_size = devices[0].compute_with_storage_grid_size() + t3k_mesh_device.enable_async(async_mode) + compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size() if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") diff --git a/models/demos/t3000/llama2_70b/demo/demo.py b/models/demos/t3000/llama2_70b/demo/demo.py index 52ef7da2a02..003127fc97f 100644 --- a/models/demos/t3000/llama2_70b/demo/demo.py +++ b/models/demos/t3000/llama2_70b/demo/demo.py @@ -457,9 +457,7 @@ def test_LlamaModel_demo( check_mesh_device(t3k_mesh_device, model_config) - for i in t3k_mesh_device.get_device_ids(): - device = t3k_mesh_device.get_device(i) - device.enable_async(True) + t3k_mesh_device.enable_async(True) args = construct_arg( implementation=implementation, diff --git a/models/demos/t3000/llama2_70b/demo/demo_continuous_batching.py b/models/demos/t3000/llama2_70b/demo/demo_continuous_batching.py index 2cc1353c6df..419ea2f0b96 100644 --- a/models/demos/t3000/llama2_70b/demo/demo_continuous_batching.py +++ b/models/demos/t3000/llama2_70b/demo/demo_continuous_batching.py @@ -374,9 +374,7 @@ def test_LlamaModel_demo( check_mesh_device(t3k_mesh_device, model_config) - for i in t3k_mesh_device.get_device_ids(): - device = t3k_mesh_device.get_device(i) - device.enable_async(True) + t3k_mesh_device.enable_async(True) args = construct_arg( implementation=implementation, diff --git a/models/demos/t3000/llama2_70b/demo/demo_continuous_batching_paged_attention.py b/models/demos/t3000/llama2_70b/demo/demo_continuous_batching_paged_attention.py index 4b0ccd77b0c..02a6684d838 100644 --- a/models/demos/t3000/llama2_70b/demo/demo_continuous_batching_paged_attention.py +++ b/models/demos/t3000/llama2_70b/demo/demo_continuous_batching_paged_attention.py @@ -413,9 +413,7 @@ def test_LlamaModel_demo( check_mesh_device(t3k_mesh_device, model_config) - for i in t3k_mesh_device.get_device_ids(): - device = t3k_mesh_device.get_device(i) - device.enable_async(True) + t3k_mesh_device.enable_async(True) args = construct_arg( implementation=implementation, diff --git a/models/demos/t3000/llama2_70b/demo/eval.py b/models/demos/t3000/llama2_70b/demo/eval.py index 67704734abc..f8b2c534b1b 100644 --- a/models/demos/t3000/llama2_70b/demo/eval.py +++ b/models/demos/t3000/llama2_70b/demo/eval.py @@ -379,9 +379,7 @@ def test_LlamaModel_demo( check_mesh_device(t3k_mesh_device, model_config) - for i in t3k_mesh_device.get_device_ids(): - device = t3k_mesh_device.get_device(i) - device.enable_async(True) + t3k_mesh_device.enable_async(True) args = construct_arg( implementation=implementation, diff --git a/models/demos/t3000/llama2_70b/demo/eval_t3000.py b/models/demos/t3000/llama2_70b/demo/eval_t3000.py index 56b898a2bb9..d7419c62ebd 100644 --- a/models/demos/t3000/llama2_70b/demo/eval_t3000.py +++ b/models/demos/t3000/llama2_70b/demo/eval_t3000.py @@ -186,9 +186,7 @@ def test_LlamaModel_demo( check_mesh_device(t3k_mesh_device, model_config) - for i in t3k_mesh_device.get_device_ids(): - device = t3k_mesh_device.get_device(i) - device.enable_async(True) + t3k_mesh_device.enable_async(True) args = construct_arg( implementation=implementation, diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py index 1d92fe1916f..b5b2286aa81 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py @@ -151,8 +151,9 @@ def test_LlamaModel_inference( if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") - for i in t3k_mesh_device.get_device_ids(): - device = t3k_mesh_device.get_device(i) + t3k_mesh_device.enable_async(True) + for device_id in t3k_mesh_device.get_device_ids(): + device = t3k_mesh_device.get_device(device_id) device.enable_program_cache() args = construct_arg( diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_perf.py b/models/demos/t3000/llama2_70b/tests/test_llama_perf.py index b4e81524ec6..131f8abf965 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_perf.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_perf.py @@ -314,10 +314,10 @@ def test_Llama_perf_host( if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") + t3k_mesh_device.enable_async(True) for i in t3k_mesh_device.get_device_ids(): device = t3k_mesh_device.get_device(i) device.enable_program_cache() - device.enable_async(True) disable_compilation_reports() run_test_LlamaModel_end_to_end( diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py b/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py index 32f4b857857..fbccd4176c3 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py @@ -234,9 +234,7 @@ def test_Llama_perf_host( check_mesh_device(t3k_mesh_device, model_config) - for i in t3k_mesh_device.get_device_ids(): - device = t3k_mesh_device.get_device(i) - device.enable_async(True) + t3k_mesh_device.enable_async(True) disable_compilation_reports() diff --git a/models/demos/t3000/llama3_70b/demo/demo.py b/models/demos/t3000/llama3_70b/demo/demo.py index a69571d0c19..55e740bbfec 100644 --- a/models/demos/t3000/llama3_70b/demo/demo.py +++ b/models/demos/t3000/llama3_70b/demo/demo.py @@ -100,9 +100,7 @@ def test_LlamaModel_demo( check_mesh_device(t3k_mesh_device, model_config) - for i in t3k_mesh_device.get_device_ids(): - device = t3k_mesh_device.get_device(i) - device.enable_async(True) + t3k_mesh_device.enable_async(True) args = construct_arg( implementation=implementation, diff --git a/models/demos/t3000/mixtral8x7b/demo/demo.py b/models/demos/t3000/mixtral8x7b/demo/demo.py index c81b8b8aaa4..be02adcf491 100644 --- a/models/demos/t3000/mixtral8x7b/demo/demo.py +++ b/models/demos/t3000/mixtral8x7b/demo/demo.py @@ -273,8 +273,7 @@ def test_mixtral8x7b_demo(t3k_mesh_device, use_program_cache, input_prompts, ins if is_ci_env and instruct_weights == True: pytest.skip("CI demo test only runs general weights to reduce CI pipeline load (both are supported)") - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) return run_mixtral_demo( user_input=input_prompts, diff --git a/models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py b/models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py index ea95208ea52..69d10470565 100644 --- a/models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py +++ b/models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py @@ -509,8 +509,7 @@ def test_mixtral8x7b_demo(t3k_mesh_device, use_program_cache, input_prompts, ins else: batch_size = 32 - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) return run_mixtral_demo( user_input=input_prompts, diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention.py index ba52dd6eda1..957be57c7de 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention.py @@ -20,8 +20,7 @@ def test_mixtral_attention_inference(t3k_mesh_device, use_program_cache, reset_seeds): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) pcc = 0.99 dtype = ttnn.bfloat8_b diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention_prefill.py index 5e20200b88a..d4e50a5f5cb 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention_prefill.py @@ -34,8 +34,7 @@ ) @torch.no_grad() def test_mixtral_attention_inference(t3k_mesh_device, use_program_cache, reset_seeds, seq_len): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) pcc = 0.99 dtype = ttnn.bfloat8_b diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py index 0f63a2f002f..36b035b536e 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py @@ -28,8 +28,7 @@ def test_mixtral_decoder_inference(t3k_mesh_device, use_program_cache, reset_see s: sequence length h: hidden size """ - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) pcc = 0.99 dtype = ttnn.bfloat8_b diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder_prefill.py index e1c9c5041f7..dc4b84ba4ef 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder_prefill.py @@ -35,8 +35,7 @@ def test_mixtral_decoder_inference(t3k_mesh_device, use_program_cache, reset_see s: sequence length h: hidden size """ - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) pcc = 0.99 dtype = ttnn.bfloat8_b diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py index 1fb815c6651..932a60af16f 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py @@ -19,8 +19,7 @@ def test_mixtral_mlp_inference(t3k_mesh_device, use_program_cache, reset_seeds): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) # Specify different dtypes for each feedForward weights dtypes = { diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp_prefill.py index d4735cd7950..7e952a57d98 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp_prefill.py @@ -28,8 +28,7 @@ ), ) def test_mixtral_mlp_inference(t3k_mesh_device, use_program_cache, reset_seeds, seq_len): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) # Specify different dtypes for each feedForward weights dtypes = { diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py index b9830431c16..b79b2fce80e 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py @@ -36,8 +36,7 @@ def forward(self, x): ), ) def test_mixtral_model_inference(t3k_mesh_device, use_program_cache, reset_seeds, batch): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) valid_pcc = 0.97 dtype = ttnn.bfloat8_b diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model_prefill.py index 246d7012186..b954fc7da68 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model_prefill.py @@ -45,8 +45,7 @@ def test_mixtral_model_inference_CI(t3k_mesh_device, use_program_cache, reset_se if is_ci_env: os.environ["MIXTRAL_REF_OUTPUT_PATH"] = "/mnt/MLPerf/tt_dnn-models/Mistral/Mixtral-8x7B-v0.1/prefill/" - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) n_layers = 32 diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py index fd08c8179e0..10a1e2e0bc9 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py @@ -21,8 +21,7 @@ def test_mixtral_moe_inference(t3k_mesh_device, use_program_cache, reset_seeds): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) pcc = 0.99 iterations = 1 diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe_prefill.py index 0158d99a229..5e8df333fd7 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe_prefill.py @@ -30,8 +30,7 @@ ), ) def test_mixtral_moe_inference(t3k_mesh_device, use_program_cache, reset_seeds, seq_len): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) pcc = 0.99 iterations = 1 diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py index 4ddb91b159b..25a2af5c8b4 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py @@ -51,8 +51,7 @@ def test_mixtral_model_perf( reset_seeds, is_ci_env, ): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) if not is_ci_env: # Enable tracy signpost support in local runs only from tracy import signpost @@ -168,8 +167,7 @@ def test_mixtral_model_with_prefill_perf( reset_seeds, is_ci_env, ): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) if not is_ci_env: # Enable tracy signpost support in local runs only from tracy import signpost diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py index d1c33c17683..1418f44c19d 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py @@ -283,8 +283,7 @@ def test_mixtral_perplexity( llm_mode == "decode" ), "Only decode mode is supported for now" # TODO Add prefill support when it reaches main - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) # Adjust the batch size based on the max prefill length if max_seq_len >= 16 * 1024: diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py index 3e3ee285feb..6beebeeed32 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py @@ -20,8 +20,7 @@ def test_mixtral_rms_norm_inference(t3k_mesh_device, use_program_cache, reset_seeds): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) dtype = ttnn.bfloat8_b diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_topk.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_topk.py index def00cdb2c4..2a67931c922 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_topk.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_topk.py @@ -42,8 +42,7 @@ def forward(self, x): def test_mixtral_model_inference( t3k_mesh_device, use_program_cache, reset_seeds, iterations, expected_top1, expected_top5 ): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) # TODO Currently topk test is supporting decode-only mode. Add prefill support. diff --git a/models/demos/tg/llama3_70b/demo/demo.py b/models/demos/tg/llama3_70b/demo/demo.py index 62c28d348fd..663d51478d9 100644 --- a/models/demos/tg/llama3_70b/demo/demo.py +++ b/models/demos/tg/llama3_70b/demo/demo.py @@ -444,9 +444,7 @@ def test_LlamaModel_demo( check_mesh_device(mesh_device, model_config) - for i in mesh_device.get_device_ids(): - device = mesh_device.get_device(i) - device.enable_async(True) + mesh_device.enable_async(True) args = construct_arg( implementation=implementation, diff --git a/models/demos/tg/llama3_70b/tests/test_llama_demo_nightly.py b/models/demos/tg/llama3_70b/tests/test_llama_demo_nightly.py index 73dfd015abd..54c986c6c1b 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_demo_nightly.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_demo_nightly.py @@ -93,9 +93,7 @@ def test_llama3_tg_nightly_demo( check_mesh_device(mesh_device, model_config) # TODO: Renable when issue #11089 is resolved - for i in mesh_device.get_device_ids(): - device = mesh_device.get_device(i) - device.enable_async(True) + mesh_device.enable_async(True) args = construct_arg( implementation=implementation, diff --git a/models/demos/tg/llama3_70b/tests/test_llama_perf.py b/models/demos/tg/llama3_70b/tests/test_llama_perf.py index ce97159a883..ce9a16095ce 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_perf.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_perf.py @@ -196,10 +196,10 @@ def test_Llama_perf_host( ) check_mesh_device(mesh_device, model_config) + mesh_device.enable_async(True) for device in mesh_device.get_devices(): device.enable_program_cache() - device.enable_async(True) disable_compilation_reports() run_test_LlamaModel_end_to_end( diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py index 56dc67c011a..98322a8f0c6 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py @@ -81,8 +81,7 @@ def test_falcon_attention( torch_model, enable_async, ): - for device in mesh_device.get_device_ids(): - mesh_device.get_device(device).enable_async(enable_async) + mesh_device.enable_async(enable_async) torch.manual_seed(0) batch = device_batch_size * mesh_device.get_num_devices() @@ -190,5 +189,4 @@ def test_falcon_attention( pytorch_layer_present[1].squeeze(1), tt_layer_present[1].to(pytorch_layer_present[1].dtype), expected_pcc ) - for device in mesh_device.get_device_ids(): - mesh_device.get_device(device).enable_async(False) + mesh_device.enable_async(False) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py index d84f35370cb..c52d9d4fc28 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py @@ -80,8 +80,7 @@ def test_falcon_causal_lm( enable_async, num_loops, ): - for device in mesh_device.get_device_ids(): - mesh_device.get_device(device).enable_async(enable_async) + mesh_device.enable_async(enable_async) torch.manual_seed(0) batch = device_batch_size * mesh_device.get_num_devices() @@ -247,8 +246,7 @@ def convert_to_ttnn(model, name): logger.info("Falcon CausalLM Passed!") - for device in mesh_device.get_device_ids(): - mesh_device.get_device(device).enable_async(False) + mesh_device.enable_async(False) @pytest.mark.parametrize( @@ -297,8 +295,8 @@ def test_t3k_falcon_causal_lm_with_trace( enable_async, num_loops, ): + t3k_mesh_device.enable_async(enable_async) for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(enable_async) t3k_mesh_device.get_device(device).enable_program_cache() torch.manual_seed(0) @@ -509,5 +507,4 @@ def convert_to_ttnn(model, name): logger.info("Falcon CausalLM Passed!") - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(False) + t3k_mesh_device.enable_async(False) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py index 774494adf2d..40676143a5b 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py @@ -81,8 +81,7 @@ def test_falcon_decoder( torch_model, enable_async, ): - for device in mesh_device.get_device_ids(): - mesh_device.get_device(device).enable_async(enable_async) + mesh_device.enable_async(enable_async) torch.manual_seed(0) batch = device_batch_size * mesh_device.get_num_devices() @@ -185,5 +184,4 @@ def test_falcon_decoder( pytorch_layer_present[1].squeeze(1), tt_layer_present[1].to(pytorch_layer_present[1].dtype), expected_pcc ) - for device in mesh_device.get_device_ids(): - mesh_device.get_device(device).enable_async(False) + mesh_device.enable_async(False) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py index f6d5997d6ff..c118f9a9b15 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py @@ -71,8 +71,7 @@ def test_falcon_mlp( torch_model, enable_async, ): - for device in mesh_device.get_device_ids(): - mesh_device.get_device(device).enable_async(enable_async) + mesh_device.enable_async(enable_async) torch.manual_seed(0) @@ -112,5 +111,4 @@ def test_falcon_mlp( ) logger.success(f"Passed: pcc: {pcc}, expected: {expected_pcc}") - for device in mesh_device.get_device_ids(): - mesh_device.get_device(device).enable_async(False) + mesh_device.enable_async(False) diff --git a/models/experimental/grok/tests/test_grok_attention.py b/models/experimental/grok/tests/test_grok_attention.py index 9ece57e4f96..28a09d42583 100644 --- a/models/experimental/grok/tests/test_grok_attention.py +++ b/models/experimental/grok/tests/test_grok_attention.py @@ -25,8 +25,7 @@ def test_grok_attention_inference(t3k_mesh_device, use_program_cache, reset_seeds): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) pcc = 0.99 dtype = ttnn.bfloat8_b model_args = TtModelArgs(t3k_mesh_device.get_device(0), dummy_weights=os.getenv("CI") == "true") diff --git a/models/experimental/grok/tests/test_grok_decoder.py b/models/experimental/grok/tests/test_grok_decoder.py index 2ab0b77d570..aa3b1c6ce00 100644 --- a/models/experimental/grok/tests/test_grok_decoder.py +++ b/models/experimental/grok/tests/test_grok_decoder.py @@ -28,8 +28,7 @@ def test_grok_decoder_inference(t3k_mesh_device, use_program_cache, reset_seeds) s: sequence length h: hidden size """ - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) pcc = 0.98 dtype = ttnn.bfloat8_b model_args = TtModelArgs(t3k_mesh_device.get_device(0), dummy_weights=os.getenv("CI") == "true") diff --git a/models/experimental/grok/tests/test_grok_embedding.py b/models/experimental/grok/tests/test_grok_embedding.py index 36ee9fa0d8a..24ce2117f1d 100644 --- a/models/experimental/grok/tests/test_grok_embedding.py +++ b/models/experimental/grok/tests/test_grok_embedding.py @@ -32,8 +32,7 @@ def forward(self, x): def test_grok_embedding(device, use_program_cache, reset_seeds): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) dtype = ttnn.bfloat16 model_args = TtModelArgs(device, dummy_weights=os.getenv("CI") == "true") diff --git a/models/experimental/grok/tests/test_grok_mlp.py b/models/experimental/grok/tests/test_grok_mlp.py index cdb98f53b61..d5a154ce8fe 100644 --- a/models/experimental/grok/tests/test_grok_mlp.py +++ b/models/experimental/grok/tests/test_grok_mlp.py @@ -26,8 +26,7 @@ @pytest.mark.timeout(500) def test_grok_mlp_inference(t3k_mesh_device, use_program_cache, reset_seeds): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) # Specify different dtypes for each feedForward weights dtypes = { "linear": ttnn.bfloat4_b, diff --git a/models/experimental/grok/tests/test_grok_model.py b/models/experimental/grok/tests/test_grok_model.py index b7a31a6d20e..a16d60e6f86 100644 --- a/models/experimental/grok/tests/test_grok_model.py +++ b/models/experimental/grok/tests/test_grok_model.py @@ -42,8 +42,7 @@ (1, 2, 10), ) def test_grok_model_inference(t3k_mesh_device, use_program_cache, reset_seeds, iterations, n_layers, validation_type): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) pcc = 0.97 dtype = ttnn.bfloat8_b diff --git a/models/experimental/grok/tests/test_grok_moe.py b/models/experimental/grok/tests/test_grok_moe.py index 8a228293db5..ee6c77e6553 100644 --- a/models/experimental/grok/tests/test_grok_moe.py +++ b/models/experimental/grok/tests/test_grok_moe.py @@ -27,8 +27,7 @@ @pytest.mark.timeout(600) def test_grok_moe_inference(t3k_mesh_device, use_program_cache, reset_seeds): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) pcc = 0.87 # real weights = 0.99 iterations = 1 dtype = ttnn.bfloat8_b diff --git a/models/experimental/grok/tests/test_grok_perf.py b/models/experimental/grok/tests/test_grok_perf.py index e3871177dbe..90df781dcad 100644 --- a/models/experimental/grok/tests/test_grok_perf.py +++ b/models/experimental/grok/tests/test_grok_perf.py @@ -48,8 +48,7 @@ def test_grok_model_perf( reset_seeds, ): dtype = ttnn.bfloat8_b - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) # Can use dummy_weights=True correctness is not tested, but it is much slower model_args = TtModelArgs(t3k_mesh_device.get_device(0), dummy_weights=False) diff --git a/models/experimental/grok/tests/test_grok_rms_norm.py b/models/experimental/grok/tests/test_grok_rms_norm.py index 7bd5e7d7f43..5f220b9eb2b 100644 --- a/models/experimental/grok/tests/test_grok_rms_norm.py +++ b/models/experimental/grok/tests/test_grok_rms_norm.py @@ -26,8 +26,7 @@ def test_grok_rms_norm_inference(t3k_mesh_device, use_program_cache, reset_seeds): dtype = ttnn.bfloat8_b - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) model_args = TtModelArgs(t3k_mesh_device.get_device(0), dummy_weights=os.getenv("CI") == "true") model_args.n_layers = 1 @@ -75,8 +74,7 @@ def test_grok_rms_norm_inference(t3k_mesh_device, use_program_cache, reset_seeds def test_grok_rms_norm_sharded_inference(t3k_mesh_device, use_program_cache, reset_seeds): - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_async(True) + t3k_mesh_device.enable_async(True) dtype = ttnn.bfloat8_b model_args = TtModelArgs(t3k_mesh_device.get_device(0), dummy_weights=os.getenv("CI") == "true") diff --git a/tests/sweep_framework/sweeps/ccl/line_all_gather.py b/tests/sweep_framework/sweeps/ccl/line_all_gather.py index 8604f06ed2f..52e02f34823 100644 --- a/tests/sweep_framework/sweeps/ccl/line_all_gather.py +++ b/tests/sweep_framework/sweeps/ccl/line_all_gather.py @@ -94,8 +94,7 @@ def run( device, ) -> list: t3k_mesh_device = device - for device in t3k_mesh_device.get_devices(): - device.enable_async(enable_async) + t3k_mesh_device.enable_async(enable_async) logger.info(f"Input shape: {input_shape}") logger.info(f"dim: {dim}") diff --git a/tests/ttnn/unit_tests/operations/test_all_gather.py b/tests/ttnn/unit_tests/operations/test_all_gather.py index 1e27f02ecb5..233c2b6bfe1 100644 --- a/tests/ttnn/unit_tests/operations/test_all_gather.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather.py @@ -138,8 +138,7 @@ def run_all_gather_impl( if num_iters < 1: pytest.fail("num_iters must be >= 1") # Use Async mode based on test input config - for device in mesh_device.get_devices(): - device.enable_async(enable_async) + mesh_device.enable_async(enable_async) if enable_async: logger.info(f"Using Async Mode for All Gather Op Dispatch") @@ -1280,8 +1279,7 @@ def run_all_gather_sharded_t3k( if t3k_mesh_device.get_num_devices() < num_devices: pytest.skip("Not T3000!") - for d in t3k_mesh_device.get_devices(): - d.enable_async(enable_async) + t3k_mesh_device.enable_async(enable_async) return run_all_gather_sharded( t3k_mesh_device, @@ -1332,8 +1330,7 @@ def run_all_gather_sharded_n300( if mesh_device.get_num_devices() != 2: pytest.skip("Not N300!") - for device in mesh_device.get_devices(): - device.enable_async(enable_async) + mesh_device.enable_async(enable_async) return run_all_gather_sharded( mesh_device, diff --git a/tests/ttnn/unit_tests/operations/test_all_gather_TG_post_commit.py b/tests/ttnn/unit_tests/operations/test_all_gather_TG_post_commit.py index face8c4d41c..2f940250eba 100644 --- a/tests/ttnn/unit_tests/operations/test_all_gather_TG_post_commit.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather_TG_post_commit.py @@ -62,8 +62,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( ): if len(mesh_device.get_devices()) != 32: pytest.skip("Not TG!") - for device in mesh_device.get_devices(): - device.enable_async(enable_async) + mesh_device.enable_async(enable_async) input_shape_per_chip = list(input_shape_per_all_gather) input_shape_per_chip[2 if cluster_axis == 0 else 3] //= num_devices_per_line diff --git a/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py b/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py index 511e13e30ad..fda60985e5a 100644 --- a/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py @@ -156,8 +156,7 @@ def run_line_all_gather_instances( if t3k_mesh_device.get_num_devices() != 8: pytest.skip("Not T3000!") - for device in t3k_mesh_device.get_devices(): - device.enable_async(enable_async) + t3k_mesh_device.enable_async(enable_async) logger.info(f"Input shape: {input_shape}") logger.info(f"dim: {dim}") diff --git a/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py b/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py index 8749f14f5b6..bd9e81bb9f0 100644 --- a/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py +++ b/tests/ttnn/unit_tests/operations/test_reduce_scatter_post_commit.py @@ -107,8 +107,7 @@ def run_reduce_scatter_test( if is_known_failure: pytest.skip(f"Skipping unsupported case {message}.") - for device_id in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device_id).enable_async(enable_async) + t3k_mesh_device.enable_async(enable_async) if enable_async: logger.info(f"Using Async Mode for Reduce Scatter Op Dispatch") @@ -339,8 +338,7 @@ def run_reduce_scatter_sharded_test( debug = False - for device_id in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device_id).enable_async(enable_async) + t3k_mesh_device.enable_async(enable_async) # Generate input tensors input_shard_shape = list(output_shard_shape) diff --git a/tests/ttnn/unit_tests/test_multi_device_async.py b/tests/ttnn/unit_tests/test_multi_device_async.py index abf8dcbeae1..62be5ba5b63 100644 --- a/tests/ttnn/unit_tests/test_multi_device_async.py +++ b/tests/ttnn/unit_tests/test_multi_device_async.py @@ -26,8 +26,7 @@ def test_ttnn_to_and_from_multi_device_shard(pcie_mesh_device, layout, memory_co if dtype == ttnn.bfloat8_b and layout == ttnn.ROW_MAJOR_LAYOUT: pytest.skip("Unsupported test permutation: bfloat8_b with ROW_MAJOR_LAYOUT") - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(True) + pcie_mesh_device.enable_async(True) for i in range(100): torch_tensor = torch.rand((1, 1, 256, 512), dtype=torch.bfloat16) @@ -41,8 +40,7 @@ def test_ttnn_to_and_from_multi_device_shard(pcie_mesh_device, layout, memory_co ) assert_with_pcc(torch_tensor, torch_loop_back_tensor, pcc=0.9999) - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(False) + pcie_mesh_device.enable_async(False) @pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) @@ -55,8 +53,7 @@ def test_multi_device_check_per_device_shard(pcie_mesh_device, layout, memory_co if dtype == ttnn.bfloat8_b and layout == ttnn.ROW_MAJOR_LAYOUT: pytest.skip("Unsupported test permutation: bfloat8_b with ROW_MAJOR_LAYOUT") - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(True) + pcie_mesh_device.enable_async(True) num_loops = 50 if dtype == ttnn.bfloat8_b: @@ -79,8 +76,7 @@ def test_multi_device_check_per_device_shard(pcie_mesh_device, layout, memory_co ) shard_offset += shard_size - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(False) + pcie_mesh_device.enable_async(False) @pytest.mark.parametrize("shape", [(1, 1, 512, 512), (1, 1, 1040, 1040)]) @@ -90,8 +86,7 @@ def test_multi_device_replicate(pcie_mesh_device, shape, layout, memory_config): """Test ReplicateTensorToMesh to broadcast a tensor across multiple devices""" from ttnn import ReplicateTensorToMesh, ListMeshToTensor - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(True) + pcie_mesh_device.enable_async(True) for i in range(100): full_tensor = torch.rand(shape, dtype=torch.bfloat16) @@ -111,8 +106,7 @@ def test_multi_device_replicate(pcie_mesh_device, shape, layout, memory_config): for loopback_replicated_tensor in loopback_replicated_tensors: assert torch.all(full_tensor == loopback_replicated_tensor) - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(False) + pcie_mesh_device.enable_async(False) @pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT]) @@ -123,8 +117,7 @@ def test_ttnn_to_multi_device_tilized_parallel(pcie_mesh_device, layout, memory_ from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ListMeshToTensor shard_dim = 3 - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(True) + pcie_mesh_device.enable_async(True) for loop in range(20): torch_tensor = torch.rand((8, 1, 1024, 1024), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( @@ -146,8 +139,7 @@ def test_ttnn_to_multi_device_tilized_parallel(pcie_mesh_device, layout, memory_ ) readback_tensor = torch.cat(readback_tensors, dim=shard_dim) assert torch.all(readback_tensor == torch_tensor) - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(False) + pcie_mesh_device.enable_async(False) @pytest.mark.parametrize("program_cache", [False, True]) @@ -156,8 +148,8 @@ def test_multi_device_unary_binary_op_chain(pcie_mesh_device, program_cache, sha """Multidevice API test: Running tensor-parallel multi-device chain of eltwise ops""" from ttnn import ShardTensorToMesh, ConcatMeshToTensor + pcie_mesh_device.enable_async(True) for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(True) if program_cache: pcie_mesh_device.get_device(device).enable_program_cache() @@ -188,8 +180,7 @@ def test_multi_device_unary_binary_op_chain(pcie_mesh_device, program_cache, sha ) assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden, pcc=0.98) - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(False) + pcie_mesh_device.enable_async(False) @pytest.mark.parametrize("program_cache", [False, True]) @@ -198,8 +189,8 @@ def test_multi_device_data_parallel_op_chain(pcie_mesh_device, program_cache, in """Multidevice API: Running data-parallel chain of ops with matmul""" from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh + pcie_mesh_device.enable_async(True) for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(True) if program_cache: pcie_mesh_device.get_device(device).enable_program_cache() @@ -239,8 +230,7 @@ def test_multi_device_data_parallel_op_chain(pcie_mesh_device, program_cache, in ) assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden, pcc=0.97) - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(False) + pcie_mesh_device.enable_async(False) @pytest.mark.parametrize( @@ -251,8 +241,7 @@ def test_multi_device_data_parallel_op_chain(pcie_mesh_device, program_cache, in ) @pytest.mark.parametrize("mem_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) def test_multi_device_argmax(pcie_mesh_device, layout, mem_config): - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(True) + pcie_mesh_device.enable_async(True) torch.manual_seed(0) torch_input = torch.randn(1, 1, 32, 4096) @@ -273,8 +262,7 @@ def test_multi_device_argmax(pcie_mesh_device, layout, mem_config): assert_with_pcc(tt_out_1B, reference_output, pcc=0.97) - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(False) + pcie_mesh_device.enable_async(False) @pytest.mark.parametrize("pcie_mesh_device", [2], indirect=True) @@ -322,8 +310,7 @@ def test_multi_device_explicit_dealloc(pcie_mesh_device): def test_add_1D_tensor_and_scalar(pcie_mesh_device, scalar, size): torch.manual_seed(0) - for device in pcie_mesh_device.get_device_ids(): - pcie_mesh_device.get_device(device).enable_async(True) + pcie_mesh_device.enable_async(True) torch_input_tensor = torch.rand((size,), dtype=torch.bfloat16) torch_output_tensor = torch_input_tensor + scalar diff --git a/tests/ttnn/unit_tests/test_multi_device_events.py b/tests/ttnn/unit_tests/test_multi_device_events.py index 6cd17f3e3ca..c83d5c693a2 100644 --- a/tests/ttnn/unit_tests/test_multi_device_events.py +++ b/tests/ttnn/unit_tests/test_multi_device_events.py @@ -20,8 +20,8 @@ def test_multi_device_events(t3k_mesh_device, shape): pytest.skip("This test requires multiple devices") # Enable Program Cache and Async Mode + t3k_mesh_device.enable_async(True) for device_id in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device_id).enable_async(True) t3k_mesh_device.get_device(device_id).enable_program_cache() # Preallocate activation tensors. @@ -85,5 +85,4 @@ def run_op_chain(input_0, input_1, workload_cq): ) assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden, pcc=0.96) - for device_id in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device_id).enable_async(False) + t3k_mesh_device.enable_async(False) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace.py b/tests/ttnn/unit_tests/test_multi_device_trace.py index 93281524725..1fc07590d90 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace.py @@ -27,8 +27,8 @@ def test_multi_device_single_trace(t3k_mesh_device, shape, use_all_gather, enabl pytest.skip("This test requires multiple devices") # Trace requires program cache to be enabled + t3k_mesh_device.enable_async(enable_async) for device_id in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device_id).enable_async(enable_async) t3k_mesh_device.get_device(device_id).enable_program_cache() # Preallocate activation tensors. These will be used when capturing and executing the trace @@ -124,8 +124,7 @@ def event_sync(event, record_cq, wait_cq): # Release trace buffer once workload is complete ttnn.release_trace(t3k_mesh_device, tid) - for device_id in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device_id).enable_async(False) + t3k_mesh_device.enable_async(False) @pytest.mark.parametrize( @@ -142,8 +141,8 @@ def test_multi_device_multi_trace(t3k_mesh_device, shape, use_all_gather, enable pytest.skip("This test requires multiple devices") # Trace requires program cache to be enabled + t3k_mesh_device.enable_async(enable_async) for device_id in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device_id).enable_async(enable_async) t3k_mesh_device.get_device(device_id).enable_program_cache() # Preallocate activation tensors. These will be used when capturing and executing the trace @@ -323,5 +322,4 @@ def event_sync(event, record_cq, wait_cq): ttnn.release_trace(t3k_mesh_device, tid_1) ttnn.release_trace(t3k_mesh_device, tid_2) - for device_id in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device_id).enable_async(False) + t3k_mesh_device.enable_async(False) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace_TG.py b/tests/ttnn/unit_tests/test_multi_device_trace_TG.py index 26b220a920d..fc5a056455e 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace_TG.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace_TG.py @@ -26,8 +26,8 @@ def test_multi_device_single_trace(mesh_device, shape, enable_async, enable_mult if mesh_device.get_num_devices() < 32: pytest.skip("Test is only valid on Galaxy") # Trace requires program cache to be enabled + mesh_device.enable_async(True) for device_id in mesh_device.get_device_ids(): - mesh_device.get_device(device_id).enable_async(enable_async) mesh_device.get_device(device_id).enable_program_cache() # Preallocate activation tensors. These will be used when capturing and executing the trace @@ -111,8 +111,7 @@ def event_sync(event, record_cq, wait_cq): # Release trace buffer once workload is complete ttnn.release_trace(mesh_device, tid) - for device_id in mesh_device.get_device_ids(): - mesh_device.get_device(device_id).enable_async(False) + mesh_device.enable_async(False) @pytest.mark.parametrize( @@ -129,8 +128,8 @@ def test_multi_device_multi_trace(mesh_device, shape, enable_async, enable_multi pytest.skip("Test is only valid on Galaxy") # Trace requires program cache to be enabled + mesh_device.enable_async(True) for device_id in mesh_device.get_device_ids(): - mesh_device.get_device(device_id).enable_async(enable_async) mesh_device.get_device(device_id).enable_program_cache() # Preallocate activation tensors. These will be used when capturing and executing the trace @@ -282,5 +281,4 @@ def event_sync(event, record_cq, wait_cq): ttnn.release_trace(mesh_device, tid_1) ttnn.release_trace(mesh_device, tid_2) - for device_id in mesh_device.get_device_ids(): - mesh_device.get_device(device_id).enable_async(False) + mesh_device.enable_async(False) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py b/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py index 3b269a49121..ddb354dc365 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py @@ -26,8 +26,8 @@ def test_multi_device_single_trace(mesh_device, shape, enable_async, enable_mult if mesh_device.get_num_devices() < 64: pytest.skip("Test is only valid on TGG") # Trace requires program cache to be enabled + mesh_device.enable_async(True) for device_id in mesh_device.get_device_ids(): - mesh_device.get_device(device_id).enable_async(enable_async) mesh_device.get_device(device_id).enable_program_cache() # Preallocate activation tensors. These will be used when capturing and executing the trace @@ -110,9 +110,7 @@ def event_sync(event, record_cq, wait_cq): # Release trace buffer once workload is complete ttnn.release_trace(mesh_device, tid) - - for device_id in mesh_device.get_device_ids(): - mesh_device.get_device(device_id).enable_async(False) + mesh_device.enable_async(False) @pytest.mark.parametrize( @@ -129,8 +127,8 @@ def test_multi_device_multi_trace(mesh_device, shape, enable_async, enable_multi pytest.skip("Test is only valid on TGG") # Trace requires program cache to be enabled + mesh_device.enable_async(True) for device_id in mesh_device.get_device_ids(): - mesh_device.get_device(device_id).enable_async(enable_async) mesh_device.get_device(device_id).enable_program_cache() # Preallocate activation tensors. These will be used when capturing and executing the trace @@ -282,5 +280,4 @@ def event_sync(event, record_cq, wait_cq): ttnn.release_trace(mesh_device, tid_1) ttnn.release_trace(mesh_device, tid_2) - for device_id in mesh_device.get_device_ids(): - mesh_device.get_device(device_id).enable_async(False) + mesh_device.enable_async(False) diff --git a/tt_metal/impl/device/mesh_device.cpp b/tt_metal/impl/device/mesh_device.cpp index e1cc8228c0b..dfe8926c517 100644 --- a/tt_metal/impl/device/mesh_device.cpp +++ b/tt_metal/impl/device/mesh_device.cpp @@ -401,6 +401,12 @@ bool validate_worker_modes(const std::vector& workers) { return worker_modes_match; } +void MeshDevice::enable_async(bool enable) { + for (auto device : this->devices) { + device->enable_async(enable); + } +} + std::vector get_t3k_physical_device_ids_ring() { auto& instance = SystemMesh::instance(); auto num_devices = instance.get_num_devices(); diff --git a/tt_metal/impl/device/mesh_device.hpp b/tt_metal/impl/device/mesh_device.hpp index 91f1d12f9cf..f65e095f6d8 100644 --- a/tt_metal/impl/device/mesh_device.hpp +++ b/tt_metal/impl/device/mesh_device.hpp @@ -152,6 +152,7 @@ class MeshDevice : public std::enable_shared_from_this { CoreCoord dram_grid_size() const; tt::ARCH arch() const; + void enable_async(bool enable); void close_devices(); std::shared_ptr get_view() const; diff --git a/ttnn/cpp/pybind11/multi_device.hpp b/ttnn/cpp/pybind11/multi_device.hpp index d339fec8ac8..7dc90e202ec 100644 --- a/ttnn/cpp/pybind11/multi_device.hpp +++ b/ttnn/cpp/pybind11/multi_device.hpp @@ -103,6 +103,16 @@ void py_module(py::module& module) { Returns: Arch: The arch of the first device in the device mesh. )doc") + .def( + "enable_async", + &MeshDevice::enable_async, + py::arg("enable"), + R"doc( + Enable or disable async mode across all devices in the mesh. + + Args: + enable (bool): True to enable async mode, False to disable it. + )doc") .def_property_readonly("shape", &MeshDevice::shape, R"doc( Get the shape of the device mesh.