Skip to content

Commit

Permalink
#5337: Improve perf script for on-device collection
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed May 22, 2024
1 parent 23d2af2 commit 9841912
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
os.environ["TT_METAL_ASYNC_DEVICE_QUEUE"] = "1"

import ttnn
from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor
from ttnn import ConcatMeshToTensor
import tt_lib
from tracy import signpost

from models.demos.t3000.mixtral8x7b.tt.mixtral_common import (
prepare_inputs_ttnn,
Expand Down Expand Up @@ -56,7 +58,7 @@ def test_mixtral_model_perf(
dtype = ttnn.bfloat8_b

model_args = TtModelArgs(t3k_device_mesh.get_device(0))
model_args.n_layers = 1
model_args.n_layers = 32
tokenizer = Tokenizer(model_args.tokenizer_path)

# Clear global profiler state before starting measurements
Expand Down Expand Up @@ -101,12 +103,17 @@ def test_mixtral_model_perf(
profiler.end("TtMistral_model_setup")

# Call the function
signpost("Model warmup")
profiler.start(f"end_to_end_inference_with_compile")
run_inference(tt_model, embd, encoded_prompts, generation_start_pos, generation_length, rot_mat)
profiler.end(f"end_to_end_inference_with_compile")
profiler.print()
compile_and_iter_time = profiler.get("model_run_for_inference_0")

for device_id in t3k_device_mesh.get_device_ids():
tt_lib.device.DumpDeviceProfiler(t3k_device_mesh.get_device(device_id))

signpost("Model perf run")
profiler.clear()
profiler.start(f"end_to_end_inference")
run_inference(tt_model, embd, encoded_prompts, generation_start_pos, generation_length, rot_mat)
Expand Down Expand Up @@ -173,3 +180,8 @@ def run_inference(tt_model, embd, encoded_prompts, generation_start_pos, generat
tt_token_batch = tt_output_torch.squeeze().argmax(axis=-1)
tt_decode_input = embd(tt_token_batch).view(batch, seqlen, -1)
profiler.end(f"torch_argmax_and_embed_{i}")

profiler.start(f"deallocate_tt_tensors_{i}")
decode_input.deallocate(force=True)
tt_out.deallocate(force=True)
profiler.end(f"deallocate_tt_tensors_{i}")

0 comments on commit 9841912

Please sign in to comment.