Skip to content

Commit

Permalink
#13454: Refactor API for MeshDevice::enable_async (#13455)
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu authored Oct 4, 2024
1 parent 824c167 commit d78ed37
Show file tree
Hide file tree
Showing 57 changed files with 111 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,8 @@ def test_FalconCausalLM_end_to_end_with_program_cache(

input_shape = [batch, seq_len]
model_config = get_model_config(model_config_str, llm_mode, input_shape, num_devices)
devices = t3k_mesh_device.get_devices()
for device in devices:
device.enable_async(async_mode)
compute_grid_size = devices[0].compute_with_storage_grid_size()
t3k_mesh_device.enable_async(async_mode)
compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ def test_FalconCausalLM_prefill_end_to_end_t3000_ci_loops_10(
input_shape = [batch, seq_len]
model_config_str = f"{data_type}-{memcfg}"
model_config = get_model_config(model_config_str, llm_mode, input_shape, num_devices)
devices = t3k_mesh_device.get_devices()
for device in devices:
device.enable_async(async_mode)
compute_grid_size = devices[0].compute_with_storage_grid_size()
t3k_mesh_device.enable_async(async_mode)
compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

Expand Down
4 changes: 1 addition & 3 deletions models/demos/t3000/falcon40b/tests/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def test_demo(
use_program_cache,
):
input_file = "models/demos/t3000/falcon40b/demo/input_data.json"
# Enable async mode
for device in t3k_mesh_device.get_devices():
device.enable_async(True)
t3k_mesh_device.enable_async(True)

generated_text, measurements = run_falcon_demo_kv(
user_input=input_file,
Expand Down
6 changes: 2 additions & 4 deletions models/demos/t3000/falcon40b/tests/test_perf_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,8 @@ def test_perf_bare_metal(

input_shape = [batch, seq_len]
model_config = get_model_config(model_config_str, llm_mode, input_shape, num_devices)
devices = t3k_mesh_device.get_devices()
for device in devices:
device.enable_async(async_mode)
compute_grid_size = devices[0].compute_with_storage_grid_size()
t3k_mesh_device.enable_async(async_mode)
compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

Expand Down
4 changes: 1 addition & 3 deletions models/demos/t3000/llama2_70b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,7 @@ def test_LlamaModel_demo(

check_mesh_device(t3k_mesh_device, model_config)

for i in t3k_mesh_device.get_device_ids():
device = t3k_mesh_device.get_device(i)
device.enable_async(True)
t3k_mesh_device.enable_async(True)

args = construct_arg(
implementation=implementation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,7 @@ def test_LlamaModel_demo(

check_mesh_device(t3k_mesh_device, model_config)

for i in t3k_mesh_device.get_device_ids():
device = t3k_mesh_device.get_device(i)
device.enable_async(True)
t3k_mesh_device.enable_async(True)

args = construct_arg(
implementation=implementation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,7 @@ def test_LlamaModel_demo(

check_mesh_device(t3k_mesh_device, model_config)

for i in t3k_mesh_device.get_device_ids():
device = t3k_mesh_device.get_device(i)
device.enable_async(True)
t3k_mesh_device.enable_async(True)

args = construct_arg(
implementation=implementation,
Expand Down
4 changes: 1 addition & 3 deletions models/demos/t3000/llama2_70b/demo/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,7 @@ def test_LlamaModel_demo(

check_mesh_device(t3k_mesh_device, model_config)

for i in t3k_mesh_device.get_device_ids():
device = t3k_mesh_device.get_device(i)
device.enable_async(True)
t3k_mesh_device.enable_async(True)

args = construct_arg(
implementation=implementation,
Expand Down
4 changes: 1 addition & 3 deletions models/demos/t3000/llama2_70b/demo/eval_t3000.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,7 @@ def test_LlamaModel_demo(

check_mesh_device(t3k_mesh_device, model_config)

for i in t3k_mesh_device.get_device_ids():
device = t3k_mesh_device.get_device(i)
device.enable_async(True)
t3k_mesh_device.enable_async(True)

args = construct_arg(
implementation=implementation,
Expand Down
5 changes: 3 additions & 2 deletions models/demos/t3000/llama2_70b/tests/test_llama_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@ def test_LlamaModel_inference(
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

for i in t3k_mesh_device.get_device_ids():
device = t3k_mesh_device.get_device(i)
t3k_mesh_device.enable_async(True)
for device_id in t3k_mesh_device.get_device_ids():
device = t3k_mesh_device.get_device(device_id)
device.enable_program_cache()

args = construct_arg(
Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/llama2_70b/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,10 @@ def test_Llama_perf_host(
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

t3k_mesh_device.enable_async(True)
for i in t3k_mesh_device.get_device_ids():
device = t3k_mesh_device.get_device(i)
device.enable_program_cache()
device.enable_async(True)
disable_compilation_reports()

run_test_LlamaModel_end_to_end(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,7 @@ def test_Llama_perf_host(

check_mesh_device(t3k_mesh_device, model_config)

for i in t3k_mesh_device.get_device_ids():
device = t3k_mesh_device.get_device(i)
device.enable_async(True)
t3k_mesh_device.enable_async(True)

disable_compilation_reports()

Expand Down
4 changes: 1 addition & 3 deletions models/demos/t3000/llama3_70b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def test_LlamaModel_demo(

check_mesh_device(t3k_mesh_device, model_config)

for i in t3k_mesh_device.get_device_ids():
device = t3k_mesh_device.get_device(i)
device.enable_async(True)
t3k_mesh_device.enable_async(True)

args = construct_arg(
implementation=implementation,
Expand Down
3 changes: 1 addition & 2 deletions models/demos/t3000/mixtral8x7b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ def test_mixtral8x7b_demo(t3k_mesh_device, use_program_cache, input_prompts, ins
if is_ci_env and instruct_weights == True:
pytest.skip("CI demo test only runs general weights to reduce CI pipeline load (both are supported)")

for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

return run_mixtral_demo(
user_input=input_prompts,
Expand Down
3 changes: 1 addition & 2 deletions models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,7 @@ def test_mixtral8x7b_demo(t3k_mesh_device, use_program_cache, input_prompts, ins
else:
batch_size = 32

for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

return run_mixtral_demo(
user_input=input_prompts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@


def test_mixtral_attention_inference(t3k_mesh_device, use_program_cache, reset_seeds):
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

pcc = 0.99
dtype = ttnn.bfloat8_b
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
)
@torch.no_grad()
def test_mixtral_attention_inference(t3k_mesh_device, use_program_cache, reset_seeds, seq_len):
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

pcc = 0.99
dtype = ttnn.bfloat8_b
Expand Down
3 changes: 1 addition & 2 deletions models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def test_mixtral_decoder_inference(t3k_mesh_device, use_program_cache, reset_see
s: sequence length
h: hidden size
"""
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

pcc = 0.99
dtype = ttnn.bfloat8_b
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def test_mixtral_decoder_inference(t3k_mesh_device, use_program_cache, reset_see
s: sequence length
h: hidden size
"""
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

pcc = 0.99
dtype = ttnn.bfloat8_b
Expand Down
3 changes: 1 addition & 2 deletions models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@


def test_mixtral_mlp_inference(t3k_mesh_device, use_program_cache, reset_seeds):
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

# Specify different dtypes for each feedForward weights
dtypes = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
),
)
def test_mixtral_mlp_inference(t3k_mesh_device, use_program_cache, reset_seeds, seq_len):
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

# Specify different dtypes for each feedForward weights
dtypes = {
Expand Down
3 changes: 1 addition & 2 deletions models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def forward(self, x):
),
)
def test_mixtral_model_inference(t3k_mesh_device, use_program_cache, reset_seeds, batch):
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

valid_pcc = 0.97
dtype = ttnn.bfloat8_b
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def test_mixtral_model_inference_CI(t3k_mesh_device, use_program_cache, reset_se
if is_ci_env:
os.environ["MIXTRAL_REF_OUTPUT_PATH"] = "/mnt/MLPerf/tt_dnn-models/Mistral/Mixtral-8x7B-v0.1/prefill/"

for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

n_layers = 32

Expand Down
3 changes: 1 addition & 2 deletions models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@


def test_mixtral_moe_inference(t3k_mesh_device, use_program_cache, reset_seeds):
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

pcc = 0.99
iterations = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
),
)
def test_mixtral_moe_inference(t3k_mesh_device, use_program_cache, reset_seeds, seq_len):
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

pcc = 0.99
iterations = 1
Expand Down
6 changes: 2 additions & 4 deletions models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def test_mixtral_model_perf(
reset_seeds,
is_ci_env,
):
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

if not is_ci_env: # Enable tracy signpost support in local runs only
from tracy import signpost
Expand Down Expand Up @@ -168,8 +167,7 @@ def test_mixtral_model_with_prefill_perf(
reset_seeds,
is_ci_env,
):
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

if not is_ci_env: # Enable tracy signpost support in local runs only
from tracy import signpost
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ def test_mixtral_perplexity(
llm_mode == "decode"
), "Only decode mode is supported for now" # TODO Add prefill support when it reaches main

for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

# Adjust the batch size based on the max prefill length
if max_seq_len >= 16 * 1024:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@


def test_mixtral_rms_norm_inference(t3k_mesh_device, use_program_cache, reset_seeds):
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

dtype = ttnn.bfloat8_b

Expand Down
3 changes: 1 addition & 2 deletions models/demos/t3000/mixtral8x7b/tests/test_mixtral_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def forward(self, x):
def test_mixtral_model_inference(
t3k_mesh_device, use_program_cache, reset_seeds, iterations, expected_top1, expected_top5
):
for device in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device).enable_async(True)
t3k_mesh_device.enable_async(True)

# TODO Currently topk test is supporting decode-only mode. Add prefill support.

Expand Down
4 changes: 1 addition & 3 deletions models/demos/tg/llama3_70b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,7 @@ def test_LlamaModel_demo(

check_mesh_device(mesh_device, model_config)

for i in mesh_device.get_device_ids():
device = mesh_device.get_device(i)
device.enable_async(True)
mesh_device.enable_async(True)

args = construct_arg(
implementation=implementation,
Expand Down
4 changes: 1 addition & 3 deletions models/demos/tg/llama3_70b/tests/test_llama_demo_nightly.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ def test_llama3_tg_nightly_demo(
check_mesh_device(mesh_device, model_config)

# TODO: Renable when issue #11089 is resolved
for i in mesh_device.get_device_ids():
device = mesh_device.get_device(i)
device.enable_async(True)
mesh_device.enable_async(True)

args = construct_arg(
implementation=implementation,
Expand Down
2 changes: 1 addition & 1 deletion models/demos/tg/llama3_70b/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ def test_Llama_perf_host(
)

check_mesh_device(mesh_device, model_config)
mesh_device.enable_async(True)

for device in mesh_device.get_devices():
device.enable_program_cache()
device.enable_async(True)
disable_compilation_reports()

run_test_LlamaModel_end_to_end(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def test_falcon_attention(
torch_model,
enable_async,
):
for device in mesh_device.get_device_ids():
mesh_device.get_device(device).enable_async(enable_async)
mesh_device.enable_async(enable_async)

torch.manual_seed(0)
batch = device_batch_size * mesh_device.get_num_devices()
Expand Down Expand Up @@ -190,5 +189,4 @@ def test_falcon_attention(
pytorch_layer_present[1].squeeze(1), tt_layer_present[1].to(pytorch_layer_present[1].dtype), expected_pcc
)

for device in mesh_device.get_device_ids():
mesh_device.get_device(device).enable_async(False)
mesh_device.enable_async(False)
Loading

0 comments on commit d78ed37

Please sign in to comment.