diff --git a/catkit2/testbed/__init__.py b/catkit2/testbed/__init__.py index a1979ab8e..6860c7b68 100644 --- a/catkit2/testbed/__init__.py +++ b/catkit2/testbed/__init__.py @@ -5,12 +5,19 @@ 'CatkitLogHandler', 'TestbedProxy', 'ServiceProxy', - 'Experiment' + 'Experiment', + 'TraceWriter', + 'trace_interval', + 'trace_instant', + 'trace_counter', + 'ZmqDistributor', ] from .testbed import * from .experiment import * from .service import * from .logging import * +from .tracing import * +from .distributor import * from .testbed_proxy import * from .service_proxy import * diff --git a/catkit2/testbed/distributor.py b/catkit2/testbed/distributor.py new file mode 100644 index 000000000..11eca469d --- /dev/null +++ b/catkit2/testbed/distributor.py @@ -0,0 +1,86 @@ +import zmq +import threading +import traceback + + +class ZmqDistributor: + '''Collects messages on a port and re-publish them on another. + + This operates on a separate thread after it is started. + + Parameters + ---------- + context : zmq.Context + A previously-created ZMQ context. All sockets will be created on this context. + input_port : integer + The port number for the incoming log messages. + output_port : integer + The port number for the outgoing log messages. + callback : function + A callback to call with each message. + ''' + def __init__(self, context, input_port, output_port, callback=None): + self.context = context + self.input_port = input_port + self.output_port = output_port + self.callback = callback + + self.shutdown_flag = threading.Event() + self.thread = None + + self.is_running = threading.Event() + + def start(self): + '''Start the proxy thread. + ''' + self.thread = threading.Thread(target=self._forwarder) + self.thread.start() + + self.is_running.wait() + + def stop(self): + '''Stop the proxy thread. + + This function waits until the thread is actually stopped. + ''' + self.shutdown_flag.set() + + if self.thread: + self.thread.join() + + self.is_running.clear() + + def _forwarder(self): + '''Create sockets and republish all received log messages. + + .. note:: + This function should not be called directly. Use + :func:`~catkit2.testbed.ZmqDistributor.start()` to start the proxy. + ''' + collector = self.context.socket(zmq.PULL) + collector.RCVTIMEO = 50 + collector.bind(f'tcp://*:{self.input_port}') + + publicist = self.context.socket(zmq.PUB) + publicist.bind(f'tcp://*:{self.output_port}') + + self.is_running.set() + + while not self.shutdown_flag.is_set(): + try: + try: + log_message = collector.recv_multipart() + publicist.send_multipart(log_message) + except zmq.ZMQError as e: + if e.errno == zmq.EAGAIN: + # Timed out. + continue + else: + raise RuntimeError('Error during receive') from e + + if self.callback: + self.callback(log_message) + except Exception: + # Something went wrong during handling of the log message. + # Let's ignore this error, but still print the exception. + print(traceback.format_exc()) diff --git a/catkit2/testbed/logging.py b/catkit2/testbed/logging.py index 92aeb6f40..1de7d12f3 100644 --- a/catkit2/testbed/logging.py +++ b/catkit2/testbed/logging.py @@ -3,7 +3,6 @@ import zmq import json import contextlib -import traceback from colorama import Fore, Back, Style from ..catkit_bindings import submit_log_entry, Severity @@ -27,79 +26,6 @@ def emit(self, record): submit_log_entry(filename, line, function, severity, message) -class LogDistributor: - '''Collects log messages on a port and re-publish them on another. - - This operates on a separate thread after it is started. - - Parameters - ---------- - context : zmq.Context - A previously-created ZMQ context. All sockets will be created on this context. - input_port : integer - The port number for the incoming log messages. - output_port : integer - The port number for the outgoing log messages. - ''' - def __init__(self, context, input_port, output_port): - self.context = context - self.input_port = input_port - self.output_port = output_port - - self.shutdown_flag = threading.Event() - self.thread = None - - def start(self): - '''Start the proxy thread. - ''' - self.thread = threading.Thread(target=self.forwarder) - self.thread.start() - - def stop(self): - '''Stop the proxy thread. - - This function waits until the thread is actually stopped. - ''' - self.shutdown_flag.set() - - if self.thread: - self.thread.join() - - def forwarder(self): - '''Create sockets and republish all received log messages. - - .. note:: - This function should not be called directly. Use - :func:`~catkit2.testbed.LoggingProxy.start` to start the proxy. - ''' - collector = self.context.socket(zmq.PULL) - collector.RCVTIMEO = 50 - collector.bind(f'tcp://*:{self.input_port}') - - publicist = self.context.socket(zmq.PUB) - publicist.bind(f'tcp://*:{self.output_port}') - - while not self.shutdown_flag.is_set(): - try: - try: - log_message = collector.recv_multipart() - publicist.send_multipart(log_message) - except zmq.ZMQError as e: - if e.errno == zmq.EAGAIN: - # Timed out. - continue - else: - raise RuntimeError('Error during receive') from e - - log_message = log_message[0].decode('utf-8') - log_message = json.loads(log_message) - - print(f'[{log_message["service_id"]}] {log_message["message"]}') - except Exception: - # Something went wrong during handling of the log message. - # Let's ignore this error, but still print the exception. - print(traceback.format_exc()) - class LogObserver: def __init__(self, host, port): self.context = zmq.Context() diff --git a/catkit2/testbed/testbed.py b/catkit2/testbed/testbed.py index fd36d4cc3..02cd517f4 100644 --- a/catkit2/testbed/testbed.py +++ b/catkit2/testbed/testbed.py @@ -13,6 +13,7 @@ from ..catkit_bindings import LogForwarder, Server, ServiceState, DataStream, get_timestamp, is_alive_state, Client from .logging import * +from .distributor import ZmqDistributor from ..proto import testbed_pb2 as testbed_proto from ..proto import service_pb2 as service_proto @@ -167,11 +168,13 @@ def __init__(self, port, is_simulated, config): self.config = config self.services = {} + self.launched_processes = [] self.log_distributor = None self.log_handler = None self.log_forwarder = None - self.launched_processes = [] + + self.tracing_distributor = None self.log = logging.getLogger(__name__) @@ -271,6 +274,9 @@ def run(self): self.start_log_distributor() self.setup_logging() + # Start tracing distributor. + self.start_tracing_distributor() + heartbeat_thread = threading.Thread(target=self.do_heartbeats) heartbeat_thread.start() @@ -316,6 +322,9 @@ def run(self): # Shut down the server. self.server.stop() + # Stop tracing distributor. + self.stop_tracing_distributor() + # Stop the logging. self.destroy_logging() self.stop_log_distributor() @@ -392,9 +401,15 @@ def destroy_logging(self): def start_log_distributor(self): '''Start the log distributor. ''' + def callback(log_message): + log_message = log_message[0].decode('utf-8') + log_message = json.loads(log_message) + + print(f'[{log_message["service_id"]}] {log_message["message"]}') + self.logging_ingress_port, self.logging_egress_port = get_unused_port(num_ports=2) - self.log_distributor = LogDistributor(self.context, self.logging_ingress_port, self.logging_egress_port) + self.log_distributor = ZmqDistributor(self.context, self.logging_ingress_port, self.logging_egress_port, callback) self.log_distributor.start() def stop_log_distributor(self): @@ -404,6 +419,19 @@ def stop_log_distributor(self): self.log_distributor.stop() self.log_distributor = None + def start_tracing_distributor(self): + '''Start the tracing distributor. + ''' + self.tracing_ingress_port, self.tracing_egress_port = get_unused_port(num_ports=2) + + self.tracing_distributor = ZmqDistributor(self.context, self.tracing_ingress_port, self.tracing_egress_port) + self.tracing_distributor.start() + + def stop_tracing_distributor(self): + if self.tracing_distributor: + self.tracing_distributor.stop() + self.tracing_distributor = None + def on_start_service(self, data): request = testbed_proto.StartServiceRequest() request.ParseFromString(data) diff --git a/catkit2/testbed/tracing.py b/catkit2/testbed/tracing.py new file mode 100644 index 000000000..ef2134c82 --- /dev/null +++ b/catkit2/testbed/tracing.py @@ -0,0 +1,263 @@ +import json +import threading +import zmq +import contextlib + +from ..proto import tracing_pb2 as tracing_proto +from .. import catkit_bindings + + +def write_json(f, data): + data = json.dumps(data, indent=None, separators=(',', ':')) + + # Write JSON to file. + f.write(data + ',\n') + + +class TraceWriter: + '''A writer for performance trace logs. + + This writer writes the trace log in Google JSON format, described by + https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU + This file can be read in by a number of trace log viewers for inspection. + + Parameters + ---------- + host : str + The host which distributes the trace messages. + port : int + The port on which the host distributes the trace messages. + ''' + def __init__(self, host, port): + self.f = None + + self.context = zmq.Context() + + self.host = host + self.port = port + + self.shutdown_flag = threading.Event() + self.thread = None + + def open(self, filename): + '''Open the writer. + + Parameters + ---------- + filename : str + The path to the file where to write the performance trace. + + Returns + ------- + TraceWriter + The current trace writer. This is for use as a context manager. + ''' + self.shutdown_flag.clear() + + self._filename = filename + + self.thread = threading.Thread(target=self._loop) + self.thread.start() + + return self + + def close(self): + '''Close the writer. + ''' + self.shutdown_flag.set() + + if self.thread: + self.thread.join() + + def _loop(self): + # Set up socket. + socket = self.context.socket(zmq.SUB) + socket.connect(f'tcp://{self.host}:{self.port}') + socket.subscribe('') + socket.RCVTIMEO = 50 + + # Set up cache for process names. + process_names = {} + thread_names = {} + + # Set initial time. + # Subtracting one second before current time to hopefully never have negative timestamps. + t_0 = catkit_bindings.get_timestamp() - 1_000_000_000 + + with open(self._filename, 'w') as f: + # Write JSON header. + f.write('[\n') + + # Main loop. + while not self.shutdown_flag.is_set(): + # Receive a new trace event. + try: + message = socket.recv_multipart() + except zmq.ZMQError as e: + if e.errno == zmq.EAGAIN: + # Timed out. + continue + else: + raise RuntimeError('Error during receive.') from e + + # Decode event. + proto_event = tracing_proto.TraceEvent() + proto_event.ParseFromString(message[0]) + + # Convert event to JSON format. + # All JSON timestamps are in us, our timestamps are in ns, so + # divide all times by 1000. + event_type = proto_event.WhichOneof('event') + if event_type == 'interval': + raw_data = proto_event.interval + data = { + 'name': raw_data.name, + 'cat': raw_data.category, + 'ph': 'X', + 'ts': (raw_data.timestamp - t_0) / 1000, + 'dur': raw_data.duration / 1000, + 'pid': raw_data.process_id, + 'tid': raw_data.thread_id + } + elif event_type == 'instant': + raw_data = proto_event.instant + data = { + 'name': raw_data.name, + 'ph': 'i', + 'ts': (raw_data.timestamp - t_0) / 1000, + 'pid': raw_data.process_id, + 'tid': raw_data.thread_id + } + elif event_type == 'counter': + raw_data = proto_event.counter + data = { + 'name': raw_data.name, + 'ph': 'C', + 'ts': (raw_data.timestamp - t_0) / 1000, + 'pid': raw_data.process_id, + 'args': { + raw_data.series: raw_data.counter + } + } + + if hasattr(raw_data, 'process_name') and raw_data.process_name: + # We haven't seen a message from this process before. + # Extract and store process name, and write metadata. + if raw_data.process_id not in process_names: + process_names[raw_data.process_id] = raw_data.process_name + + metadata = { + 'name': 'process_name', + 'ph': 'M', + 'pid': raw_data.process_id, + 'args': { + 'name': raw_data.process_name + } + } + + write_json(f, metadata) + + if hasattr(raw_data, 'thread_name') and raw_data.thread_name: + # We haven't seen a message from this process before. + # Extract and store process name, and write metadata. + if raw_data.process_id not in thread_names: + thread_names[raw_data.process_id] = raw_data.thread_name + + metadata = { + 'name': 'thread_name', + 'ph': 'M', + 'pid': raw_data.process_id, + 'tid': raw_data.thread_id, + 'args': { + 'name': raw_data.thread_name + } + } + + write_json(f, metadata) + + write_json(f, data) + + def __enter__(self): + '''Enter the context manager. + + Returns + ------- + TraceWriter + The current trace writer. + ''' + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + '''Exit the context manager. + + Parameters + ---------- + exc_type : class + The exception class. + exc_val : + The value of the exception. + exc_tb : traceback + The traceback of the exception. + ''' + self.close() + + +@contextlib.contextmanager +def trace_interval(name, category=''): + '''Trace an interval event. + + Both the start and end time will be logged and + the event will be shown as a bar in the trace log viewer. + + Parameters + ---------- + name : str + The name of the event. + category : str, optional + The category of interval. Events with the same category will be + colored the same in the log viewer. The default is empty. + ''' + start = catkit_bindings.get_timestamp() + + try: + yield + finally: + end = catkit_bindings.get_timestamp() + + catkit_bindings.trace_interval(name, category, start, end - start) + +def trace_instant(name): + '''Trace an instant event. + + The time at which this function is called will be logged and + the event will be shown as a single arrow or line in the trace + log viewer. + + Parameters + ---------- + name : str + The name of the event. + ''' + timestamp = catkit_bindings.get_timestamp() + + catkit_bindings.trace_instant(name, timestamp) + +def trace_counter(name, series, counter): + '''Trace a counter event. + + The time at which this function is called is logged, in addition + to a scalar that we want to keep track of. The event will be shown + as a line graph. + + Parameters + ---------- + name : str + The name of the event. + series : str + The series of the event, e.g. "contrast", "iteration". + counter : float, int + The contents of this event, i.e. what changes over time. + ''' + timestamp = catkit_bindings.get_timestamp() + + catkit_bindings.trace_counter(name, series, timestamp, counter) diff --git a/catkit_bindings/bindings.cpp b/catkit_bindings/bindings.cpp index 8488d0d00..00c0f8d38 100644 --- a/catkit_bindings/bindings.cpp +++ b/catkit_bindings/bindings.cpp @@ -19,6 +19,7 @@ #include "Server.h" #include "Client.h" #include "HostName.h" +#include "Tracing.h" #include "proto/testbed.pb.h" @@ -668,6 +669,22 @@ PYBIND11_MODULE(catkit_bindings, m) .def(py::init<>()) .def("connect", &LogForwarder::Connect); + m.def("trace_connect", [](std::string process_name, std::string host, int port) { + tracing_proxy.Connect(process_name, host, port); + }); + m.def("trace_disconnect", []() { + tracing_proxy.Disconnect(); + }); + m.def("trace_interval", [](std::string name, std::string category, uint64_t timestamp_start, uint64_t duration) { + tracing_proxy.TraceInterval(name, category, timestamp_start, duration); + }); + m.def("trace_instant", [](std::string name, uint64_t timestamp) { + tracing_proxy.TraceInstant(name, timestamp); + }); + m.def("trace_counter", [](std::string name, std::string series, uint64_t timestamp, double counter) { + tracing_proxy.TraceCounter(name, series, timestamp, counter); + }); + #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); #else diff --git a/catkit_core/CMakeLists.txt b/catkit_core/CMakeLists.txt index ef34abfc2..4f0b9107c 100644 --- a/catkit_core/CMakeLists.txt +++ b/catkit_core/CMakeLists.txt @@ -30,7 +30,7 @@ add_library(catkit_core STATIC ServiceProxy.cpp ServiceState.cpp Tensor.cpp - TracingProxy.cpp + Tracing.cpp Types.cpp Util.cpp proto/core.pb.cc diff --git a/catkit_core/DataStream.cpp b/catkit_core/DataStream.cpp index 6df46ea5b..26217c05c 100644 --- a/catkit_core/DataStream.cpp +++ b/catkit_core/DataStream.cpp @@ -3,6 +3,7 @@ //#include "Log.h" #include "Timing.h" #include "Util.h" +#include "Tracing.h" #include #include @@ -160,6 +161,9 @@ DataFrame DataStream::RequestNewFrame() frame.Set(m_Header->m_DataType, m_Header->m_NumDimensions, m_Header->m_Dimensions, m_Buffer + offset, false); + auto ts = GetTimeStamp(); + tracing_proxy.TraceInterval("DataStream::RequestNewFrame", GetStreamName(), ts, 0); + return frame; } @@ -203,16 +207,24 @@ void DataStream::SubmitFrame(size_t id) m_Header->m_FrameRateCounter = m_Header->m_FrameRateCounter * std::exp(-FRAMERATE_DECAY * time_delta) + FRAMERATE_DECAY; + + auto ts = GetTimeStamp(); + tracing_proxy.TraceInterval("DataStream::SubmitFrame", GetStreamName(), ts, 0); } void DataStream::SubmitData(const void *data) { + auto start = GetTimeStamp(); + DataFrame frame = RequestNewFrame(); char *source = (char *) data; std::copy(source, source + frame.GetSizeInBytes(), frame.m_Data); SubmitFrame(frame.m_Id); + + auto end = GetTimeStamp(); + tracing_proxy.TraceInterval("DataStream::SubmitData", GetStreamName(), start, end - start); } std::vector DataStream::GetDimensions() diff --git a/catkit_core/Service.cpp b/catkit_core/Service.cpp index 2a71d5c56..e472d2518 100644 --- a/catkit_core/Service.cpp +++ b/catkit_core/Service.cpp @@ -3,6 +3,7 @@ #include "Finally.h" #include "Timing.h" #include "TestbedProxy.h" +#include "Tracing.h" #include "proto/service.pb.h" #include @@ -33,6 +34,8 @@ Service::Service(string service_type, string service_id, int service_port, int t m_Heartbeat = DataStream::Create("heartbeat", service_id, DataType::DT_UINT64, {1}, 20); + tracing_proxy.Connect(service_id, "127.0.0.1", m_Testbed->GetTracingIngressPort()); + string state_stream_id = m_Testbed->RegisterService( service_id, service_type, diff --git a/catkit_core/Tracing.cpp b/catkit_core/Tracing.cpp new file mode 100644 index 000000000..0807b63e4 --- /dev/null +++ b/catkit_core/Tracing.cpp @@ -0,0 +1,219 @@ +#include "Tracing.h" + +#include "Timing.h" +#include "Util.h" +#include "Log.h" +#include "proto/tracing.pb.h" + +#include + +using namespace std; + +TracingProxy tracing_proxy; + +TracingProxy::TracingProxy() + : m_IsConnected(false) +{ +} + +TracingProxy::~TracingProxy() +{ + Disconnect(); +} + +void TracingProxy::Connect(string process_name, string host, int port) +{ + // Disconnect if we are already running. + if (IsConnected()) + { + if (m_Host == host && m_Port == port) + return; + + Disconnect(); + } + + SetProcessName(process_name); + + m_Host = host; + m_Port = port; + m_ShutDown = false; + + m_MessageLoopThread = std::thread(&TracingProxy::MessageLoop, this); + m_IsConnected = true; +} + +void TracingProxy::Disconnect() +{ + m_ShutDown = true; + m_ConditionVariable.notify_all(); + + // Wait for the thread to exit. + if (m_MessageLoopThread.joinable()) + m_MessageLoopThread.join(); + + m_IsConnected = false; +} + +bool TracingProxy::IsConnected() +{ + return m_IsConnected; +} + +void TracingProxy::TraceInterval(string name, string category, uint64_t timestamp, uint64_t duration) +{ + TraceEventInterval event; + + event.name = name; + event.category = category; + event.process_id = GetProcessId(); + event.process_name = m_ProcessName; + event.thread_id = GetThreadId(); + event.thread_name = m_ThreadName; + event.timestamp = timestamp; + event.duration = duration; + + AddTraceEvent(event); +} + +void TracingProxy::TraceInstant(string name, uint64_t timestamp) +{ + TraceEventInstant event; + + event.name = name; + event.process_id = GetProcessId(); + event.process_name = m_ProcessName; + event.thread_id = GetThreadId(); + event.thread_name = m_ThreadName; + event.timestamp = timestamp; + + AddTraceEvent(event); +} + +void TracingProxy::TraceCounter(string name, string series, uint64_t timestamp, double counter) +{ + TraceEventCounter event; + + event.name = name; + event.series = series; + event.process_id = GetProcessId(); + event.process_name = m_ProcessName; + event.timestamp = timestamp; + event.counter = counter; + + AddTraceEvent(event); +} + +struct BuildProtoEvent +{ + string operator()(TraceEventInterval &event) + { + auto *interval = new catkit_proto::tracing::TraceEventInterval(); + + interval->set_name(event.name); + interval->set_category(event.category); + interval->set_process_id(event.process_id); + interval->set_process_name(event.process_name); + interval->set_thread_id(event.thread_id); + interval->set_thread_name(event.thread_name); + interval->set_timestamp(event.timestamp); + interval->set_duration(event.duration); + + catkit_proto::tracing::TraceEvent proto; + proto.set_allocated_interval(interval); + + return proto.SerializeAsString(); + } + + string operator()(TraceEventInstant &event) + { + auto *instant = new catkit_proto::tracing::TraceEventInstant(); + + instant->set_name(event.name); + instant->set_process_id(event.process_id); + instant->set_process_name(event.process_name); + instant->set_thread_id(event.thread_id); + instant->set_thread_name(event.thread_name); + instant->set_timestamp(event.timestamp); + + catkit_proto::tracing::TraceEvent proto; + proto.set_allocated_instant(instant); + + return proto.SerializeAsString(); + } + + string operator()(TraceEventCounter &event) + { + auto *counter = new catkit_proto::tracing::TraceEventCounter(); + + counter->set_name(event.name); + counter->set_series(event.series); + counter->set_process_id(event.process_id); + counter->set_process_name(event.process_name); + counter->set_timestamp(event.timestamp); + counter->set_counter(event.counter); + + catkit_proto::tracing::TraceEvent proto; + proto.set_allocated_counter(counter); + + return proto.SerializeAsString(); + } +}; + +void TracingProxy::MessageLoop() +{ + zmq::context_t context; + zmq::socket_t socket(context, ZMQ_PUSH); + + socket.set(zmq::sockopt::linger, 0); + socket.set(zmq::sockopt::sndtimeo, 10); + + socket.connect("tcp://"s + m_Host + ":" + to_string(m_Port)); + + TraceEvent event; + + while (!m_ShutDown) + { + // Get next message from the queue. + { + std::unique_lock lock(m_Mutex); + + while (m_TraceMessages.empty() && !m_ShutDown) + { + m_ConditionVariable.wait(lock); + } + + if (m_ShutDown) + break; + + event = m_TraceMessages.front(); + m_TraceMessages.pop(); + } + + // Convert the TraceEvent to a ProtoBuf serialized string. + string message = std::visit(BuildProtoEvent{}, event); + + // Construct message. + zmq::message_t message_zmq(message.size()); + memcpy(message_zmq.data(), message.c_str(), message.size()); + + // Send message to socket. + zmq::send_result_t res; + do + { + res = socket.send(message_zmq, zmq::send_flags::none); + } + while (!res.has_value() && m_ShutDown); + } + + socket.close(); +} + +void TracingProxy::SetProcessName(string process_name) +{ + m_ProcessName = process_name; +} + +void TracingProxy::SetThreadName(string thread_name) +{ + m_ThreadName = thread_name; +} diff --git a/catkit_core/Tracing.h b/catkit_core/Tracing.h new file mode 100644 index 000000000..e45672457 --- /dev/null +++ b/catkit_core/Tracing.h @@ -0,0 +1,94 @@ +#ifndef TRACING_PROXY_H +#define TRACING_PROXY_H + +#include +#include +#include +#include +#include + +struct TraceEventInterval +{ + std::string name; + std::string category; + std::uint32_t process_id; + std::string process_name; + std::uint32_t thread_id; + std::string thread_name; + std::uint64_t timestamp; + std::uint64_t duration; +}; + +struct TraceEventInstant +{ + std::string name; + std::uint32_t process_id; + std::string process_name; + std::uint32_t thread_id; + std::string thread_name; + std::uint64_t timestamp; +}; + +struct TraceEventCounter +{ + std::string name; + std::string series; + std::uint32_t process_id; + std::string process_name; + std::uint64_t timestamp; + double counter; +}; + +typedef std::variant TraceEvent; + +class TracingProxy +{ +public: + TracingProxy(); + ~TracingProxy(); + + void Connect(std::string process_name, std::string host, int port); + void Disconnect(); + bool IsConnected(); + + void TraceInterval(std::string name, std::string category, uint64_t timestamp_start, uint64_t duration); + void TraceInstant(std::string name, uint64_t timestamp); + void TraceCounter(std::string name, std::string series, uint64_t timestamp, double counter); + +private: + template + void AddTraceEvent(T &event) + { + if (IsConnected()) + { + std::unique_lock lock(m_Mutex); + m_TraceMessages.emplace(event); + + m_ConditionVariable.notify_all(); + } + } + + void MessageLoop(); + + void SetProcessName(std::string process_name); + void SetThreadName(std::string thread_name); + + std::thread m_MessageLoopThread; + + std::atomic_bool m_IsConnected; + std::atomic_bool m_ShutDown; + + std::queue m_TraceMessages; + std::mutex m_Mutex; + std::condition_variable m_ConditionVariable; + + std::string m_Host; + int m_Port; + + std::string m_ProcessName; + std::string m_ThreadName; +}; + +extern TracingProxy tracing_proxy; + +#endif // TRACING_PROXY_H diff --git a/catkit_core/TracingProxy.cpp b/catkit_core/TracingProxy.cpp deleted file mode 100644 index adb604f72..000000000 --- a/catkit_core/TracingProxy.cpp +++ /dev/null @@ -1,174 +0,0 @@ -#include "TracingProxy.h" - -#include "Timing.h" -#include "Util.h" - -using namespace std; - -void AddProcessThreadIds(string &contents) -{ - contents += ",\"pid\":"; - contents += to_string(GetProcessId()); - contents += ",\"tid\":"; - contents += to_string(GetThreadId()); -} - -TracingProxy::TracingProxy(std::shared_ptr testbed) -{ - m_Host = testbed->GetHost(); - m_Port = testbed->GetTracingIngressPort(); - - m_MessageLoopThread = std::thread(&TracingProxy::MessageLoop, this); -} - -TracingProxy::~TracingProxy() -{ - ShutDown(); - - if (m_MessageLoopThread.joinable()) - m_MessageLoopThread.join(); -} - -void TracingProxy::TraceBegin(string func, string what) -{ - auto ts = GetTimeStamp(); - - string contents = "{\"name\":\""; - contents += func; - contents += "\",\"cat\":\"\",\"ph\":\"B\",\"ts\":"; - contents += to_string(double(ts) / 1000); - AddProcessThreadIds(contents); - contents += ",\"args\":{\"what\":\""; - contents += what; - contents += "\"}}"; - - SendTraceMessage(contents); -} - -void TracingProxy::TraceEnd(string func) -{ - auto ts = GetTimeStamp(); - - string contents = "{\"name\":\""; - contents += func; - contents += "\",\"cat\":\"\",\"ph\":\"E\",\"ts\":"; - contents += to_string(double(ts) / 1000); - AddProcessThreadIds(contents); - contents += ",}"; - - SendTraceMessage(contents); -} - -void TracingProxy::TraceInterval(string func, string what, uint64_t timestamp_start, uint64_t timestamp_end) -{ - string contents = "{\"name\":\""; - contents += func; - contents += "\",\"cat\":\"\",\"ph\":\"X\",\"ts\":"; - contents += to_string(double(timestamp_start) / 1000); - AddProcessThreadIds(contents); - contents += "\"dur\":"; - contents += to_string(double(timestamp_end - timestamp_start) / 1000); - contents += ",\"args\":{\"what\":\""; - contents += what; - contents += "\"}}"; - - SendTraceMessage(contents); -} - -void TracingProxy::TraceCounter(string func, string series, double counter) -{ - auto ts = GetTimeStamp(); - string contents = "{\"name\":\""; - contents += func; - contents += "\",\"cat\":\"\",\"ph\":\"C\",\"ts\":"; - contents += to_string(double(ts) / 1000); - AddProcessThreadIds(contents); - contents += ",\"args\":{\""; - contents += series; - contents += "\":"; - contents += to_string(counter); - contents += "}}"; - - SendTraceMessage(contents); -} - -void TracingProxy::TraceProcessName(string process_name) -{ - string contents = "{\"name\":\"process_name\",\"ph\":\"M\""; - AddProcessThreadIds(contents); - contents += ",\"args\":{\"name\":\""; - contents += process_name; - contents += "}}"; - - SendTraceMessage(contents); -} - -void TracingProxy::TraceThreadName(string thread_name) -{ - string contents = "{\"name\":\"thread_name\",\"ph\":\"M\""; - AddProcessThreadIds(contents); - contents += ",\"args\":{\"name\":\""; - contents += thread_name; - contents += "}}"; - - SendTraceMessage(contents); -} - -void TracingProxy::SendTraceMessage(string contents) -{ - std::unique_lock lock(m_Mutex); - m_TraceMessages.push(contents); - - m_ConditionVariable.notify_all(); -} - -void TracingProxy::MessageLoop() -{ - zmq::context_t context; - zmq::socket_t socket(context, ZMQ_PUSH); - - socket.set(zmq::sockopt::linger, 0); - socket.set(zmq::sockopt::sndtimeo, 10); - - socket.connect("tcp://"s + m_Host + ":" + to_string(m_Port)); - - std::string message; - - while (!m_ShutDown) - { - // Get next message from the queue. - { - std::unique_lock lock(m_Mutex); - - while (m_TraceMessages.empty() && !m_ShutDown) - { - m_ConditionVariable.wait(lock); - } - - if (m_ShutDown) - break; - - message = m_TraceMessages.front(); - m_TraceMessages.pop(); - } - - // Construct message. - zmq::message_t message_zmq(message.size()); - memcpy(message_zmq.data(), message.c_str(), message.size()); - - zmq::send_result_t res; - do - { - res = socket.send(message_zmq, zmq::send_flags::none); - } - while (!res.has_value() && m_ShutDown); - } - - socket.close(); -} - -void TracingProxy::ShutDown() -{ - m_ShutDown = true; - m_ConditionVariable.notify_all(); -} diff --git a/catkit_core/TracingProxy.h b/catkit_core/TracingProxy.h deleted file mode 100644 index 98162dbf3..000000000 --- a/catkit_core/TracingProxy.h +++ /dev/null @@ -1,45 +0,0 @@ -#ifndef TRACING_PROXY_H -#define TRACING_PROXY_H - -#include "TestbedProxy.h" - -#include - -#include -#include -#include - -class TracingProxy -{ -public: - TracingProxy(std::shared_ptr testbed); - ~TracingProxy(); - - void TraceBegin(std::string func, std::string what); - void TraceEnd(std::string func); - - void TraceInterval(std::string func, std::string what, uint64_t timestamp_start, uint64_t timestamp_end); - - void TraceCounter(std::string func, std::string series, double counter); - - void TraceProcessName(std::string process_name); - void TraceThreadName(std::string thread_name); - -private: - void SendTraceMessage(std::string contents); - - void MessageLoop(); - void ShutDown(); - - std::thread m_MessageLoopThread; - std::atomic_bool m_ShutDown; - - std::queue m_TraceMessages; - std::mutex m_Mutex; - std::condition_variable m_ConditionVariable; - - std::string m_Host; - int m_Port; -}; - -#endif // TRACING_PROXY_H diff --git a/proto/tracing.proto b/proto/tracing.proto index 34d4de129..a1b172048 100644 --- a/proto/tracing.proto +++ b/proto/tracing.proto @@ -2,21 +2,44 @@ syntax = "proto3"; package catkit_proto.tracing; -enum TraceEventPhase +message TraceEventInterval { - BEGIN = 0; - END = 1; - INTERVAL = 2; - COUNTER = 3; + string name = 1; + string category = 2; + uint32 process_id = 3; + string process_name = 4; + uint32 thread_id = 5; + string thread_name = 6; + uint64 timestamp = 7; + uint64 duration = 8; } -message TraceEvent +message TraceEventInstant { string name = 1; - TraceEventPhase phase = 2; - uint32 process_id = 3; + uint32 process_id = 2; + string process_name = 3; uint32 thread_id = 4; + string thread_name = 5; + uint64 timestamp = 6; +} + +message TraceEventCounter +{ + string name = 1; + string series = 2; + uint32 process_id = 3; + string process_name = 4; uint64 timestamp = 5; - uint64 duration = 6; - double counter = 7; + double counter = 6; +} + +message TraceEvent +{ + oneof event + { + TraceEventInterval interval = 1; + TraceEventInstant instant = 2; + TraceEventCounter counter = 3; + } } diff --git a/tests/test_tracing.py b/tests/test_tracing.py new file mode 100644 index 000000000..e94fcd358 --- /dev/null +++ b/tests/test_tracing.py @@ -0,0 +1,80 @@ +from catkit2 import TraceWriter, trace_interval, trace_instant, trace_counter, ZmqDistributor +from catkit2.catkit_bindings import trace_connect, trace_disconnect + +import time +import zmq +import json + + +PROCESS_NAME = 'our_process_name' +FNAME = 'trace.json' +INSTANT_NAME = 'blank' +COUNTER_NAME = 'counter' +SERIES_NAME = 'series' +INTERVAL_NAME_1 = 'a' +INTERVAL_NAME_2 = 'ab' + +def test_trace_writer(unused_port): + input_port = unused_port() + output_port = unused_port() + + writer = TraceWriter('127.0.0.1', output_port) + trace_connect(PROCESS_NAME, '127.0.0.1', input_port) + + context = zmq.Context() + + tracing_distributor = ZmqDistributor(context, input_port, output_port) + tracing_distributor.start() + + # Wait for slow-joiner of ZMQ sockets. + time.sleep(0.3) + + try: + with writer.open(FNAME): + + with trace_interval(INTERVAL_NAME_1): + with trace_interval(INTERVAL_NAME_2): + time.sleep(0.01) + + for i in range(10): + trace_counter(COUNTER_NAME, SERIES_NAME, i) + + if i % 2 == 0: + trace_instant(INSTANT_NAME) + + # Wait for all messages to pass through the system and be written out. + time.sleep(0.3) + + finally: + trace_disconnect() + + tracing_distributor.stop() + + # Check the written JSON file. + with open(FNAME) as f: + data = f.read()[:-2] + ']' + + entries = json.loads(data) + + for entry in entries: + assert entry['ph'] in ['M', 'X', 'C', 'i'] + + if entry['ph'] == 'M': + if entry['name'] == 'process_name': + assert entry['args']['name'] == PROCESS_NAME + elif entry['ph'] == 'X': + assert entry['name'] in [INTERVAL_NAME_1, INTERVAL_NAME_2] + assert 'dur' in entry + assert 'ts' in entry + assert 'pid' in entry + assert 'tid' in entry + elif entry['ph'] == 'C': + assert entry['name'] == COUNTER_NAME + assert 'ts' in entry + assert 'pid' in entry + assert SERIES_NAME in entry['args'] + elif entry['ph'] == 'i': + assert entry['name'] == INSTANT_NAME + assert 'ts' in entry + assert 'pid' in entry + assert 'tid' in entry