From b76a20fd7e0bbd80a26bea260fdb15b8f4ecd264 Mon Sep 17 00:00:00 2001 From: Mo Date: Fri, 6 Sep 2024 19:08:12 +0000 Subject: [PATCH] #9956: Include and process trace activity in tracy --- tt_metal/impl/device/device.cpp | 9 ++ tt_metal/tools/profiler/process_ops_logs.py | 135 ++++++++++++++------ tt_metal/tools/profiler/tt_metal_tracy.hpp | 25 ++++ ttnn/cpp/ttnn/operations/core/core.cpp | 7 + 4 files changed, 139 insertions(+), 37 deletions(-) create mode 100644 tt_metal/tools/profiler/tt_metal_tracy.hpp diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 2775a4ac380..e08b6c47acf 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -21,6 +21,7 @@ #include "noc/noc_parameters.h" #include "tt_metal/impl/device/device_pool.hpp" #include "tt_metal/detail/persistent_kernel_cache.hpp" +#include "tt_metal/tools/profiler/tt_metal_tracy.hpp" #include "llrt/hal.hpp" namespace tt { @@ -2488,6 +2489,8 @@ bool Device::using_slow_dispatch() const { } void Device::begin_trace(const uint8_t cq_id, const uint32_t tid) { + ZoneScoped; + TracyTTMetalBeginTrace(this->id(), tid); TT_FATAL(this->trace_buffer_pool_.count(tid) == 0, "Trace already exists for tid {} on device", tid); TT_FATAL(!this->hw_command_queues_[cq_id]->tid.has_value(), "CQ {} is already being used for tracing tid {}", (uint32_t)cq_id, tid); this->EnableAllocs(); @@ -2497,6 +2500,8 @@ void Device::begin_trace(const uint8_t cq_id, const uint32_t tid) { } void Device::end_trace(const uint8_t cq_id, const uint32_t tid) { + ZoneScoped; + TracyTTMetalEndTrace(this->id(), tid); TT_FATAL(this->hw_command_queues_[cq_id]->tid == tid, "CQ {} is not being used for tracing tid {}", (uint32_t)cq_id, tid); TT_FATAL(this->trace_buffer_pool_.count(tid) > 0, "Trace instance {} must exist on device", tid); this->hw_command_queues_[cq_id]->record_end(); @@ -2513,6 +2518,8 @@ void Device::end_trace(const uint8_t cq_id, const uint32_t tid) { } void Device::replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking) { + ZoneScoped; + TracyTTMetalReplayTrace(this->id(), tid); constexpr bool check = false; TT_FATAL(this->trace_buffer_pool_.count(tid) > 0, "Trace instance {} must exist on device" , tid); if constexpr (check) { @@ -2526,6 +2533,8 @@ void Device::replay_trace(const uint8_t cq_id, const uint32_t tid, const bool bl } void Device::release_trace(const uint32_t tid) { + ZoneScoped; + TracyTTMetalReleaseTrace(this->id(), tid); uint32_t erased = this->trace_buffer_pool_.erase(tid); // Only enable allocations once all captured traces are released if (this->trace_buffer_pool_.empty()) { diff --git a/tt_metal/tools/profiler/process_ops_logs.py b/tt_metal/tools/profiler/process_ops_logs.py index 3159fa4b2cb..11feab2eed3 100755 --- a/tt_metal/tools/profiler/process_ops_logs.py +++ b/tt_metal/tools/profiler/process_ops_logs.py @@ -13,6 +13,7 @@ import json import yaml from datetime import datetime +import copy import click from loguru import logger @@ -39,6 +40,7 @@ "OP TYPE", "GLOBAL CALL COUNT", "DEVICE ID", + "TRACE ID", "ATTRIBUTES", "MATH FIDELITY", "CORE COUNT", @@ -93,37 +95,66 @@ def import_tracy_op_logs(): with open(tracyOpDataLog, "r", newline="") as csvFile: opDataDicts = csv.DictReader(csvFile, delimiter=";", quotechar="`") opsData = [] + traceIDs = {} + traceReplays = {} for opDataDict in opDataDicts: opDataStr = opDataDict["MessageName"] opDataTime = opDataDict["total_ns"] if "TT_DNN" in opDataStr: - tmpStrs = opDataStr.split(" ->\n", 1) - opData = {} - if len(tmpStrs) > 1: # uncached device op, host op, or fallback op - jsonStr = tmpStrs[-1] - opData = json.loads(jsonStr) - if "op_hash" in opData.keys(): - assert "device_id" in opData.keys() - deviceID = int(opData["device_id"]) - opHash = int(opData["op_hash"]) - if deviceID in cached_ops.keys(): - cached_ops[deviceID][opHash] = opData.copy() + if "OP" in opDataStr: + tmpStrs = opDataStr.split(" ->\n", 1) + opData = {} + if len(tmpStrs) > 1: # uncached device op, host op, or fallback op + jsonStr = tmpStrs[-1] + opData = json.loads(jsonStr) + if "op_hash" in opData.keys(): + assert "device_id" in opData.keys() + deviceID = int(opData["device_id"]) + opHash = int(opData["op_hash"]) + if deviceID in cached_ops.keys(): + cached_ops[deviceID][opHash] = opData.copy() + else: + cached_ops[deviceID] = {opHash: opData.copy()} + del cached_ops[deviceID][opHash]["global_call_count"] + else: # cached device op + opDataList = opDataStr.split(":", 1)[-1].split(",") + assert len(opDataList) > 3, "Wrong cached op info format" + opCode = opDataList[0].strip() + opHash = int(opDataList[1]) + deviceID = int(opDataList[2]) + opID = int(opDataList[3]) + assert deviceID in cached_ops.keys(), "Expected hashed op info is not found" + assert opHash in cached_ops[deviceID].keys(), "Expected hashed op info is not found" + opData = cached_ops[deviceID][opHash].copy() + opData["global_call_count"] = opID + opData["tracy_time"] = opDataTime + opData["trace_id"] = "" + if deviceID in traceIDs: + opData["trace_id"] = traceIDs[deviceID] + opsData.append(opData) + elif "TRACE" in opDataStr: + IDs = opDataStr.split(":")[-1].strip().split(",") + assert len(IDs) == 2, ( + "Wrong number of IDs is provided in trace message. " + "Device and trace are the two IDs that should be provided. " + f"But IDs {IDs} were provided" + ) + deviceID = int(IDs[0].strip()) + traceID = int(IDs[1].strip()) + if "BEGIN" in opDataStr: + traceIDs[deviceID] = traceID + elif "END" in opDataStr: + assert traceIDs[deviceID] == traceID, ( + f"Wrong trace ID, device {deviceID} should finish on trace ID " + f"{traceIDs[deviceID]} but it is finishing on trace ID {traceID}" + ) + traceIDs[deviceID] = None + elif "REPLAY" in opDataStr: + replayIDTime = (traceID, opDataTime) + if deviceID in traceReplays: + traceReplays[deviceID].append(replayIDTime) else: - cached_ops[deviceID] = {opHash: opData.copy()} - del cached_ops[deviceID][opHash]["global_call_count"] - else: # cached device op - opDataList = opDataStr.split(":", 1)[-1].split(",") - assert len(opDataList) > 3, "Wrong cached op info format" - opCode = opDataList[0].strip() - opHash = int(opDataList[1]) - deviceID = int(opDataList[2]) - opID = int(opDataList[3]) - assert deviceID in cached_ops.keys(), "Expected hashed op info is not found" - assert opHash in cached_ops[deviceID].keys(), "Expected hashed op info is not found" - opData = cached_ops[deviceID][opHash].copy() - opData["global_call_count"] = opID - opData["tracy_time"] = opDataTime - opsData.append(opData) + traceReplays[deviceID] = [replayIDTime] if "TT_SIGNPOST" in opDataStr: signpostsCount += 1 @@ -160,6 +191,7 @@ def import_tracy_op_logs(): def get_device_op_data(ops): logger.info(f"Getting device ops") deviceOps = {} + hasTraceRuns = False for opID, opData in ops.items(): if "device_id" in opData.keys(): deviceID = opData["device_id"] @@ -167,6 +199,8 @@ def get_device_op_data(ops): deviceOps[deviceID] = [opData] else: deviceOps[deviceID].append(opData) + if "trace_id" in opData.keys() and opData["trace_id"] is not None: + hasTraceRuns = True def device_ops_compare(op): return int(op["global_call_count"]) @@ -174,12 +208,12 @@ def device_ops_compare(op): for deviceID in deviceOps: deviceOps[deviceID].sort(key=device_ops_compare) - return deviceOps + return deviceOps, hasTraceRuns # Append device data to device ops and return the list of mapped device op ref list def append_device_data(ops, deviceLogFolder): - deviceOps = get_device_op_data(ops) + deviceOps, hasTraceRuns = get_device_op_data(ops) logger.info(f"Appending device data") deviceTimesLog = os.path.join(deviceLogFolder, PROFILER_DEVICE_SIDE_LOG) if os.path.isfile(deviceTimesLog): @@ -190,6 +224,27 @@ def append_device_data(ops, deviceLogFolder): for device in deviceOps: assert device in deviceData["devices"].keys() deviceOpsTime = deviceData["devices"][device]["cores"]["DEVICE"]["riscs"]["TENSIX"]["ops"] + if hasTraceRuns: + generatedHostData = [] + opIDHostDataDict = {} + for deviceOp in deviceOps[device]: + opID = deviceOp["global_call_count"] + assert ( + opID not in opIDHostDataDict + ), f"Host op ID cannot be repeated: op ID {opID} was reported twice by the host" + opIDHostDataDict[opID] = copy.deepcopy(deviceOp) + + for deviceOpTime in deviceOpsTime: + if len(deviceOpTime["timeseries"]) > 0: + timeID, ts, statData, risc, core = deviceOpTime["timeseries"][0] + assert "run_host_id" in timeID.keys(), "Device op ID missing: Device data must provide op ID" + deviceOpID = timeID["run_host_id"] + assert ( + deviceOpID in opIDHostDataDict + ), f"Device op ID not present: Device op ID {deviceOpID} not present in host data" + generatedHostData.append(opIDHostDataDict[deviceOpID]) + deviceOps[device] = generatedHostData + if len(deviceOps[device]) != len(deviceOpsTime): deviceOPId = None hostOPId = None @@ -204,21 +259,27 @@ def append_device_data(ops, deviceLogFolder): break if deviceOPId and hostOPId: - assert ( - False - ), f"Device data mismatch: Expected {len(deviceOps[device])} but received {len(deviceOpsTime)} ops on device {device}. Device is showing op ID {deviceOPId} when host is showing op ID {hostOPId}" + assert False, ( + f"Device data mismatch: Expected {len(deviceOps[device])} " + f"but received {len(deviceOpsTime)} ops on device {device}. " + f"Device is showing op ID {deviceOPId} when host is showing op ID {hostOPId}" + ) else: - assert ( - True - ), f"Device data mismatch: Expected {len(deviceOps[device])} but received {len(deviceOpsTime)} ops on device {device}" + assert False, ( + f"Device data mismatch: Expected {len(deviceOps[device])} but " + f"received {len(deviceOpsTime)} ops on device {device}" + ) + for deviceOp, deviceOpTime in zip(deviceOps[device], deviceOpsTime): cores = set() for timeID, ts, statData, risc, core in deviceOpTime["timeseries"]: if "zone_name" in timeID.keys() and "FW" in timeID["zone_name"]: if "run_host_id" in timeID.keys(): - assert ( - timeID["run_host_id"] == deviceOp["global_call_count"] - ), f"op id {timeID['run_host_id']} reproted by device is not matching assigned op id {deviceOp['global_call_count']}" + assert timeID["run_host_id"] == deviceOp["global_call_count"], ( + f"Device and host op ID mismatch: " + f"op id {timeID['run_host_id']} reproted by device is " + f"not matching assigned op id {deviceOp['global_call_count']}" + ) if core not in cores: cores.add(core) deviceOp["core_usage"] = {"count": len(cores), "cores": [str(core) for core in cores]} diff --git a/tt_metal/tools/profiler/tt_metal_tracy.hpp b/tt_metal/tools/profiler/tt_metal_tracy.hpp new file mode 100644 index 00000000000..57045c4733f --- /dev/null +++ b/tt_metal/tools/profiler/tt_metal_tracy.hpp @@ -0,0 +1,25 @@ +#if defined(TRACY_ENABLE) + +#define TracyTTMetalBeginTrace( device_id, trace_id ) \ + std::string trace_message = fmt::format("`TT_DNN_TRACE_BEGIN: {}, {}`", device_id, trace_id); \ + TracyMessage(trace_message.c_str(), trace_message.size()); + +#define TracyTTMetalEndTrace( device_id, trace_id ) \ + std::string trace_message = fmt::format("`TT_DNN_TRACE_END: {}, {}`", device_id, trace_id); \ + TracyMessage(trace_message.c_str(), trace_message.size()); + +#define TracyTTMetalReplayTrace( device_id, trace_id ) \ + std::string trace_message = fmt::format("`TT_DNN_TRACE_REPLAY: {}, {}`", device_id, trace_id); \ + TracyMessage(trace_message.c_str(), trace_message.size()); + +#define TracyTTMetalReleaseTrace( device_id, trace_id ) \ + std::string trace_message = fmt::format("`TT_DNN_TRACE_RELEASE: {}, {}`", device_id, trace_id); \ + TracyMessage(trace_message.c_str(), trace_message.size()); +#else + +#define TracyTTMetalBeginTrace( device_id, trace_id ) +#define TracyTTMetalEndTrace( device_id, trace_id ) +#define TracyTTMetalReplayTrace( device_id, trace_id ) +#define TracyTTMetalRelkaseTrace( device_id, trace_id ) + +#endif diff --git a/ttnn/cpp/ttnn/operations/core/core.cpp b/ttnn/cpp/ttnn/operations/core/core.cpp index 71b13efb34e..855c6b9f580 100644 --- a/ttnn/cpp/ttnn/operations/core/core.cpp +++ b/ttnn/cpp/ttnn/operations/core/core.cpp @@ -145,16 +145,19 @@ Tensor reallocate(const Tensor& input_tensor, const std::optional& // Trace APIs - Single Device uint32_t begin_trace_capture(Device* device, const uint8_t cq_id) { + ZoneScoped; uint32_t tid = Trace::next_id(); device->push_work([device, cq_id, tid]() mutable { device->begin_trace(cq_id, tid); }); return tid; } void end_trace_capture(Device* device, const uint32_t tid, const uint8_t cq_id) { + ZoneScoped; device->push_work([device, cq_id, tid]() mutable { device->end_trace(cq_id, tid); }); } void execute_trace(Device* device, const uint32_t tid, const uint8_t cq_id, bool blocking) { + ZoneScoped; // If blocking, ensure that worker thread blocks until trace is completed device->push_work([device, cq_id, tid, blocking]() mutable { device->replay_trace(cq_id, tid, blocking); }); // If blocking, wait until worker threads have completed @@ -169,6 +172,7 @@ void release_trace(Device* device, const uint32_t tid) { // Trace APIs - Multi Device uint32_t begin_trace_capture(MeshDevice* device, const uint8_t cq_id) { + ZoneScoped; auto workers = device->get_devices(); uint32_t tid = Trace::next_id(); for (auto& worker : workers) { @@ -178,6 +182,7 @@ uint32_t begin_trace_capture(MeshDevice* device, const uint8_t cq_id) { } void end_trace_capture(MeshDevice* device, const uint32_t tid, const uint8_t cq_id) { + ZoneScoped; auto workers = device->get_devices(); for (auto& worker : workers) { worker->push_work([worker, cq_id, tid]() mutable { worker->end_trace(cq_id, tid); }); @@ -185,6 +190,7 @@ void end_trace_capture(MeshDevice* device, const uint32_t tid, const uint8_t cq_ } void execute_trace(MeshDevice* device, const uint32_t tid, const uint8_t cq_id, bool blocking) { + ZoneScoped; auto workers = device->get_devices(); // If blocking, ensure that each worker thread blocks until device-local trace is completed for (auto& worker : workers) { @@ -199,6 +205,7 @@ void execute_trace(MeshDevice* device, const uint32_t tid, const uint8_t cq_id, } void release_trace(MeshDevice* device, const uint32_t tid) { + ZoneScoped; auto workers = device->get_devices(); for (auto& worker : workers) { worker->push_work([worker, tid]() mutable { worker->release_trace(tid); });