Skip to content

Commit

Permalink
#9956: Fix non-trace device ops
Browse files Browse the repository at this point in the history
  • Loading branch information
mo-tenstorrent committed Oct 31, 2024
1 parent 1b75907 commit e58c1f0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 2 additions & 4 deletions models/experimental/functional_unet/tests/test_unet_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def test_unet_trace(
l1_input_tensor = ttnn.reshard(input_tensor, ttnn_model.input_sharded_memory_config, l1_input_tensor)
ttnn.execute_trace(device, tid, cq_id=0, blocking=False)
outputs.append(output_tensor.cpu(blocking=False))
ttnn.DumpDeviceProfiler(device)
ttnn.synchronize_device(device)
end = time.time()
logger.info(f"Average model performance={iterations * batch / (end-start) : .2f} fps")
Expand Down Expand Up @@ -208,14 +207,13 @@ def test_unet_trace_2cq(

ttnn.execute_trace(device, tid, cq_id=0, blocking=False)
outputs.append(output_tensor.cpu(blocking=False))

ttnn.DumpDeviceProfiler(device)

ttnn.synchronize_device(device)
end = time.time()
logger.info(f"Average model time={1000.0 * (end-start) / iterations : .2f} ms")
logger.info(f"Average model performance={iterations * batch / (end-start) : .2f} fps")

ttnn.DumpDeviceProfiler(device)

logger.info(f"Running sanity check against reference model output")
check_pcc_conv(torch_output_tensor, outputs[-1], UNET_FULL_MODEL_PCC)

Expand Down
8 changes: 6 additions & 2 deletions tt_metal/tools/profiler/process_ops_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,19 @@ def append_device_data(ops, traceReplays, logFolder):
for analysis, data in deviceOp["device_time"].items():
for sample in data:
sample["duration_ns"] = sample["duration_cycles"] * 1000 / freq

traceOps = {}

# Tag trace ops with a UID
for device in devicesOps:
for deviceOp in devicesOps[device]:
if "trace_runtime_id" in deviceOp.keys():
deviceOp["global_call_count"] = (
deviceOp["global_call_count"] | deviceOp["trace_runtime_id"] << 16
)
traceOps[deviceOp["global_call_count"]] = deviceOp
else:
# Update host reported device op with device populated version
ops[deviceOp["global_call_count"]] = deviceOp
return devicesOps, traceOps


Expand Down Expand Up @@ -539,7 +543,7 @@ def row_compare(row):
if row > ((1 << 16) - 1):
ret = traceOps[row]["tracy_time"]
else:
ret = ops[row]["tracy_time"]
ret = ops[row]["host_time"]["ns_since_start"]
ret = int(ret)
return ret

Expand Down

0 comments on commit e58c1f0

Please sign in to comment.