diff --git a/tt_metal/tools/profiler/process_ops_logs.py b/tt_metal/tools/profiler/process_ops_logs.py index cab784e7a3b..2ad546825d8 100755 --- a/tt_metal/tools/profiler/process_ops_logs.py +++ b/tt_metal/tools/profiler/process_ops_logs.py @@ -129,7 +129,7 @@ def import_tracy_op_logs(logFolder): opData = cached_ops[deviceID][opHash].copy() opData["global_call_count"] = opID opData["tracy_time"] = opDataTime - opData["trace_id"] = "" + opData["trace_id"] = None if deviceID in traceIDs: opData["trace_id"] = traceIDs[deviceID] opsData.append(opData) @@ -151,11 +151,14 @@ def import_tracy_op_logs(logFolder): ) traceIDs[deviceID] = None elif "REPLAY" in opDataStr: - replayIDTime = (traceID, opDataTime) + replayIDTime = opDataTime if deviceID in traceReplays: - traceReplays[deviceID].append(replayIDTime) + if traceID in traceReplays[deviceID]: + traceReplays[deviceID][traceID].append(replayIDTime) + else: + traceReplays[deviceID][traceID] = [replayIDTime] else: - traceReplays[deviceID] = [replayIDTime] + traceReplays[deviceID] = {traceID: [replayIDTime]} if "TT_SIGNPOST" in opDataStr: signpostsCount += 1 @@ -185,7 +188,7 @@ def import_tracy_op_logs(logFolder): else: ops[parentOpID]["child_calls"] = {op["name"]: int(op["exec_time_ns"])} - return ops, signposts + return ops, signposts, traceReplays # Generate a map of OP reference list per device. @@ -225,7 +228,7 @@ def device_log_ops_compare(op): # Append device data to device ops and return the list of mapped device op ref list -def append_device_data(ops, logFolder): +def append_device_data(ops, traceReplays, logFolder): deviceOps, hasTraceRuns = get_device_op_data(ops) logger.info(f"Appending device data") deviceTimesLog = os.path.join(logFolder, PROFILER_DEVICE_SIDE_LOG) @@ -248,6 +251,7 @@ def append_device_data(ops, logFolder): ), f"Host op ID cannot be repeated: op ID {opID} was reported twice by the host" opIDHostDataDict[opID] = copy.deepcopy(deviceOp) + traceOps = {} for deviceOpTime in deviceOpsTime: if len(deviceOpTime["timeseries"]) > 0: timeID, ts, statData, risc, core = deviceOpTime["timeseries"][0] @@ -256,6 +260,23 @@ def append_device_data(ops, logFolder): assert ( deviceOpID in opIDHostDataDict ), f"Device op ID not present: Device op ID {deviceOpID} not present in host data" + traceID = opIDHostDataDict[deviceOpID]["trace_id"] + if traceID is not None: + if device in traceOps: + if traceID in traceOps[device]: + if device in traceOps[device][traceID]: + traceReplays[device][traceID].pop(0) + traceOps[device][traceID] = set([deviceOpID]) + else: + traceOps[device][traceID].add(deviceOpID) + else: + traceOps[device][traceID] = set([deviceOpID]) + else: + traceOps[device] = {traceID: set([deviceOpID])} + assert ( + len(traceReplays[device][traceID]) > 0 + ), "Wrong trace replay count: Device has more ops than trace replay issued commands" + opIDHostDataDict[deviceOpID]["tracy_time"] = traceReplays[device][traceID][0] generatedHostData.append(opIDHostDataDict[deviceOpID]) deviceOps[device] = generatedHostData @@ -629,10 +650,10 @@ def process_ops(output_folder, name_append, date): logFolder = generate_logs_folder(output_folder) reportFolder = generate_reports_folder(output_folder) - ops, signposts = import_tracy_op_logs(logFolder) + ops, signposts, traceReplays = import_tracy_op_logs(logFolder) if ops: - deviceOps = append_device_data(ops, logFolder) + deviceOps = append_device_data(ops, traceReplays, logFolder) generate_reports(ops, deviceOps, signposts, logFolder, reportFolder, date, name_append) else: