Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix testbed port numbers #239

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion catkit2/testbed/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def run(self):
logging.getLogger('matplotlib').setLevel(logging.WARNING)

# Set up log forwarder.
log_forwarder = LogForwarder('experiment', f'tcp://{self.testbed.host}:{self.testbed.logging_ingress_port}')
log_forwarder = LogForwarder()
log_forwarder.connect('experiment', f'tcp://{self.testbed.host}:{self.testbed.logging_ingress_port}')

# Set up log writer.
self._log_writer = LogWriter(self.testbed.host, self.testbed.logging_egress_port)
Expand Down
45 changes: 35 additions & 10 deletions catkit2/testbed/testbed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import subprocess
import socket
import threading
import contextlib

import psutil
import zmq
Expand All @@ -18,10 +19,22 @@

SERVICE_LIVELINESS = 5

def get_unused_port():
with socket.socket() as sock:
sock.bind(('', 0))
return sock.getsockname()[1]
def get_unused_port(num_ports=1):
ports = []

with contextlib.ExitStack() as stack:
for i in range(num_ports):
sock = socket.socket()
stack.enter_context(sock)

sock.bind(('', 0))
ports.append(sock.getsockname()[1])

if num_ports == 1:
return ports[0]

return ports


class ServiceReference:
'''A reference to a service running on another process.
Expand Down Expand Up @@ -143,6 +156,13 @@ def __init__(self, port, is_simulated, config):
self.host = '127.0.0.1'
self.port = port

self.logging_ingress_port = 0
self.logging_egress_port = 0
self.data_logging_ingress_port = 0
self.data_logging_egress_port = 0
self.tracing_ingress_port = 0
self.tracing_egress_port = 0

self.is_simulated = is_simulated
self.config = config

Expand Down Expand Up @@ -355,7 +375,8 @@ def setup_logging(self):
logging.getLogger().addHandler(self.log_handler)
logging.getLogger().setLevel(logging.DEBUG)

self.log_forwarder = LogForwarder('testbed', f'tcp://localhost:{self.port + 1}')
self.log_forwarder = LogForwarder()
self.log_forwarder.connect('testbed', f'tcp://localhost:{self.logging_ingress_port}')

def destroy_logging(self):
'''Shut down all logging.
Expand All @@ -371,7 +392,9 @@ def destroy_logging(self):
def start_log_distributor(self):
'''Start the log distributor.
'''
self.log_distributor = LogDistributor(self.context, self.port + 1, self.port + 2)
self.logging_ingress_port, self.logging_egress_port = get_unused_port(num_ports=2)

self.log_distributor = LogDistributor(self.context, self.logging_ingress_port, self.logging_egress_port)
self.log_distributor.start()

def stop_log_distributor(self):
Expand Down Expand Up @@ -433,10 +456,12 @@ def on_get_info(self, data):
reply.config = json.dumps(self.config)
reply.is_simulated = self.is_simulated
reply.heartbeat_stream_id = self.heartbeat_stream.stream_id
reply.logging_ingress_port = self.port + 1
reply.logging_egress_port = self.port + 2
reply.data_logging_ingress_port = 0
reply.tracing_ingress_port = 0
reply.logging_ingress_port = self.logging_ingress_port
reply.logging_egress_port = self.logging_egress_port
reply.data_logging_ingress_port = self.data_logging_ingress_port
reply.data_logging_egress_port = self.data_logging_egress_port
reply.tracing_ingress_port = self.tracing_ingress_port
reply.tracing_egress_port = self.tracing_egress_port

return reply.SerializeToString()

Expand Down
3 changes: 2 additions & 1 deletion catkit_bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,8 @@ PYBIND11_MODULE(catkit_bindings, m)
py::arg("print_context") = true);

py::class_<LogForwarder>(m, "LogForwarder")
.def(py::init<std::string, std::string>());
.def(py::init<>())
.def("connect", &LogForwarder::Connect);

#ifdef VERSION_INFO
m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
Expand Down
14 changes: 11 additions & 3 deletions catkit_core/LogForwarder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
using namespace zmq;
using json = nlohmann::json;

LogForwarder::LogForwarder(std::string service_id, std::string host)
: m_ServiceId(service_id), m_Host(host), m_ShutDown(false)
LogForwarder::LogForwarder()
: m_ShutDown(false)
{
m_MessageLoopThread = std::thread(&LogForwarder::MessageLoop, this);
}

LogForwarder::~LogForwarder()
Expand All @@ -20,6 +19,15 @@ LogForwarder::~LogForwarder()
m_MessageLoopThread.join();
}

void LogForwarder::Connect(std::string service_id, std::string host)
{
m_ServiceId = service_id;
m_Host = host;

m_MessageLoopThread = std::thread(&LogForwarder::MessageLoop, this);
}


void LogForwarder::AddLogEntry(const LogEntry &entry)
{
json message = {
Expand Down
4 changes: 3 additions & 1 deletion catkit_core/LogForwarder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
class LogForwarder : LogListener
{
public:
LogForwarder(std::string service_id, std::string host);
LogForwarder();
~LogForwarder();

void Connect(std::string service_id, std::string host);

void AddLogEntry(const LogEntry &entry);

private:
Expand Down
4 changes: 3 additions & 1 deletion catkit_core/Service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ const double SAFETY_INTERVAL = 60; // seconds.

Service::Service(string service_type, string service_id, int service_port, int testbed_port)
: m_Server(service_port), m_ServiceId(service_id), m_ServiceType(service_type),
m_LoggerConsole(), m_LoggerPublish(service_id, "tcp://127.0.0.1:"s + to_string(testbed_port + 1)),
m_LoggerConsole(), m_LoggerPublish(),
m_Heartbeat(nullptr), m_State(nullptr), m_Safety(nullptr), m_Testbed(nullptr),
m_IsRunning(false), m_ShouldShutDown(false), m_FailSafe(false)
{
m_Testbed = make_shared<TestbedProxy>("127.0.0.1", testbed_port);
m_Config = m_Testbed->GetConfig()["services"][service_id];

m_LoggerPublish.Connect(service_id, "tcp://127.0.0.1:"s + to_string(m_Testbed->GetLoggingIngressPort()));

m_Heartbeat = DataStream::Create("heartbeat", service_id, DataType::DT_UINT64, {1}, 20);

string state_stream_id = m_Testbed->RegisterService(
Expand Down
Loading