Skip to content

Commit

Permalink
#9956: Include and process trace activity in tracy
Browse files Browse the repository at this point in the history
  • Loading branch information
mo-tenstorrent committed Oct 1, 2024
1 parent 064f80b commit b76a20f
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 37 deletions.
9 changes: 9 additions & 0 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "noc/noc_parameters.h"
#include "tt_metal/impl/device/device_pool.hpp"
#include "tt_metal/detail/persistent_kernel_cache.hpp"
#include "tt_metal/tools/profiler/tt_metal_tracy.hpp"
#include "llrt/hal.hpp"

namespace tt {
Expand Down Expand Up @@ -2488,6 +2489,8 @@ bool Device::using_slow_dispatch() const {
}

void Device::begin_trace(const uint8_t cq_id, const uint32_t tid) {
ZoneScoped;
TracyTTMetalBeginTrace(this->id(), tid);
TT_FATAL(this->trace_buffer_pool_.count(tid) == 0, "Trace already exists for tid {} on device", tid);
TT_FATAL(!this->hw_command_queues_[cq_id]->tid.has_value(), "CQ {} is already being used for tracing tid {}", (uint32_t)cq_id, tid);
this->EnableAllocs();
Expand All @@ -2497,6 +2500,8 @@ void Device::begin_trace(const uint8_t cq_id, const uint32_t tid) {
}

void Device::end_trace(const uint8_t cq_id, const uint32_t tid) {
ZoneScoped;
TracyTTMetalEndTrace(this->id(), tid);
TT_FATAL(this->hw_command_queues_[cq_id]->tid == tid, "CQ {} is not being used for tracing tid {}", (uint32_t)cq_id, tid);
TT_FATAL(this->trace_buffer_pool_.count(tid) > 0, "Trace instance {} must exist on device", tid);
this->hw_command_queues_[cq_id]->record_end();
Expand All @@ -2513,6 +2518,8 @@ void Device::end_trace(const uint8_t cq_id, const uint32_t tid) {
}

void Device::replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking) {
ZoneScoped;
TracyTTMetalReplayTrace(this->id(), tid);
constexpr bool check = false;
TT_FATAL(this->trace_buffer_pool_.count(tid) > 0, "Trace instance {} must exist on device" , tid);
if constexpr (check) {
Expand All @@ -2526,6 +2533,8 @@ void Device::replay_trace(const uint8_t cq_id, const uint32_t tid, const bool bl
}

void Device::release_trace(const uint32_t tid) {
ZoneScoped;
TracyTTMetalReleaseTrace(this->id(), tid);
uint32_t erased = this->trace_buffer_pool_.erase(tid);
// Only enable allocations once all captured traces are released
if (this->trace_buffer_pool_.empty()) {
Expand Down
135 changes: 98 additions & 37 deletions tt_metal/tools/profiler/process_ops_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import json
import yaml
from datetime import datetime
import copy

import click
from loguru import logger
Expand All @@ -39,6 +40,7 @@
"OP TYPE",
"GLOBAL CALL COUNT",
"DEVICE ID",
"TRACE ID",
"ATTRIBUTES",
"MATH FIDELITY",
"CORE COUNT",
Expand Down Expand Up @@ -93,37 +95,66 @@ def import_tracy_op_logs():
with open(tracyOpDataLog, "r", newline="") as csvFile:
opDataDicts = csv.DictReader(csvFile, delimiter=";", quotechar="`")
opsData = []
traceIDs = {}
traceReplays = {}
for opDataDict in opDataDicts:
opDataStr = opDataDict["MessageName"]
opDataTime = opDataDict["total_ns"]
if "TT_DNN" in opDataStr:
tmpStrs = opDataStr.split(" ->\n", 1)
opData = {}
if len(tmpStrs) > 1: # uncached device op, host op, or fallback op
jsonStr = tmpStrs[-1]
opData = json.loads(jsonStr)
if "op_hash" in opData.keys():
assert "device_id" in opData.keys()
deviceID = int(opData["device_id"])
opHash = int(opData["op_hash"])
if deviceID in cached_ops.keys():
cached_ops[deviceID][opHash] = opData.copy()
if "OP" in opDataStr:
tmpStrs = opDataStr.split(" ->\n", 1)
opData = {}
if len(tmpStrs) > 1: # uncached device op, host op, or fallback op
jsonStr = tmpStrs[-1]
opData = json.loads(jsonStr)
if "op_hash" in opData.keys():
assert "device_id" in opData.keys()
deviceID = int(opData["device_id"])
opHash = int(opData["op_hash"])
if deviceID in cached_ops.keys():
cached_ops[deviceID][opHash] = opData.copy()
else:
cached_ops[deviceID] = {opHash: opData.copy()}
del cached_ops[deviceID][opHash]["global_call_count"]
else: # cached device op
opDataList = opDataStr.split(":", 1)[-1].split(",")
assert len(opDataList) > 3, "Wrong cached op info format"
opCode = opDataList[0].strip()
opHash = int(opDataList[1])
deviceID = int(opDataList[2])
opID = int(opDataList[3])
assert deviceID in cached_ops.keys(), "Expected hashed op info is not found"
assert opHash in cached_ops[deviceID].keys(), "Expected hashed op info is not found"
opData = cached_ops[deviceID][opHash].copy()
opData["global_call_count"] = opID
opData["tracy_time"] = opDataTime
opData["trace_id"] = ""
if deviceID in traceIDs:
opData["trace_id"] = traceIDs[deviceID]
opsData.append(opData)
elif "TRACE" in opDataStr:
IDs = opDataStr.split(":")[-1].strip().split(",")
assert len(IDs) == 2, (
"Wrong number of IDs is provided in trace message. "
"Device and trace are the two IDs that should be provided. "
f"But IDs {IDs} were provided"
)
deviceID = int(IDs[0].strip())
traceID = int(IDs[1].strip())
if "BEGIN" in opDataStr:
traceIDs[deviceID] = traceID
elif "END" in opDataStr:
assert traceIDs[deviceID] == traceID, (
f"Wrong trace ID, device {deviceID} should finish on trace ID "
f"{traceIDs[deviceID]} but it is finishing on trace ID {traceID}"
)
traceIDs[deviceID] = None
elif "REPLAY" in opDataStr:
replayIDTime = (traceID, opDataTime)
if deviceID in traceReplays:
traceReplays[deviceID].append(replayIDTime)
else:
cached_ops[deviceID] = {opHash: opData.copy()}
del cached_ops[deviceID][opHash]["global_call_count"]
else: # cached device op
opDataList = opDataStr.split(":", 1)[-1].split(",")
assert len(opDataList) > 3, "Wrong cached op info format"
opCode = opDataList[0].strip()
opHash = int(opDataList[1])
deviceID = int(opDataList[2])
opID = int(opDataList[3])
assert deviceID in cached_ops.keys(), "Expected hashed op info is not found"
assert opHash in cached_ops[deviceID].keys(), "Expected hashed op info is not found"
opData = cached_ops[deviceID][opHash].copy()
opData["global_call_count"] = opID
opData["tracy_time"] = opDataTime
opsData.append(opData)
traceReplays[deviceID] = [replayIDTime]

if "TT_SIGNPOST" in opDataStr:
signpostsCount += 1
Expand Down Expand Up @@ -160,26 +191,29 @@ def import_tracy_op_logs():
def get_device_op_data(ops):
logger.info(f"Getting device ops")
deviceOps = {}
hasTraceRuns = False
for opID, opData in ops.items():
if "device_id" in opData.keys():
deviceID = opData["device_id"]
if deviceID not in deviceOps.keys():
deviceOps[deviceID] = [opData]
else:
deviceOps[deviceID].append(opData)
if "trace_id" in opData.keys() and opData["trace_id"] is not None:
hasTraceRuns = True

def device_ops_compare(op):
return int(op["global_call_count"])

for deviceID in deviceOps:
deviceOps[deviceID].sort(key=device_ops_compare)

return deviceOps
return deviceOps, hasTraceRuns


# Append device data to device ops and return the list of mapped device op ref list
def append_device_data(ops, deviceLogFolder):
deviceOps = get_device_op_data(ops)
deviceOps, hasTraceRuns = get_device_op_data(ops)
logger.info(f"Appending device data")
deviceTimesLog = os.path.join(deviceLogFolder, PROFILER_DEVICE_SIDE_LOG)
if os.path.isfile(deviceTimesLog):
Expand All @@ -190,6 +224,27 @@ def append_device_data(ops, deviceLogFolder):
for device in deviceOps:
assert device in deviceData["devices"].keys()
deviceOpsTime = deviceData["devices"][device]["cores"]["DEVICE"]["riscs"]["TENSIX"]["ops"]
if hasTraceRuns:
generatedHostData = []
opIDHostDataDict = {}
for deviceOp in deviceOps[device]:
opID = deviceOp["global_call_count"]
assert (
opID not in opIDHostDataDict
), f"Host op ID cannot be repeated: op ID {opID} was reported twice by the host"
opIDHostDataDict[opID] = copy.deepcopy(deviceOp)

for deviceOpTime in deviceOpsTime:
if len(deviceOpTime["timeseries"]) > 0:
timeID, ts, statData, risc, core = deviceOpTime["timeseries"][0]
assert "run_host_id" in timeID.keys(), "Device op ID missing: Device data must provide op ID"
deviceOpID = timeID["run_host_id"]
assert (
deviceOpID in opIDHostDataDict
), f"Device op ID not present: Device op ID {deviceOpID} not present in host data"
generatedHostData.append(opIDHostDataDict[deviceOpID])
deviceOps[device] = generatedHostData

if len(deviceOps[device]) != len(deviceOpsTime):
deviceOPId = None
hostOPId = None
Expand All @@ -204,21 +259,27 @@ def append_device_data(ops, deviceLogFolder):
break

if deviceOPId and hostOPId:
assert (
False
), f"Device data mismatch: Expected {len(deviceOps[device])} but received {len(deviceOpsTime)} ops on device {device}. Device is showing op ID {deviceOPId} when host is showing op ID {hostOPId}"
assert False, (
f"Device data mismatch: Expected {len(deviceOps[device])} "
f"but received {len(deviceOpsTime)} ops on device {device}. "
f"Device is showing op ID {deviceOPId} when host is showing op ID {hostOPId}"
)
else:
assert (
True
), f"Device data mismatch: Expected {len(deviceOps[device])} but received {len(deviceOpsTime)} ops on device {device}"
assert False, (
f"Device data mismatch: Expected {len(deviceOps[device])} but "
f"received {len(deviceOpsTime)} ops on device {device}"
)

for deviceOp, deviceOpTime in zip(deviceOps[device], deviceOpsTime):
cores = set()
for timeID, ts, statData, risc, core in deviceOpTime["timeseries"]:
if "zone_name" in timeID.keys() and "FW" in timeID["zone_name"]:
if "run_host_id" in timeID.keys():
assert (
timeID["run_host_id"] == deviceOp["global_call_count"]
), f"op id {timeID['run_host_id']} reproted by device is not matching assigned op id {deviceOp['global_call_count']}"
assert timeID["run_host_id"] == deviceOp["global_call_count"], (
f"Device and host op ID mismatch: "
f"op id {timeID['run_host_id']} reproted by device is "
f"not matching assigned op id {deviceOp['global_call_count']}"
)
if core not in cores:
cores.add(core)
deviceOp["core_usage"] = {"count": len(cores), "cores": [str(core) for core in cores]}
Expand Down
25 changes: 25 additions & 0 deletions tt_metal/tools/profiler/tt_metal_tracy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#if defined(TRACY_ENABLE)

#define TracyTTMetalBeginTrace( device_id, trace_id ) \
std::string trace_message = fmt::format("`TT_DNN_TRACE_BEGIN: {}, {}`", device_id, trace_id); \
TracyMessage(trace_message.c_str(), trace_message.size());

#define TracyTTMetalEndTrace( device_id, trace_id ) \
std::string trace_message = fmt::format("`TT_DNN_TRACE_END: {}, {}`", device_id, trace_id); \
TracyMessage(trace_message.c_str(), trace_message.size());

#define TracyTTMetalReplayTrace( device_id, trace_id ) \
std::string trace_message = fmt::format("`TT_DNN_TRACE_REPLAY: {}, {}`", device_id, trace_id); \
TracyMessage(trace_message.c_str(), trace_message.size());

#define TracyTTMetalReleaseTrace( device_id, trace_id ) \
std::string trace_message = fmt::format("`TT_DNN_TRACE_RELEASE: {}, {}`", device_id, trace_id); \
TracyMessage(trace_message.c_str(), trace_message.size());
#else

#define TracyTTMetalBeginTrace( device_id, trace_id )
#define TracyTTMetalEndTrace( device_id, trace_id )
#define TracyTTMetalReplayTrace( device_id, trace_id )
#define TracyTTMetalRelkaseTrace( device_id, trace_id )

#endif
7 changes: 7 additions & 0 deletions ttnn/cpp/ttnn/operations/core/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,19 @@ Tensor reallocate(const Tensor& input_tensor, const std::optional<MemoryConfig>&

// Trace APIs - Single Device
uint32_t begin_trace_capture(Device* device, const uint8_t cq_id) {
ZoneScoped;
uint32_t tid = Trace::next_id();
device->push_work([device, cq_id, tid]() mutable { device->begin_trace(cq_id, tid); });
return tid;
}

void end_trace_capture(Device* device, const uint32_t tid, const uint8_t cq_id) {
ZoneScoped;
device->push_work([device, cq_id, tid]() mutable { device->end_trace(cq_id, tid); });
}

void execute_trace(Device* device, const uint32_t tid, const uint8_t cq_id, bool blocking) {
ZoneScoped;
// If blocking, ensure that worker thread blocks until trace is completed
device->push_work([device, cq_id, tid, blocking]() mutable { device->replay_trace(cq_id, tid, blocking); });
// If blocking, wait until worker threads have completed
Expand All @@ -169,6 +172,7 @@ void release_trace(Device* device, const uint32_t tid) {

// Trace APIs - Multi Device
uint32_t begin_trace_capture(MeshDevice* device, const uint8_t cq_id) {
ZoneScoped;
auto workers = device->get_devices();
uint32_t tid = Trace::next_id();
for (auto& worker : workers) {
Expand All @@ -178,13 +182,15 @@ uint32_t begin_trace_capture(MeshDevice* device, const uint8_t cq_id) {
}

void end_trace_capture(MeshDevice* device, const uint32_t tid, const uint8_t cq_id) {
ZoneScoped;
auto workers = device->get_devices();
for (auto& worker : workers) {
worker->push_work([worker, cq_id, tid]() mutable { worker->end_trace(cq_id, tid); });
}
}

void execute_trace(MeshDevice* device, const uint32_t tid, const uint8_t cq_id, bool blocking) {
ZoneScoped;
auto workers = device->get_devices();
// If blocking, ensure that each worker thread blocks until device-local trace is completed
for (auto& worker : workers) {
Expand All @@ -199,6 +205,7 @@ void execute_trace(MeshDevice* device, const uint32_t tid, const uint8_t cq_id,
}

void release_trace(MeshDevice* device, const uint32_t tid) {
ZoneScoped;
auto workers = device->get_devices();
for (auto& worker : workers) {
worker->push_work([worker, tid]() mutable { worker->release_trace(tid); });
Expand Down

0 comments on commit b76a20f

Please sign in to comment.