Skip to content

Commit

Permalink
#9956: single device trace run perf gen passing
Browse files Browse the repository at this point in the history
  • Loading branch information
mo-tenstorrent committed Sep 10, 2024
1 parent bd1c108 commit 2c3ed65
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 44 deletions.
99 changes: 61 additions & 38 deletions tt_metal/tools/profiler/process_ops_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"GLOBAL CALL COUNT",
"DEVICE ID",
"TRACE ID",
"TRACE RUNTIME ID",
"ATTRIBUTES",
"MATH FIDELITY",
"CORE COUNT",
Expand Down Expand Up @@ -100,10 +101,11 @@ def import_tracy_op_logs():
for opDataDict in opDataDicts:
opDataStr = opDataDict["MessageName"]
opDataTime = opDataDict["total_ns"]
if "TT_DNN" in opDataStr:
if "TT_DNN" in opDataStr or "TT_METAL" in opDataStr:
if "OP" in opDataStr:
tmpStrs = opDataStr.split(" ->\n", 1)
opData = {}
opData["trace_id"] = None
if len(tmpStrs) > 1: # uncached device op, host op, or fallback op
jsonStr = tmpStrs[-1]
opData = json.loads(jsonStr)
Expand All @@ -116,6 +118,8 @@ def import_tracy_op_logs():
else:
cached_ops[deviceID] = {opHash: opData.copy()}
del cached_ops[deviceID][opHash]["global_call_count"]
if deviceID in traceIDs:
opData["trace_id"] = traceIDs[deviceID]
else: # cached device op
opDataList = opDataStr.split(":", 1)[-1].split(",")
assert len(opDataList) > 3, "Wrong cached op info format"
Expand All @@ -127,10 +131,9 @@ def import_tracy_op_logs():
assert opHash in cached_ops[deviceID].keys(), "Expected hashed op info is not found"
opData = cached_ops[deviceID][opHash].copy()
opData["global_call_count"] = opID
if deviceID in traceIDs:
opData["trace_id"] = traceIDs[deviceID]
opData["tracy_time"] = opDataTime
opData["trace_id"] = None
if deviceID in traceIDs:
opData["trace_id"] = traceIDs[deviceID]
opsData.append(opData)
elif "TRACE" in opDataStr:
IDs = opDataStr.split(":")[-1].strip().split(",")
Expand Down Expand Up @@ -168,7 +171,7 @@ def import_tracy_op_logs():
with open(tracyOpTimesLog, "r") as csvFile:
csvReader = csv.DictReader(csvFile)
for op in csvReader:
if "TT_DNN" in op["name"]:
if "TT_DNN" in op["name"] or "TT_METAL" in op["name"]:
opID = int(op["zone_text"].split(":")[-1])
assert opID in ops.keys(), f"Op time for op {opID} must present"
ops[opID]["host_time"] = op
Expand Down Expand Up @@ -216,21 +219,26 @@ def device_ops_compare(op):

# Append device data to device ops and return the list of mapped device op ref list
def append_device_data(ops, traceReplays, deviceLogFolder):
deviceOps, hasTraceRuns = get_device_op_data(ops)
traceReplayCounts = {}
for deviceID in traceReplays:
traceReplayCounts[deviceID] = {}
for traceID in traceReplays[deviceID]:
traceReplayCounts[deviceID][traceID] = len(traceReplays[deviceID][traceID])
devicesOps, hasTraceRuns = get_device_op_data(ops)
logger.info(f"Appending device data")
deviceTimesLog = os.path.join(deviceLogFolder, PROFILER_DEVICE_SIDE_LOG)
if os.path.isfile(deviceTimesLog):
setup = device_post_proc_config.default_setup()
setup.deviceInputLog = deviceTimesLog
deviceData = import_log_run_stats(setup)
freq = deviceData["deviceInfo"]["freq"]
for device in deviceOps:
for device in devicesOps:
assert device in deviceData["devices"].keys()
deviceOpsTime = deviceData["devices"][device]["cores"]["DEVICE"]["riscs"]["TENSIX"]["ops"]
if hasTraceRuns:
generatedHostData = []
opIDHostDataDict = {}
for deviceOp in deviceOps[device]:
for deviceOp in devicesOps[device]:
opID = deviceOp["global_call_count"]
assert (
opID not in opIDHostDataDict
Expand All @@ -250,7 +258,7 @@ def append_device_data(ops, traceReplays, deviceLogFolder):
if traceID is not None:
if device in traceOps:
if traceID in traceOps[device]:
if device in traceOps[device][traceID]:
if deviceOpID in traceOps[device][traceID]:
traceReplays[device][traceID].pop(0)
traceOps[device][traceID] = set([deviceOpID])
else:
Expand All @@ -263,13 +271,16 @@ def append_device_data(ops, traceReplays, deviceLogFolder):
len(traceReplays[device][traceID]) > 0
), "Wrong trace replay count: Device has more ops than trace replay issued commands"
opIDHostDataDict[deviceOpID]["tracy_time"] = traceReplays[device][traceID][0]
generatedHostData.append(opIDHostDataDict[deviceOpID])
deviceOps[device] = generatedHostData
opIDHostDataDict[deviceOpID]["trace_runtime_id"] = (
traceReplayCounts[device][traceID] - len(traceReplays[device][traceID]) + 1
)
generatedHostData.append(copy.deepcopy(opIDHostDataDict[deviceOpID]))
devicesOps[device] = generatedHostData

if len(deviceOps[device]) != len(deviceOpsTime):
if len(devicesOps[device]) != len(deviceOpsTime):
deviceOPId = None
hostOPId = None
for deviceOp, deviceOpTime in zip(deviceOps[device], deviceOpsTime):
for deviceOp, deviceOpTime in zip(devicesOps[device], deviceOpsTime):
if len(deviceOpTime["timeseries"]) > 0:
timeID, ts, statData, risc, core = deviceOpTime["timeseries"][0]
if "zone_name" in timeID.keys() and "FW" in timeID["zone_name"]:
Expand All @@ -281,17 +292,17 @@ def append_device_data(ops, traceReplays, deviceLogFolder):

if deviceOPId and hostOPId:
assert False, (
f"Device data mismatch: Expected {len(deviceOps[device])} "
f"Device data mismatch: Expected {len(devicesOps[device])} "
f"but received {len(deviceOpsTime)} ops on device {device}. "
f"Device is showing op ID {deviceOPId} when host is showing op ID {hostOPId}"
)
else:
assert False, (
f"Device data mismatch: Expected {len(deviceOps[device])} but "
f"Device data mismatch: Expected {len(devicesOps[device])} but "
f"received {len(deviceOpsTime)} ops on device {device}"
)

for deviceOp, deviceOpTime in zip(deviceOps[device], deviceOpsTime):
for deviceOp, deviceOpTime in zip(devicesOps[device], deviceOpsTime):
cores = set()
for timeID, ts, statData, risc, core in deviceOpTime["timeseries"]:
if "zone_name" in timeID.keys() and "FW" in timeID["zone_name"]:
Expand All @@ -310,7 +321,16 @@ def append_device_data(ops, traceReplays, deviceLogFolder):
for analysis, data in deviceOp["device_time"].items():
for sample in data:
sample["duration_ns"] = sample["duration_cycles"] * 1000 / freq
return deviceOps

traceOps = {}
for device in devicesOps:
for deviceOp in devicesOps[device]:
if "trace_runtime_id" in deviceOp.keys():
deviceOp["global_call_count"] = (
deviceOp["global_call_count"] | deviceOp["trace_runtime_id"] << 16
)
traceOps[deviceOp["global_call_count"]] = deviceOp
return devicesOps, traceOps


def get_device_data_generate_report(deviceLogFolder, outputFolder, date, nameAppend):
Expand Down Expand Up @@ -408,7 +428,7 @@ def get_device_data_generate_report(deviceLogFolder, outputFolder, date, nameApp
return deviceOps


def generate_reports(ops, deviceOps, signposts, outputFolder, date, nameAppend):
def generate_reports(ops, deviceOps, traceOps, signposts, outputFolder, date, nameAppend):
logger.info(f"OPs' perf analysis is finished! Generating reports ...")
outFolder = PROFILER_OUTPUT_DIR
if outputFolder:
Expand All @@ -434,18 +454,6 @@ def generate_reports(ops, deviceOps, signposts, outputFolder, date, nameAppend):
if os.path.isfile(f"{PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG}"):
os.system(f"cp {PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG} {outFolder}")

# logger.info(f"Generating OPs yaml")
# allOpsYAMLPath = os.path.join(outFolder, f"{name}_all_ops.yaml")
# with open(allOpsYAMLPath, "w") as allOpsYAML:
# yaml.safe_dump(ops, allOpsYAML, default_flow_style=False)
# logger.info(f"OPs yaml generated at: {allOpsYAMLPath}")

# logger.info(f"Generating Device OPs yaml")
# deviceOpsYAMLPath = os.path.join(outFolder, f"{name}_devices_ops.yaml")
# with open(deviceOpsYAMLPath, "w") as deviceOpsYAML:
# yaml.safe_dump(deviceOps, deviceOpsYAML, default_flow_style=False)
# logger.info(f"Device OPs yaml generated at: {deviceOpsYAMLPath}")

logger.info(f"Generating OPs CSV")
allOpsCSVPath = os.path.join(outFolder, f"{name}.csv")
with open(allOpsCSVPath, "w") as allOpsCSV:
Expand Down Expand Up @@ -514,19 +522,26 @@ def row_compare(row):
if type(row) is str and "sp" in row:
ret = signposts[row]["tracy_time"]
elif type(row) is int:
ret = ops[row]["tracy_time"]
if row > ((1 << 16) - 1):
ret = traceOps[row]["tracy_time"]
else:
ret = ops[row]["tracy_time"]
ret = int(ret)
return ret

rowKeys = list(ops.keys()) + list(signposts.keys())
rowKeys = list(ops.keys()) + list(traceOps.keys()) + list(signposts.keys())
rowKeys.sort(key=row_compare)
childCallKeys = set()
for row in rowKeys:
if type(row) is int:
op = ops[row]
if "child_calls" in op.keys():
for childCall in op["child_calls"]:
if row > ((1 << 16) - 1):
opData = traceOps[row]
else:
opData = ops[row]
if "child_calls" in opData.keys():
for childCall in opData["child_calls"]:
childCallKeys.add(f"{childCall}_TT_HOST_FUNC [ns]")

for row in rowKeys:
rowDict = {}
if type(row) is str and "sp" in row:
Expand All @@ -538,7 +553,15 @@ def row_compare(row):
rowDict["HOST START TS"] = int(signposts[row]["tracy_time"])
elif type(row) is int:
op = row
opData = ops[op]
if op > ((1 << 16) - 1):
opData = traceOps[op]
opData["global_call_count"] = ((1 << 16) - 1) & op
else:
opData = ops[op]
opData["trace_runtime_id"] = ""
if "trac_id" not in opData.keys() or opData["trace_id"] is None:
opData["trace_id"] = ""

for field, fieldData in opData.items():
headerField = csv_header_format(field)
if headerField in OPS_CSV_HEADER:
Expand Down Expand Up @@ -631,8 +654,8 @@ def process_ops(output_folder, name_append, date):
ops, signposts, traceReplays = import_tracy_op_logs()

if ops:
deviceOps = append_device_data(ops, traceReplays, PROFILER_LOGS_DIR)
generate_reports(ops, deviceOps, signposts, output_folder, date, name_append)
deviceOps, traceOps = append_device_data(ops, traceReplays, PROFILER_LOGS_DIR)
generate_reports(ops, deviceOps, traceOps, signposts, output_folder, date, name_append)

else:
deviceOps = get_device_data_generate_report(PROFILER_LOGS_DIR, output_folder, date, name_append)
Expand Down
14 changes: 9 additions & 5 deletions tt_metal/tools/profiler/tt_metal_tracy.hpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#if defined(TRACY_ENABLE)

#define TracyTTMetalBeginTrace( device_id, trace_id ) \
std::string trace_message = fmt::format("`TT_DNN_TRACE_BEGIN: {}, {}`", device_id, trace_id); \
std::string trace_message = fmt::format("`TT_METAL_TRACE_BEGIN: {}, {}`", device_id, trace_id); \
TracyMessage(trace_message.c_str(), trace_message.size());

#define TracyTTMetalEndTrace( device_id, trace_id ) \
std::string trace_message = fmt::format("`TT_DNN_TRACE_END: {}, {}`", device_id, trace_id); \
std::string trace_message = fmt::format("`TT_METAL_TRACE_END: {}, {}`", device_id, trace_id); \
TracyMessage(trace_message.c_str(), trace_message.size());

#define TracyTTMetalReplayTrace( device_id, trace_id ) \
std::string trace_message = fmt::format("`TT_DNN_TRACE_REPLAY: {}, {}`", device_id, trace_id); \
std::string trace_message = fmt::format("`TT_METAL_TRACE_REPLAY: {}, {}`", device_id, trace_id); \
TracyMessage(trace_message.c_str(), trace_message.size());

#define TracyTTMetalReleaseTrace( device_id, trace_id ) \
std::string trace_message = fmt::format("`TT_DNN_TRACE_RELEASE: {}, {}`", device_id, trace_id); \
std::string trace_message = fmt::format("`TT_METAL_TRACE_RELEASE: {}, {}`", device_id, trace_id); \
TracyMessage(trace_message.c_str(), trace_message.size());
#else

#define TracyTTMetalBeginTrace( device_id, trace_id )
#define TracyTTMetalEndTrace( device_id, trace_id )
#define TracyTTMetalReplayTrace( device_id, trace_id )
#define TracyTTMetalRelkaseTrace( device_id, trace_id )
#define TracyTTMetalReleaseTrace( device_id, trace_id )

#endif
2 changes: 1 addition & 1 deletion ttnn/tracy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def generate_report(outFolder, nameAppend, childCalls):
if childCallsList:
childCallStr = f"-x {','.join(childCallsList)}"
subprocess.run(
f"{PROFILER_BIN_DIR / TRACY_CSVEXPROT_TOOL} -u -p TT_DNN {childCallStr} {PROFILER_LOGS_DIR / TRACY_FILE_NAME}",
f"{PROFILER_BIN_DIR / TRACY_CSVEXPROT_TOOL} -u -p TT_ {childCallStr} {PROFILER_LOGS_DIR / TRACY_FILE_NAME}",
shell=True,
check=True,
stdout=csvFile,
Expand Down

0 comments on commit 2c3ed65

Please sign in to comment.