Skip to content

Commit

Permalink
#9956: single device trace run perf gen passing
Browse files Browse the repository at this point in the history
  • Loading branch information
mo-tenstorrent committed Sep 10, 2024
1 parent bd1c108 commit bb3ecba
Showing 1 changed file with 54 additions and 33 deletions.
87 changes: 54 additions & 33 deletions tt_metal/tools/profiler/process_ops_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"GLOBAL CALL COUNT",
"DEVICE ID",
"TRACE ID",
"TRACE RUNTIME ID",
"ATTRIBUTES",
"MATH FIDELITY",
"CORE COUNT",
Expand Down Expand Up @@ -216,21 +217,26 @@ def device_ops_compare(op):

# Append device data to device ops and return the list of mapped device op ref list
def append_device_data(ops, traceReplays, deviceLogFolder):
deviceOps, hasTraceRuns = get_device_op_data(ops)
traceReplayCounts = {}
for deviceID in traceReplays:
traceReplayCounts[deviceID] = {}
for traceID in traceReplays[deviceID]:
traceReplayCounts[deviceID][traceID] = len(traceReplays[deviceID][traceID])
devicesOps, hasTraceRuns = get_device_op_data(ops)
logger.info(f"Appending device data")
deviceTimesLog = os.path.join(deviceLogFolder, PROFILER_DEVICE_SIDE_LOG)
if os.path.isfile(deviceTimesLog):
setup = device_post_proc_config.default_setup()
setup.deviceInputLog = deviceTimesLog
deviceData = import_log_run_stats(setup)
freq = deviceData["deviceInfo"]["freq"]
for device in deviceOps:
for device in devicesOps:
assert device in deviceData["devices"].keys()
deviceOpsTime = deviceData["devices"][device]["cores"]["DEVICE"]["riscs"]["TENSIX"]["ops"]
if hasTraceRuns:
generatedHostData = []
opIDHostDataDict = {}
for deviceOp in deviceOps[device]:
for deviceOp in devicesOps[device]:
opID = deviceOp["global_call_count"]
assert (
opID not in opIDHostDataDict
Expand All @@ -250,7 +256,7 @@ def append_device_data(ops, traceReplays, deviceLogFolder):
if traceID is not None:
if device in traceOps:
if traceID in traceOps[device]:
if device in traceOps[device][traceID]:
if deviceOpID in traceOps[device][traceID]:
traceReplays[device][traceID].pop(0)
traceOps[device][traceID] = set([deviceOpID])
else:
Expand All @@ -263,13 +269,16 @@ def append_device_data(ops, traceReplays, deviceLogFolder):
len(traceReplays[device][traceID]) > 0
), "Wrong trace replay count: Device has more ops than trace replay issued commands"
opIDHostDataDict[deviceOpID]["tracy_time"] = traceReplays[device][traceID][0]
generatedHostData.append(opIDHostDataDict[deviceOpID])
deviceOps[device] = generatedHostData
opIDHostDataDict[deviceOpID]["trace_runtime_id"] = (
traceReplayCounts[device][traceID] - len(traceReplays[device][traceID]) + 1
)
generatedHostData.append(copy.deepcopy(opIDHostDataDict[deviceOpID]))
devicesOps[device] = generatedHostData

if len(deviceOps[device]) != len(deviceOpsTime):
if len(devicesOps[device]) != len(deviceOpsTime):
deviceOPId = None
hostOPId = None
for deviceOp, deviceOpTime in zip(deviceOps[device], deviceOpsTime):
for deviceOp, deviceOpTime in zip(devicesOps[device], deviceOpsTime):
if len(deviceOpTime["timeseries"]) > 0:
timeID, ts, statData, risc, core = deviceOpTime["timeseries"][0]
if "zone_name" in timeID.keys() and "FW" in timeID["zone_name"]:
Expand All @@ -281,17 +290,17 @@ def append_device_data(ops, traceReplays, deviceLogFolder):

if deviceOPId and hostOPId:
assert False, (
f"Device data mismatch: Expected {len(deviceOps[device])} "
f"Device data mismatch: Expected {len(devicesOps[device])} "
f"but received {len(deviceOpsTime)} ops on device {device}. "
f"Device is showing op ID {deviceOPId} when host is showing op ID {hostOPId}"
)
else:
assert False, (
f"Device data mismatch: Expected {len(deviceOps[device])} but "
f"Device data mismatch: Expected {len(devicesOps[device])} but "
f"received {len(deviceOpsTime)} ops on device {device}"
)

for deviceOp, deviceOpTime in zip(deviceOps[device], deviceOpsTime):
for deviceOp, deviceOpTime in zip(devicesOps[device], deviceOpsTime):
cores = set()
for timeID, ts, statData, risc, core in deviceOpTime["timeseries"]:
if "zone_name" in timeID.keys() and "FW" in timeID["zone_name"]:
Expand All @@ -310,7 +319,16 @@ def append_device_data(ops, traceReplays, deviceLogFolder):
for analysis, data in deviceOp["device_time"].items():
for sample in data:
sample["duration_ns"] = sample["duration_cycles"] * 1000 / freq
return deviceOps

traceOps = {}
for device in devicesOps:
for deviceOp in devicesOps[device]:
if "trace_runtime_id" in deviceOp.keys():
deviceOp["global_call_count"] = (
deviceOp["global_call_count"] | deviceOp["trace_runtime_id"] << 16
)
traceOps[deviceOp["global_call_count"]] = deviceOp
return devicesOps, traceOps


def get_device_data_generate_report(deviceLogFolder, outputFolder, date, nameAppend):
Expand Down Expand Up @@ -408,7 +426,7 @@ def get_device_data_generate_report(deviceLogFolder, outputFolder, date, nameApp
return deviceOps


def generate_reports(ops, deviceOps, signposts, outputFolder, date, nameAppend):
def generate_reports(ops, deviceOps, traceOps, signposts, outputFolder, date, nameAppend):
logger.info(f"OPs' perf analysis is finished! Generating reports ...")
outFolder = PROFILER_OUTPUT_DIR
if outputFolder:
Expand All @@ -434,18 +452,6 @@ def generate_reports(ops, deviceOps, signposts, outputFolder, date, nameAppend):
if os.path.isfile(f"{PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG}"):
os.system(f"cp {PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG} {outFolder}")

# logger.info(f"Generating OPs yaml")
# allOpsYAMLPath = os.path.join(outFolder, f"{name}_all_ops.yaml")
# with open(allOpsYAMLPath, "w") as allOpsYAML:
# yaml.safe_dump(ops, allOpsYAML, default_flow_style=False)
# logger.info(f"OPs yaml generated at: {allOpsYAMLPath}")

# logger.info(f"Generating Device OPs yaml")
# deviceOpsYAMLPath = os.path.join(outFolder, f"{name}_devices_ops.yaml")
# with open(deviceOpsYAMLPath, "w") as deviceOpsYAML:
# yaml.safe_dump(deviceOps, deviceOpsYAML, default_flow_style=False)
# logger.info(f"Device OPs yaml generated at: {deviceOpsYAMLPath}")

logger.info(f"Generating OPs CSV")
allOpsCSVPath = os.path.join(outFolder, f"{name}.csv")
with open(allOpsCSVPath, "w") as allOpsCSV:
Expand Down Expand Up @@ -514,19 +520,26 @@ def row_compare(row):
if type(row) is str and "sp" in row:
ret = signposts[row]["tracy_time"]
elif type(row) is int:
ret = ops[row]["tracy_time"]
if row > ((1 << 16) - 1):
ret = traceOps[row]["tracy_time"]
else:
ret = ops[row]["tracy_time"]
ret = int(ret)
return ret

rowKeys = list(ops.keys()) + list(signposts.keys())
rowKeys = list(ops.keys()) + list(traceOps.keys()) + list(signposts.keys())
rowKeys.sort(key=row_compare)
childCallKeys = set()
for row in rowKeys:
if type(row) is int:
op = ops[row]
if "child_calls" in op.keys():
for childCall in op["child_calls"]:
if row > ((1 << 16) - 1):
opData = traceOps[row]
else:
opData = ops[row]
if "child_calls" in opData.keys():
for childCall in opData["child_calls"]:
childCallKeys.add(f"{childCall}_TT_HOST_FUNC [ns]")

for row in rowKeys:
rowDict = {}
if type(row) is str and "sp" in row:
Expand All @@ -538,7 +551,15 @@ def row_compare(row):
rowDict["HOST START TS"] = int(signposts[row]["tracy_time"])
elif type(row) is int:
op = row
opData = ops[op]
if op > ((1 << 16) - 1):
opData = traceOps[op]
opData["global_call_count"] = ((1 << 16) - 1) & op
else:
opData = ops[op]
opData["trace_runtime_id"] = ""
if opData["trace_id"] is None:
opData["trace_id"] = ""

for field, fieldData in opData.items():
headerField = csv_header_format(field)
if headerField in OPS_CSV_HEADER:
Expand Down Expand Up @@ -631,8 +652,8 @@ def process_ops(output_folder, name_append, date):
ops, signposts, traceReplays = import_tracy_op_logs()

if ops:
deviceOps = append_device_data(ops, traceReplays, PROFILER_LOGS_DIR)
generate_reports(ops, deviceOps, signposts, output_folder, date, name_append)
deviceOps, traceOps = append_device_data(ops, traceReplays, PROFILER_LOGS_DIR)
generate_reports(ops, deviceOps, traceOps, signposts, output_folder, date, name_append)

else:
deviceOps = get_device_data_generate_report(PROFILER_LOGS_DIR, output_folder, date, name_append)
Expand Down

0 comments on commit bb3ecba

Please sign in to comment.