Skip to content

Commit

Permalink
Move everything zmq related to it's own backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mraspaud committed Dec 6, 2023
1 parent b2e3df1 commit 6a56d2d
Show file tree
Hide file tree
Showing 11 changed files with 290 additions and 205 deletions.
22 changes: 10 additions & 12 deletions posttroll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,28 @@
"""Posttroll packages."""

import logging
import os
import sys
from datetime import datetime

import zmq
from donfig import Config

from .version import get_versions

config = Config("posttroll")
context = {}
# context = {}
logger = logging.getLogger(__name__)


def get_context():
"""Provide the context to use.
# def get_context():
# """Provide the context to use.

This function takes care of creating new contexts in case of forks.
"""
pid = os.getpid()
if pid not in context:
context[pid] = zmq.Context()
logger.debug("renewed context for PID %d", pid)
return context[pid]
# This function takes care of creating new contexts in case of forks.
# """
# pid = os.getpid()
# if pid not in context:
# context[pid] = zmq.Context()
# logger.debug("renewed context for PID %d", pid)
# return context[pid]


def strp_isoformat(strg):
Expand Down
37 changes: 8 additions & 29 deletions posttroll/address_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,25 @@
/<server-name>/address info ... host:port
"""
import copy
import errno
import logging
import os
import threading
import errno
import time

from datetime import datetime, timedelta

import netifaces
from zmq import REP, LINGER

from posttroll import config
from posttroll.bbmcast import MulticastReceiver, SocketTimeout
from posttroll.message import Message
from posttroll.publisher import Publish
from posttroll import get_context


__all__ = ('AddressReceiver', 'getaddress')
__all__ = ("AddressReceiver", "getaddress")

LOGGER = logging.getLogger(__name__)

debug = os.environ.get('DEBUG', False)
debug = os.environ.get("DEBUG", False)
broadcast_port = 21200

default_publish_port = 16543
Expand All @@ -64,7 +61,7 @@ def get_local_ips():
for addr in inet_addrs:
if addr is not None:
for add in addr:
ips.append(add['addr'])
ips.append(add["addr"])
return ips

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -169,7 +166,9 @@ def _run(self):
break

else:
recv = _SimpleReceiver(port)
if config.get("backend", "unsecure_zmq") == "unsecure_zmq":
from posttroll.backends.zmq.address_receiver import SimpleReceiver
recv = SimpleReceiver(port)
nameservers = ["localhost"]

self._is_running = True
Expand Down Expand Up @@ -221,26 +220,6 @@ def _add(self, adr, metadata):
self._addresses[adr] = metadata


class _SimpleReceiver(object):

""" Simple listing on port for address messages."""

def __init__(self, port=None):
self._port = port or default_publish_port
self._socket = get_context().socket(REP)
self._socket.bind("tcp://*:" + str(port))

def __call__(self):
message = self._socket.recv_string()
self._socket.send_string("ok")
return message, None

def close(self):
"""Close the receiver."""
self._socket.setsockopt(LINGER, 1)
self._socket.close()


# -----------------------------------------------------------------------------
# default
getaddress = AddressReceiver
18 changes: 18 additions & 0 deletions posttroll/backends/zmq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
import logging
import os

import zmq

from posttroll import config

logger = logging.getLogger(__name__)
context = {}


def get_context():
"""Provide the context to use.
This function takes care of creating new contexts in case of forks.
"""
pid = os.getpid()
if pid not in context:
context[pid] = zmq.Context()
logger.debug("renewed context for PID %d", pid)
return context[pid]

def _set_tcp_keepalive(socket):
_set_int_sockopt(socket, zmq.TCP_KEEPALIVE, config.get("tcp_keepalive", None))
_set_int_sockopt(socket, zmq.TCP_KEEPALIVE_CNT, config.get("tcp_keepalive_cnt", None))
Expand Down
22 changes: 22 additions & 0 deletions posttroll/backends/zmq/address_receiver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from posttroll.address_receiver import default_publish_port
from posttroll.backends.zmq import get_context
from zmq import REP, LINGER

class SimpleReceiver(object):

""" Simple listing on port for address messages."""

def __init__(self, port=None):
self._port = port or default_publish_port
self._socket = get_context().socket(REP)
self._socket.bind("tcp://*:" + str(port))

def __call__(self):
message = self._socket.recv_string()
self._socket.send_string("ok")
return message, None

def close(self):
"""Close the receiver."""
self._socket.setsockopt(LINGER, 1)
self._socket.close()
51 changes: 51 additions & 0 deletions posttroll/backends/zmq/message_broadcaster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import threading
from posttroll.backends.zmq import get_context

from zmq import LINGER, NOBLOCK, REQ, ZMQError

import logging

logger = logging.getLogger(__name__)


class UnsecureZMQDesignatedReceiversSender:
"""Sends message to multiple *receivers* on *port*."""

def __init__(self, default_port, receivers):
self.default_port = default_port

self.receivers = receivers
self._shutdown_event = threading.Event()

def __call__(self, data):
"""Send data."""
for receiver in self.receivers:
self._send_to_address(receiver, data)

def _send_to_address(self, address, data, timeout=10):
"""Send data to *address* and *port* without verification of response."""
# Socket to talk to server
socket = get_context().socket(REQ)
try:
socket.setsockopt(LINGER, timeout * 1000)
if address.find(":") == -1:
socket.connect("tcp://%s:%d" % (address, self.default_port))
else:
socket.connect("tcp://%s" % address)
socket.send_string(data)
while not self._shutdown_event.is_set():
try:
message = socket.recv_string(NOBLOCK)
except ZMQError:
self._shutdown_event.wait(.1)
continue
if message != "ok":
logger.warn("invalid acknowledge received: %s" % message)
break

finally:
socket.close()

def close(self):
"""Close the sender."""
self._shutdown_event.set()
90 changes: 90 additions & 0 deletions posttroll/backends/zmq/ns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""ZMQ implexentation of ns."""

import logging
from threading import Lock

from zmq import LINGER, NOBLOCK, POLLIN, REP, REQ, Poller

from posttroll.backends.zmq import get_context
from posttroll.message import Message
from posttroll.ns import PORT, get_active_address

logger = logging.getLogger("__name__")

nslock = Lock()


def unsecure_zmq_get_pub_address(name, timeout=10, nameserver="localhost"):
"""Get the address of the publisher.
For a given publisher *name* from the nameserver on *nameserver* (localhost by default).
"""
# Socket to talk to server
socket = get_context().socket(REQ)
try:
socket.setsockopt(LINGER, int(timeout * 1000))
socket.connect("tcp://" + nameserver + ":" + str(PORT))
logger.debug("Connecting to %s",
"tcp://" + nameserver + ":" + str(PORT))
poller = Poller()
poller.register(socket, POLLIN)

message = Message("/oper/ns", "request", {"service": name})
socket.send_string(str(message))

# Get the reply.
sock = poller.poll(timeout=timeout * 1000)
if sock:
if sock[0][0] == socket:
message = Message.decode(socket.recv_string(NOBLOCK))
return message.data
else:
raise TimeoutError("Didn't get an address after %d seconds."
% timeout)
finally:
socket.close()


class UnsecureZMQNameServer:
"""The name server."""

def __init__(self):
"""Set up the nameserver."""
self.loop = True
self.listener = None

def run(self, arec):
"""Run the listener and answer to requests."""
port = PORT

try:
with nslock:
self.listener = get_context().socket(REP)
self.listener.bind("tcp://*:" + str(port))
logger.debug("Listening on port %s", str(port))
poller = Poller()
poller.register(self.listener, POLLIN)
while self.loop:
with nslock:
socks = dict(poller.poll(1000))
if socks:
if socks.get(self.listener) == POLLIN:
msg = self.listener.recv_string()
else:
continue
logger.debug("Replying to request: " + str(msg))
msg = Message.decode(msg)
active_address = get_active_address(msg.data["service"], arec)
self.listener.send_unicode(str(active_address))
except KeyboardInterrupt:
# Needed to stop the nameserver.
pass
finally:
self.stop()

def stop(self):
"""Stop the name server."""
self.listener.setsockopt(LINGER, 1)
self.loop = False
with nslock:
self.listener.close()
8 changes: 5 additions & 3 deletions posttroll/backends/zmq/publisher.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""ZMQ implementation of the publisher."""

import logging
from threading import Lock
from urllib.parse import urlsplit, urlunsplit

import zmq
import logging

from posttroll import get_context
from posttroll.backends.zmq import _set_tcp_keepalive
from posttroll.backends.zmq import _set_tcp_keepalive, get_context

LOGGER = logging.getLogger(__name__)

Expand Down
16 changes: 8 additions & 8 deletions posttroll/backends/zmq/subscriber.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
"""ZMQ implementation of the subscriber."""

import logging
from threading import Lock
from urllib.parse import urlsplit
from posttroll.message import Message
from zmq import Poller, SUB, SUBSCRIBE, POLLIN, PULL, ZMQError, NOBLOCK, LINGER
from time import sleep
import logging

from posttroll import get_context
from posttroll.backends.zmq import _set_tcp_keepalive
from urllib.parse import urlsplit

from zmq import LINGER, NOBLOCK, POLLIN, PULL, SUB, SUBSCRIBE, Poller, ZMQError

from posttroll.backends.zmq import _set_tcp_keepalive, get_context
from posttroll.message import Message

LOGGER = logging.getLogger(__name__)

class UnsecureZMQSubscriber:
"""Unsecure ZMQ implementation of the subscriber."""

def __init__(self, addresses, topics='', message_filter=None, translate=False):
def __init__(self, addresses, topics="", message_filter=None, translate=False):
"""Initialize the subscriber."""
self._topics = topics
self._filter = message_filter
Expand Down
Loading

0 comments on commit 6a56d2d

Please sign in to comment.