From caa36c21670d04bdfdd3c38ae2d70207dcc42d21 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 29 Nov 2023 20:57:03 +0100 Subject: [PATCH 01/45] Refactor zmq pub and sub into a zmq backend --- posttroll/__init__.py | 12 - posttroll/address_receiver.py | 2 +- posttroll/backends/__init__.py | 0 posttroll/backends/zmq/__init__.py | 14 + posttroll/backends/zmq/publisher.py | 61 +++ posttroll/backends/zmq/subscriber.py | 202 +++++++++ posttroll/message_broadcaster.py | 35 +- posttroll/publisher.py | 90 ++-- posttroll/subscriber.py | 168 +------- posttroll/testing.py | 2 +- posttroll/tests/__init__.py | 23 +- posttroll/tests/test_bbmcast.py | 10 - posttroll/tests/test_pubsub.py | 605 +++++++++++++++------------ 13 files changed, 689 insertions(+), 535 deletions(-) create mode 100644 posttroll/backends/__init__.py create mode 100644 posttroll/backends/zmq/__init__.py create mode 100644 posttroll/backends/zmq/publisher.py create mode 100644 posttroll/backends/zmq/subscriber.py diff --git a/posttroll/__init__.py b/posttroll/__init__.py index b77c34b..30f40d9 100644 --- a/posttroll/__init__.py +++ b/posttroll/__init__.py @@ -75,17 +75,5 @@ def strp_isoformat(strg): return dat.replace(microsecond=mis) -def _set_tcp_keepalive(socket): - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE, config.get("tcp_keepalive", None)) - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_CNT, config.get("tcp_keepalive_cnt", None)) - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_IDLE, config.get("tcp_keepalive_idle", None)) - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_INTVL, config.get("tcp_keepalive_intvl", None)) - - -def _set_int_sockopt(socket, param, value): - if value is not None: - socket.setsockopt(param, int(value)) - - __version__ = get_versions()['version'] del get_versions diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index 27106ef..3cedf79 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -149,7 +149,7 @@ def _check_age(self, pub, min_interval=zero_seconds): def _run(self): """Run the receiver.""" port = broadcast_port - nameservers = [] + nameservers = False if self._multicast_enabled: while True: try: diff --git a/posttroll/backends/__init__.py b/posttroll/backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/posttroll/backends/zmq/__init__.py b/posttroll/backends/zmq/__init__.py new file mode 100644 index 0000000..5596169 --- /dev/null +++ b/posttroll/backends/zmq/__init__.py @@ -0,0 +1,14 @@ +import zmq + +from posttroll import config + +def _set_tcp_keepalive(socket): + _set_int_sockopt(socket, zmq.TCP_KEEPALIVE, config.get("tcp_keepalive", None)) + _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_CNT, config.get("tcp_keepalive_cnt", None)) + _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_IDLE, config.get("tcp_keepalive_idle", None)) + _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_INTVL, config.get("tcp_keepalive_intvl", None)) + + +def _set_int_sockopt(socket, param, value): + if value is not None: + socket.setsockopt(param, int(value)) diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py new file mode 100644 index 0000000..5ba7fbe --- /dev/null +++ b/posttroll/backends/zmq/publisher.py @@ -0,0 +1,61 @@ +from threading import Lock +from urllib.parse import urlsplit, urlunsplit +import zmq +import logging + +from posttroll import get_context +from posttroll.backends.zmq import _set_tcp_keepalive + +LOGGER = logging.getLogger(__name__) + + +class UnsecureZMQPublisher: + """Unsecure ZMQ implementation of the publisher class.""" + + def __init__(self, address, name="", min_port=None, max_port=None): + """Bind the publisher class to a port.""" + self.name = name + self.destination = address + self.publish_socket = None + self.min_port = min_port + self.max_port = max_port + self.port_number = None + self._pub_lock = Lock() + + def start(self): + """Start the publisher. + """ + self.publish_socket = get_context().socket(zmq.PUB) + _set_tcp_keepalive(self.publish_socket) + + self._bind() + LOGGER.info(f"Publisher for {self.destination} started on port {self.port_number}.") + return self + + def _bind(self): + # Check for port 0 (random port) + u__ = urlsplit(self.destination) + port = u__.port + if port == 0: + dest = urlunsplit((u__.scheme, u__.hostname, + u__.path, u__.query, u__.fragment)) + self.port_number = self.publish_socket.bind_to_random_port( + dest, + min_port=self.min_port, + max_port=self.max_port) + netloc = u__.hostname + ":" + str(self.port_number) + self.destination = urlunsplit((u__.scheme, netloc, u__.path, + u__.query, u__.fragment)) + else: + self.publish_socket.bind(self.destination) + self.port_number = port + + def send(self, msg): + """Send the given message.""" + with self._pub_lock: + self.publish_socket.send_string(msg) + + def stop(self): + """Stop the publisher.""" + self.publish_socket.setsockopt(zmq.LINGER, 1) + self.publish_socket.close() diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py new file mode 100644 index 0000000..980a2b0 --- /dev/null +++ b/posttroll/backends/zmq/subscriber.py @@ -0,0 +1,202 @@ +from threading import Lock +from urllib.parse import urlsplit +from posttroll.message import _MAGICK, Message +from zmq import Poller, SUB, SUBSCRIBE, POLLIN, PULL, ZMQError, NOBLOCK, LINGER +from time import sleep +import logging + +from posttroll import get_context +from posttroll.backends.zmq import _set_tcp_keepalive + + + +LOGGER = logging.getLogger(__name__) + +class UnsecureZMQSubscriber: + """Unsecure ZMQ implementation of the subscriber.""" + + def __init__(self, addresses, topics='', message_filter=None, translate=False): + """Initialize the subscriber.""" + self._topics = topics + self._filter = message_filter + self._translate = translate + + self.sub_addr = {} + self.addr_sub = {} + + self._hooks = [] + self._hooks_cb = {} + + self.poller = Poller() + self._lock = Lock() + + self.update(addresses) + + self._loop = None + + def add(self, address, topics=None): + """Add *address* to the subscribing list for *topics*. + + It topics is None we will subscribe to already specified topics. + """ + with self._lock: + if address in self.addresses: + return + + topics = topics or self._topics + LOGGER.info("Subscriber adding address %s with topics %s", + str(address), str(topics)) + subscriber = self._add_sub_socket(address, topics) + self.sub_addr[subscriber] = address + self.addr_sub[address] = subscriber + + def _add_sub_socket(self, address, topics): + subscriber = get_context().socket(SUB) + _set_tcp_keepalive(subscriber) + for t__ in topics: + subscriber.setsockopt_string(SUBSCRIBE, str(t__)) + subscriber.connect(address) + + if self.poller: + self.poller.register(subscriber, POLLIN) + return subscriber + + def remove(self, address): + """Remove *address* from the subscribing list for *topics*.""" + with self._lock: + try: + subscriber = self.addr_sub[address] + except KeyError: + return + LOGGER.info("Subscriber removing address %s", str(address)) + del self.addr_sub[address] + del self.sub_addr[subscriber] + self._remove_sub_socket(subscriber) + + def _remove_sub_socket(self, subscriber): + if self.poller: + self.poller.unregister(subscriber) + subscriber.close() + + def update(self, addresses): + """Update with a set of addresses.""" + if isinstance(addresses, str): + addresses = [addresses, ] + current_addresses, new_addresses = set(self.addresses), set(addresses) + addresses_to_remove = current_addresses.difference(new_addresses) + addresses_to_add = new_addresses.difference(current_addresses) + for addr in addresses_to_remove: + self.remove(addr) + for addr in addresses_to_add: + self.add(addr) + return bool(addresses_to_remove or addresses_to_add) + + def add_hook_sub(self, address, topics, callback): + """Specify a SUB *callback* in the same stream (thread) as the main receive loop. + + The callback will be called with the received messages from the + specified subscription. + + Good for operations, which is required to be done in the same thread as + the main recieve loop (e.q operations on the underlying sockets). + """ + topics = topics + LOGGER.info("Subscriber adding SUB hook %s for topics %s", + str(address), str(topics)) + socket = self._add_sub_socket(address, topics) + self._add_hook(socket, callback) + + def add_hook_pull(self, address, callback): + """Specify a PULL *callback* in the same stream (thread) as the main receive loop. + + The callback will be called with the received messages from the + specified subscription. Good for pushed 'inproc' messages from another thread. + """ + LOGGER.info("Subscriber adding PULL hook %s", str(address)) + socket = get_context().socket(PULL) + socket.connect(address) + if self.poller: + self.poller.register(socket, POLLIN) + self._add_hook(socket, callback) + + def _add_hook(self, socket, callback): + """Add a generic hook. The passed socket has to be "receive only".""" + self._hooks.append(socket) + self._hooks_cb[socket] = callback + + + @property + def addresses(self): + """Get the addresses.""" + return self.sub_addr.values() + + @property + def subscribers(self): + """Get the subscribers.""" + return self.sub_addr.keys() + + def recv(self, timeout=None): + """Receive, optionally with *timeout* in seconds.""" + if timeout: + timeout *= 1000. + + for sub in list(self.subscribers) + self._hooks: + self.poller.register(sub, POLLIN) + self._loop = True + try: + while self._loop: + sleep(0) + try: + socks = dict(self.poller.poll(timeout=timeout)) + if socks: + for sub in self.subscribers: + if sub in socks and socks[sub] == POLLIN: + received = sub.recv_string(NOBLOCK) + m__ = Message.decode(received) + if not self._filter or self._filter(m__): + if self._translate: + url = urlsplit(self.sub_addr[sub]) + host = url[1].split(":")[0] + m__.sender = (m__.sender.split("@")[0] + + "@" + host) + yield m__ + + for sub in self._hooks: + if sub in socks and socks[sub] == POLLIN: + m__ = Message.decode(sub.recv_string(NOBLOCK)) + self._hooks_cb[sub](m__) + else: + # timeout + yield None + except ZMQError as err: + if self._loop: + LOGGER.exception("Receive failed: %s", str(err)) + finally: + for sub in list(self.subscribers) + self._hooks: + self.poller.unregister(sub) + + def __call__(self, **kwargs): + """Handle calls with class instance.""" + return self.recv(**kwargs) + + def stop(self): + """Stop the subscriber.""" + self._loop = False + + def close(self): + """Close the subscriber: stop it and close the local subscribers.""" + self.stop() + for sub in list(self.subscribers) + self._hooks: + try: + sub.setsockopt(LINGER, 1) + sub.close() + except ZMQError: + pass + + def __del__(self): + """Clean up after the instance is deleted.""" + for sub in list(self.subscribers) + self._hooks: + try: + sub.close() + except Exception: # noqa: E722 + pass diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py index 53bfe52..1912c3b 100644 --- a/posttroll/message_broadcaster.py +++ b/posttroll/message_broadcaster.py @@ -29,7 +29,7 @@ from posttroll import message from posttroll.bbmcast import MulticastSender, MC_GROUP from posttroll import get_context -from zmq import REQ, LINGER +from zmq import REQ, LINGER, NOBLOCK, ZMQError __all__ = ('MessageBroadcaster', 'AddressBroadcaster', 'sendaddress') @@ -45,6 +45,7 @@ def __init__(self, default_port, receivers): self.default_port = default_port self.receivers = receivers + self._shutdown_event = threading.Event() def __call__(self, data): for receiver in self.receivers: @@ -61,16 +62,22 @@ def _send_to_address(self, address, data, timeout=10): else: socket.connect("tcp://%s" % address) socket.send_string(data) - message = socket.recv_string() - if message != "ok": - LOGGER.warn("invalid acknowledge received: %s" % message) + while not self._shutdown_event.is_set(): + try: + message = socket.recv_string(NOBLOCK) + except ZMQError: + self._shutdown_event.wait(.1) + continue + if message != "ok": + LOGGER.warn("invalid acknowledge received: %s" % message) + break finally: socket.close() def close(self): """Close the sender.""" - pass + self._shutdown_event.set() #----------------------------------------------------------------------------- # # General thread to broadcast messages. @@ -95,33 +102,32 @@ def __init__(self, msg, port, interval, designated_receivers=None): self._interval = interval self._message = msg - self._do_run = False - self._is_running = False + self._shutdown_event = threading.Event() self._thread = threading.Thread(target=self._run) def start(self): """Start the broadcasting.""" if self._interval > 0: - if not self._is_running: - self._do_run = True + if not self._thread.is_alive(): self._thread.start() return self def is_running(self): """Are we running.""" - return self._is_running + return self._thread.is_alive() def stop(self): """Stop the broadcasting.""" - self._do_run = False + self._shutdown_event.set() + self._sender.close() + self._thread.join() return self def _run(self): """Broadcasts forever.""" - self._is_running = True network_fail = False try: - while self._do_run: + while not self._shutdown_event.is_set(): try: if network_fail is True: LOGGER.info("Network connection re-established!") @@ -135,9 +141,8 @@ def _run(self): network_fail = True else: raise - time.sleep(self._interval) + self._shutdown_event.wait(self._interval) finally: - self._is_running = False self._sender.close() #----------------------------------------------------------------------------- diff --git a/posttroll/publisher.py b/posttroll/publisher.py index befa5e3..036d387 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -23,16 +23,10 @@ """The publisher module gives high-level tools to publish messages on a port.""" -import os import logging import socket from datetime import datetime, timedelta -from threading import Lock -from urllib.parse import urlsplit, urlunsplit -import zmq -from posttroll import get_context -from posttroll import _set_tcp_keepalive from posttroll.message import Message from posttroll.message_broadcaster import sendaddressservice from posttroll import config @@ -93,58 +87,35 @@ class Publisher: def __init__(self, address, name="", min_port=None, max_port=None): """Bind the publisher class to a port.""" - self.name = name - self.destination = address - self.publish_socket = None # Limit port range or use the defaults when no port is defined # by the user - self.min_port = min_port or int(config.get('pub_min_port', 49152)) - self.max_port = max_port or int(config.get('pub_max_port', 65536)) - self.port_number = None - + min_port = min_port or int(config.get('pub_min_port', 49152)) + max_port = max_port or int(config.get('pub_max_port', 65536)) # Initialize no heartbeat self._heartbeat = None - self._pub_lock = Lock() - + backend = config.get("backend", "unsecure_zmq") + if backend == "unsecure_zmq": + from posttroll.backends.zmq.publisher import UnsecureZMQPublisher + self._publisher = UnsecureZMQPublisher(address, name, min_port, max_port) + elif backend == "secure_zmq": + from posttroll.backends.zmq.publisher import UnsecureZMQPublisher + self._publisher = UnsecureZMQPublisher(address, name, min_port, max_port) + else: + raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") def start(self): - """Start the publisher. - """ - self.publish_socket = get_context().socket(zmq.PUB) - _set_tcp_keepalive(self.publish_socket) - - self.bind() - LOGGER.info("publisher started on port %s", str(self.port_number)) + """Start the publisher.""" + self._publisher.start() return self - def bind(self): - # Check for port 0 (random port) - u__ = urlsplit(self.destination) - port = u__.port - if port == 0: - dest = urlunsplit((u__.scheme, u__.hostname, - u__.path, u__.query, u__.fragment)) - self.port_number = self.publish_socket.bind_to_random_port( - dest, - min_port=self.min_port, - max_port=self.max_port) - netloc = u__.hostname + ":" + str(self.port_number) - self.destination = urlunsplit((u__.scheme, netloc, u__.path, - u__.query, u__.fragment)) - else: - self.publish_socket.bind(self.destination) - self.port_number = port - def send(self, msg): """Send the given message.""" - with self._pub_lock: - self.publish_socket.send_string(msg) + return self._publisher.send(msg) def stop(self): """Stop the publisher.""" - self.publish_socket.setsockopt(zmq.LINGER, 1) - self.publish_socket.close() + return self._publisher.stop() def close(self): """Alias for stop.""" @@ -156,6 +127,16 @@ def heartbeat(self, min_interval=0): self._heartbeat = _PublisherHeartbeat(self) self._heartbeat(min_interval) + @property + def name(self): + """Get the name of the publisher.""" + return self._publisher.name + + @property + def port_number(self): + """Get the port number from the actual publisher.""" + return self._publisher.port_number + class _PublisherHeartbeat: """Publisher for heartbeat.""" @@ -214,17 +195,18 @@ def __init__(self, name, port=0, aliases=None, broadcast_interval=2, def start(self): """Start the publisher.""" - pub_addr = _get_publish_address(self._port) + pub_addr = _create_tcp_publish_address(self._port) self._publisher = self._publisher_class(pub_addr, self._name, min_port=self.min_port, - max_port=self.max_port).start() - LOGGER.debug("entering publish %s", str(self._publisher.destination)) - addr = _get_publish_address(self._publisher.port_number, str(get_own_ip())) + max_port=self.max_port) + self._publisher.start() + addr = _create_tcp_publish_address(self._publisher.port_number, str(get_own_ip())) self._broadcaster = sendaddressservice(self._name, addr, self._aliases, self._broadcast_interval, - self._nameservers).start() - return self._publisher + self._nameservers) + self._broadcaster.start() + return self def send(self, msg): """Send a *msg*.""" @@ -244,8 +226,12 @@ def close(self): """Alias for stop.""" self.stop() + @property + def port_number(self): + return self._publisher.port_number + -def _get_publish_address(port, ip_address="*"): +def _create_tcp_publish_address(port, ip_address="*"): return "tcp://" + ip_address + ":" + str(port) @@ -320,7 +306,7 @@ def create_publisher_from_dict_config(settings): def _get_publisher_instance(settings): - publisher_address = _get_publish_address(settings['port']) + publisher_address = _create_tcp_publish_address(settings['port']) publisher_name = settings.get("name", "") min_port = settings.get("min_port") max_port = settings.get("max_port") diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 7ea5d80..a34fad1 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -24,20 +24,14 @@ """Simple library to subscribe to messages.""" -from time import sleep + import logging import time from datetime import datetime, timedelta -from threading import Lock -from urllib.parse import urlsplit - -# pylint: disable=E0611 -from zmq import LINGER, NOBLOCK, POLLIN, PULL, SUB, SUBSCRIBE, Poller, ZMQError -# pylint: enable=E0611 -from posttroll import get_context -from posttroll import _set_tcp_keepalive -from posttroll.message import _MAGICK, Message +from posttroll import config +from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber +from posttroll.message import _MAGICK from posttroll.ns import get_pub_address LOGGER = logging.getLogger(__name__) @@ -69,79 +63,28 @@ class Subscriber: def __init__(self, addresses, topics='', message_filter=None, translate=False): """Initialize the subscriber.""" - self._topics = self._magickfy_topics(topics) - self._filter = message_filter - self._translate = translate - - self.sub_addr = {} - self.addr_sub = {} - - self._hooks = [] - self._hooks_cb = {} - - self.poller = Poller() - self._lock = Lock() - - self.update(addresses) - - self._loop = None + topics = self._magickfy_topics(topics) + backend = config.get("backend", "unsecure_zmq") + if backend == "unsecure_zmq": + self._subscriber = UnsecureZMQSubscriber(addresses, topics=topics, + message_filter=message_filter, translate=translate) + else: + raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") def add(self, address, topics=None): """Add *address* to the subscribing list for *topics*. - It topics is None we will subscibe to already specified topics. + It topics is None we will subscribe to already specified topics. """ - with self._lock: - if address in self.addresses: - return - - topics = self._magickfy_topics(topics) or self._topics - LOGGER.info("Subscriber adding address %s with topics %s", - str(address), str(topics)) - subscriber = self._add_sub_socket(address, topics) - self.sub_addr[subscriber] = address - self.addr_sub[address] = subscriber - - def _add_sub_socket(self, address, topics): - subscriber = get_context().socket(SUB) - _set_tcp_keepalive(subscriber) - for t__ in topics: - subscriber.setsockopt_string(SUBSCRIBE, str(t__)) - subscriber.connect(address) - - if self.poller: - self.poller.register(subscriber, POLLIN) - return subscriber + return self._subscriber.add(address, self._magickfy_topics(topics)) def remove(self, address): """Remove *address* from the subscribing list for *topics*.""" - with self._lock: - try: - subscriber = self.addr_sub[address] - except KeyError: - return - LOGGER.info("Subscriber removing address %s", str(address)) - del self.addr_sub[address] - del self.sub_addr[subscriber] - self._remove_sub_socket(subscriber) - - def _remove_sub_socket(self, subscriber): - if self.poller: - self.poller.unregister(subscriber) - subscriber.close() + return self._subscriber.remove(address) def update(self, addresses): """Update with a set of addresses.""" - if isinstance(addresses, str): - addresses = [addresses, ] - current_addresses, new_addresses = set(self.addresses), set(addresses) - addresses_to_remove = current_addresses.difference(new_addresses) - addresses_to_add = new_addresses.difference(current_addresses) - for addr in addresses_to_remove: - self.remove(addr) - for addr in addresses_to_add: - self.add(addr) - return bool(addresses_to_remove or addresses_to_add) + return self._subscriber.update(addresses) def add_hook_sub(self, address, topics, callback): """Specify a SUB *callback* in the same stream (thread) as the main receive loop. @@ -152,11 +95,7 @@ def add_hook_sub(self, address, topics, callback): Good for operations, which is required to be done in the same thread as the main recieve loop (e.q operations on the underlying sockets). """ - topics = self._magickfy_topics(topics) - LOGGER.info("Subscriber adding SUB hook %s for topics %s", - str(address), str(topics)) - socket = self._add_sub_socket(address, topics) - self._add_hook(socket, callback) + return self._subscriber.add_hook_sub(address, self._magickfy_topics(topics), callback) def add_hook_pull(self, address, callback): """Specify a PULL *callback* in the same stream (thread) as the main receive loop. @@ -164,85 +103,33 @@ def add_hook_pull(self, address, callback): The callback will be called with the received messages from the specified subscription. Good for pushed 'inproc' messages from another thread. """ - LOGGER.info("Subscriber adding PULL hook %s", str(address)) - socket = get_context().socket(PULL) - socket.connect(address) - if self.poller: - self.poller.register(socket, POLLIN) - self._add_hook(socket, callback) - - def _add_hook(self, socket, callback): - """Add a generic hook. The passed socket has to be "receive only".""" - self._hooks.append(socket) - self._hooks_cb[socket] = callback - + return self._subscriber.add_hook_pull(address, callback) @property def addresses(self): """Get the addresses.""" - return self.sub_addr.values() + return self._subscriber.addresses @property def subscribers(self): """Get the subscribers.""" - return self.sub_addr.keys() + return self._subscriber.subscribers def recv(self, timeout=None): """Receive, optionally with *timeout* in seconds.""" - if timeout: - timeout *= 1000. - - for sub in list(self.subscribers) + self._hooks: - self.poller.register(sub, POLLIN) - self._loop = True - try: - while self._loop: - sleep(0) - try: - socks = dict(self.poller.poll(timeout=timeout)) - if socks: - for sub in self.subscribers: - if sub in socks and socks[sub] == POLLIN: - m__ = Message.decode(sub.recv_string(NOBLOCK)) - if not self._filter or self._filter(m__): - if self._translate: - url = urlsplit(self.sub_addr[sub]) - host = url[1].split(":")[0] - m__.sender = (m__.sender.split("@")[0] - + "@" + host) - yield m__ - - for sub in self._hooks: - if sub in socks and socks[sub] == POLLIN: - m__ = Message.decode(sub.recv_string(NOBLOCK)) - self._hooks_cb[sub](m__) - else: - # timeout - yield None - except ZMQError as err: - if self._loop: - LOGGER.exception("Receive failed: %s", str(err)) - finally: - for sub in list(self.subscribers) + self._hooks: - self.poller.unregister(sub) + return self._subscriber.recv(timeout) def __call__(self, **kwargs): """Handle calls with class instance.""" - return self.recv(**kwargs) + return self._subscriber(**kwargs) def stop(self): """Stop the subscriber.""" - self._loop = False + return self._subscriber.stop() def close(self): """Close the subscriber: stop it and close the local subscribers.""" - self.stop() - for sub in list(self.subscribers) + self._hooks: - try: - sub.setsockopt(LINGER, 1) - sub.close() - except ZMQError: - pass + return self._subscriber.close() @staticmethod def _magickfy_topics(topics): @@ -263,15 +150,6 @@ def _magickfy_topics(topics): ts_.append(t__) return ts_ - def __del__(self): - """Clean up after the instance is deleted.""" - for sub in list(self.subscribers) + self._hooks: - try: - sub.close() - except Exception: # noqa: E722 - pass - - class NSSubscriber: """Automatically subscribe to *services*. diff --git a/posttroll/testing.py b/posttroll/testing.py index 501aa53..2563d11 100644 --- a/posttroll/testing.py +++ b/posttroll/testing.py @@ -10,7 +10,7 @@ def patched_subscriber_recv(messages): @contextmanager def patched_publisher(): - """Patch the Subscriber object to return given messages.""" + """Patch the Publisher object to return given messages.""" from unittest import mock published = [] diff --git a/posttroll/tests/__init__.py b/posttroll/tests/__init__.py index 88e1689..5e53b2b 100644 --- a/posttroll/tests/__init__.py +++ b/posttroll/tests/__init__.py @@ -20,25 +20,4 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -"""Tests package. -""" - -from posttroll.tests import test_bbmcast, test_message, test_pubsub -import unittest - - -def suite(): - """The global test suite. - """ - mysuite = unittest.TestSuite() - # Test the documentation strings - # Use the unittests also - mysuite.addTests(test_bbmcast.suite()) - mysuite.addTests(test_message.suite()) - mysuite.addTests(test_pubsub.suite()) - - return mysuite - - -def load_tests(loader, tests, pattern): - return suite() +"""Tests package.""" diff --git a/posttroll/tests/test_bbmcast.py b/posttroll/tests/test_bbmcast.py index a7bea8d..45b19b6 100644 --- a/posttroll/tests/test_bbmcast.py +++ b/posttroll/tests/test_bbmcast.py @@ -107,13 +107,3 @@ def test_mcast_receiver(self): str(random.randint(0, 255)) + "." + str(random.randint(0, 255))) self.assertRaises(error, bbmcast.mcast_receiver, mcport, mcgroup) - - -def suite(): - """The suite for test_bbmcast. - """ - loader = unittest.TestLoader() - mysuite = unittest.TestSuite() - mysuite.addTest(loader.loadTestsFromTestCase(TestBB)) - - return mysuite diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 2da79b9..ec1642c 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -31,192 +31,193 @@ from contextlib import contextmanager import posttroll +from posttroll.ns import NameServer +from posttroll.publisher import create_publisher_from_dict_config +from posttroll.subscriber import Subscribe, Subscriber, create_subscriber_from_dict_config import pytest from donfig import Config test_lock = Lock() -class TestNS(unittest.TestCase): - """Test the nameserver.""" - - def setUp(self): - """Set up the testing class.""" - from posttroll.ns import NameServer - test_lock.acquire() - self.ns = NameServer(max_age=timedelta(seconds=3)) - self.thr = Thread(target=self.ns.run) - self.thr.start() - - def tearDown(self): - """Clean up after the tests have run.""" - self.ns.stop() - self.thr.join() - time.sleep(2) - test_lock.release() - - def test_pub_addresses(self): - """Test retrieving addresses.""" - from posttroll.ns import get_pub_addresses - from posttroll.publisher import Publish - - with Publish(str("data_provider"), 0, ["this_data"], broadcast_interval=0.1): - time.sleep(.3) - res = get_pub_addresses(["this_data"], timeout=.5) - assert len(res) == 1 - expected = {u'status': True, - u'service': [u'data_provider', u'this_data'], - u'name': u'address'} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] - res = get_pub_addresses([str("data_provider")]) - assert len(res) == 1 - expected = {u'status': True, - u'service': [u'data_provider', u'this_data'], - u'name': u'address'} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] - - def test_pub_sub_ctx(self): - """Test publish and subscribe.""" - from posttroll.message import Message - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe - - with Publish("data_provider", 0, ["this_data"]) as pub: - with Subscribe("this_data", "counter") as sub: - for counter in range(5): - message = Message("/counter", "info", str(counter)) - pub.send(str(message)) - time.sleep(1) - msg = next(sub.recv(2)) - if msg is not None: - assert str(msg) == str(message) - tested = True - sub.close() - assert tested - - def test_pub_sub_add_rm(self): - """Test adding and removing publishers.""" - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe - - time.sleep(4) - with Subscribe("this_data", "counter", True) as sub: - assert len(sub.sub_addr) == 0 - with Publish("data_provider", 0, ["this_data"]): - time.sleep(4) - next(sub.recv(2)) - assert len(sub.sub_addr) == 1 - time.sleep(3) - for msg in sub.recv(2): - if msg is None: - break - time.sleep(3) - assert len(sub.sub_addr) == 0 - with Publish("data_provider_2", 0, ["another_data"]): - time.sleep(4) - next(sub.recv(2)) - assert len(sub.sub_addr) == 0 - sub.close() - - -class TestNSWithoutMulticasting(unittest.TestCase): - """Test the nameserver.""" - - def setUp(self): - """Set up the testing class.""" - from posttroll.ns import NameServer - test_lock.acquire() - self.nameservers = ['localhost'] - self.ns = NameServer(max_age=timedelta(seconds=3), - multicast_enabled=False) - self.thr = Thread(target=self.ns.run) - self.thr.start() - - def tearDown(self): - """Clean up after the tests have run.""" - self.ns.stop() - self.thr.join() - time.sleep(2) - test_lock.release() - - def test_pub_addresses(self): - """Test retrieving addresses.""" - from posttroll.ns import get_pub_addresses - from posttroll.publisher import Publish - - with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers): - time.sleep(3) - res = get_pub_addresses(["this_data"]) - self.assertEqual(len(res), 1) - expected = {u'status': True, - u'service': [u'data_provider', u'this_data'], - u'name': u'address'} - for key, val in expected.items(): - self.assertEqual(res[0][key], val) - self.assertTrue("receive_time" in res[0]) - self.assertTrue("URI" in res[0]) - res = get_pub_addresses(["data_provider"]) - self.assertEqual(len(res), 1) - expected = {u'status': True, - u'service': [u'data_provider', u'this_data'], - u'name': u'address'} - for key, val in expected.items(): - self.assertEqual(res[0][key], val) - self.assertTrue("receive_time" in res[0]) - self.assertTrue("URI" in res[0]) - - def test_pub_sub_ctx(self): - """Test publish and subscribe.""" - from posttroll.message import Message - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe - - with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers) as pub: - with Subscribe("this_data", "counter") as sub: - for counter in range(5): - message = Message("/counter", "info", str(counter)) - pub.send(str(message)) - time.sleep(1) - msg = next(sub.recv(2)) - if msg is not None: - self.assertEqual(str(msg), str(message)) - tested = True - sub.close() - self.assertTrue(tested) - - def test_pub_sub_add_rm(self): - """Test adding and removing publishers.""" - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe - - time.sleep(4) - with Subscribe("this_data", "counter", True) as sub: - self.assertEqual(len(sub.sub_addr), 0) - with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers): - time.sleep(4) - next(sub.recv(2)) - self.assertEqual(len(sub.sub_addr), 1) - time.sleep(3) - for msg in sub.recv(2): - if msg is None: - break - - time.sleep(3) - self.assertEqual(len(sub.sub_addr), 0) - with Publish("data_provider_2", 0, ["another_data"], - nameservers=self.nameservers): - time.sleep(4) - next(sub.recv(2)) - self.assertEqual(len(sub.sub_addr), 0) +# class TestNS(unittest.TestCase): +# """Test the nameserver.""" + +# def setUp(self): +# """Set up the testing class.""" +# test_lock.acquire() +# self.ns = NameServer(max_age=timedelta(seconds=3)) +# self.thr = Thread(target=self.ns.run) +# self.thr.start() + +# def tearDown(self): +# """Clean up after the tests have run.""" +# self.ns.stop() +# self.thr.join() +# time.sleep(2) +# test_lock.release() + +# def test_pub_addresses(self): +# """Test retrieving addresses.""" +# from posttroll.ns import get_pub_addresses +# from posttroll.publisher import Publish + +# with Publish(str("data_provider"), 0, ["this_data"], broadcast_interval=0.1): +# time.sleep(.3) +# res = get_pub_addresses(["this_data"], timeout=.5) +# assert len(res) == 1 +# expected = {u'status': True, +# u'service': [u'data_provider', u'this_data'], +# u'name': u'address'} +# for key, val in expected.items(): +# assert res[0][key] == val +# assert "receive_time" in res[0] +# assert "URI" in res[0] +# res = get_pub_addresses([str("data_provider")]) +# assert len(res) == 1 +# expected = {u'status': True, +# u'service': [u'data_provider', u'this_data'], +# u'name': u'address'} +# for key, val in expected.items(): +# assert res[0][key] == val +# assert "receive_time" in res[0] +# assert "URI" in res[0] + +# def test_pub_sub_ctx(self): +# """Test publish and subscribe.""" +# from posttroll.message import Message +# from posttroll.publisher import Publish +# from posttroll.subscriber import Subscribe + +# with Publish("data_provider", 0, ["this_data"]) as pub: +# with Subscribe("this_data", "counter") as sub: +# for counter in range(5): +# message = Message("/counter", "info", str(counter)) +# pub.send(str(message)) +# time.sleep(1) +# msg = next(sub.recv(2)) +# if msg is not None: +# assert str(msg) == str(message) +# tested = True +# sub.close() +# assert tested + +# def test_pub_sub_add_rm(self): +# """Test adding and removing publishers.""" +# from posttroll.publisher import Publish +# from posttroll.subscriber import Subscribe + +# time.sleep(4) +# with Subscribe("this_data", "counter", True) as sub: +# assert len(sub.sub_addr) == 0 +# with Publish("data_provider", 0, ["this_data"]): +# time.sleep(4) +# next(sub.recv(2)) +# assert len(sub.sub_addr) == 1 +# time.sleep(3) +# for msg in sub.recv(2): +# if msg is None: +# break +# time.sleep(3) +# assert len(sub.sub_addr) == 0 +# with Publish("data_provider_2", 0, ["another_data"]): +# time.sleep(4) +# next(sub.recv(2)) +# assert len(sub.sub_addr) == 0 +# sub.close() + + +# class TestNSWithoutMulticasting(unittest.TestCase): +# """Test the nameserver.""" + +# def setUp(self): +# """Set up the testing class.""" +# test_lock.acquire() +# self.nameservers = ['localhost'] +# self.ns = NameServer(max_age=timedelta(seconds=3), +# multicast_enabled=False) +# self.thr = Thread(target=self.ns.run) +# self.thr.start() + +# def tearDown(self): +# """Clean up after the tests have run.""" +# self.ns.stop() +# self.thr.join() +# time.sleep(2) +# test_lock.release() + +# def test_pub_addresses(self): +# """Test retrieving addresses.""" +# from posttroll.ns import get_pub_addresses +# from posttroll.publisher import Publish + +# with Publish("data_provider", 0, ["this_data"], +# nameservers=self.nameservers): +# time.sleep(3) +# res = get_pub_addresses(["this_data"]) +# self.assertEqual(len(res), 1) +# expected = {u'status': True, +# u'service': [u'data_provider', u'this_data'], +# u'name': u'address'} +# for key, val in expected.items(): +# self.assertEqual(res[0][key], val) +# self.assertTrue("receive_time" in res[0]) +# self.assertTrue("URI" in res[0]) +# res = get_pub_addresses(["data_provider"]) +# self.assertEqual(len(res), 1) +# expected = {u'status': True, +# u'service': [u'data_provider', u'this_data'], +# u'name': u'address'} +# for key, val in expected.items(): +# self.assertEqual(res[0][key], val) +# self.assertTrue("receive_time" in res[0]) +# self.assertTrue("URI" in res[0]) + +# def test_pub_sub_ctx(self): +# """Test publish and subscribe.""" +# from posttroll.message import Message +# from posttroll.publisher import Publish +# from posttroll.subscriber import Subscribe + +# with Publish("data_provider", 0, ["this_data"], +# nameservers=self.nameservers) as pub: +# with Subscribe("this_data", "counter") as sub: +# for counter in range(5): +# message = Message("/counter", "info", str(counter)) +# pub.send(str(message)) +# time.sleep(1) +# msg = next(sub.recv(2)) +# if msg is not None: +# self.assertEqual(str(msg), str(message)) +# tested = True +# sub.close() +# self.assertTrue(tested) + +# def test_pub_sub_add_rm(self): +# """Test adding and removing publishers.""" +# from posttroll.publisher import Publish +# from posttroll.subscriber import Subscribe + +# time.sleep(4) +# with Subscribe("this_data", "counter", True) as sub: +# self.assertEqual(len(sub.sub_addr), 0) +# with Publish("data_provider", 0, ["this_data"], +# nameservers=self.nameservers): +# time.sleep(4) +# next(sub.recv(2)) +# self.assertEqual(len(sub.sub_addr), 1) +# time.sleep(3) +# for msg in sub.recv(2): +# if msg is None: +# break + +# time.sleep(3) +# self.assertEqual(len(sub.sub_addr), 0) +# with Publish("data_provider_2", 0, ["another_data"], +# nameservers=self.nameservers): +# time.sleep(4) +# next(sub.recv(2)) +# self.assertEqual(len(sub.sub_addr), 0) class TestPubSub(unittest.TestCase): @@ -234,9 +235,8 @@ def test_pub_address_timeout(self): """Test timeout in offline nameserver.""" from posttroll.ns import get_pub_address from posttroll.ns import TimeoutError - - self.assertRaises(TimeoutError, - get_pub_address, ["this_data", 0.5]) + with pytest.raises(TimeoutError): + get_pub_address("this_data", 0.05) def test_pub_suber(self): """Test publisher and subscriber.""" @@ -249,11 +249,14 @@ def test_pub_suber(self): pub = Publisher(pub_address).start() addr = pub_address[:-1] + str(pub.port_number) sub = Subscriber([addr], '/counter') + # wait a bit before sending the first message so that the subscriber is ready + time.sleep(.002) + tested = False for counter in range(5): message = Message("/counter", "info", str(counter)) pub.send(str(message)) - time.sleep(1) + time.sleep(.05) msg = next(sub.recv(2)) if msg is not None: @@ -266,20 +269,22 @@ def test_pub_sub_ctx_no_nameserver(self): """Test publish and subscribe.""" from posttroll.message import Message from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe with Publish("data_provider", 40000, nameservers=False) as pub: with Subscribe(topics="counter", nameserver=False, addresses=["tcp://127.0.0.1:40000"]) as sub: + assert isinstance(sub, Subscriber) + # wait a bit before sending the first message so that the subscriber is ready + time.sleep(.002) for counter in range(5): message = Message("/counter", "info", str(counter)) pub.send(str(message)) - time.sleep(1) + time.sleep(.05) msg = next(sub.recv(2)) if msg is not None: self.assertEqual(str(msg), str(message)) tested = True sub.close() - self.assertTrue(tested) + assert tested class TestPub(unittest.TestCase): @@ -293,7 +298,7 @@ def tearDown(self): """Clean up after the tests have run.""" test_lock.release() - def test_pub_unicode(self): + def test_pub_supports_unicode(self): """Test publishing messages in Unicode.""" from posttroll.message import Message from posttroll.publisher import Publish @@ -305,16 +310,14 @@ def test_pub_unicode(self): except UnicodeDecodeError: self.fail("Sending raised UnicodeDecodeError unexpectedly!") - def test_pub_minmax_port(self): - """Test user defined port range.""" - import os - + def test_pub_minmax_port_from_config(self): + """Test config defined port range.""" # Using environment variables to set port range # Try over a range of ports just in case the single port is reserved for port in range(40000, 50000): # Set the port range to config with posttroll.config.set(pub_min_port=str(port), pub_max_port=str(port + 1)): - res = _get_port(min_port=None, max_port=None) + res = _get_port_from_publish_instance(min_port=None, max_port=None) if res is False: # The port wasn't free, try another one continue @@ -322,10 +325,12 @@ def test_pub_minmax_port(self): self.assertEqual(res, port) break + def test_pub_minmax_port_from_instanciation(self): + """Test port range defined at instanciation.""" # Using range of ports defined at instantation time, this # should override environment variables for port in range(50000, 60000): - res = _get_port(min_port=port, max_port=port+1) + res = _get_port_from_publish_instance(min_port=port, max_port=port+1) if res is False: # The port wasn't free, try again continue @@ -334,7 +339,7 @@ def test_pub_minmax_port(self): break -def _get_port(min_port=None, max_port=None): +def _get_port_from_publish_instance(min_port=None, max_port=None): from zmq.error import ZMQError from posttroll.publisher import Publish @@ -364,7 +369,6 @@ def tearDown(self): """Clean up after the tests have run.""" self.ns.stop() self.thr.join() - time.sleep(2) test_lock.release() def test_listener_container(self): @@ -373,10 +377,10 @@ def test_listener_container(self): from posttroll.publisher import NoisyPublisher from posttroll.listener import ListenerContainer - pub = NoisyPublisher("test") + pub = NoisyPublisher("test", broadcast_interval=0.1) pub.start() sub = ListenerContainer(topics=["/counter"]) - time.sleep(2) + time.sleep(.1) for counter in range(5): tested = False msg_out = Message("/counter", "info", str(counter)) @@ -384,7 +388,7 @@ def test_listener_container(self): msg_in = sub.output_queue.get(True, 1) if msg_in is not None: - self.assertEqual(str(msg_in), str(msg_out)) + assert str(msg_in) == str(msg_out) tested = True self.assertTrue(tested) pub.stop() @@ -450,22 +454,19 @@ def test_localhost_restriction(self, mcrec, pub, msg): class TestPublisherDictConfig(unittest.TestCase): """Test configuring publishers with a dictionary.""" - @mock.patch('posttroll.publisher.Publisher') - def test_publisher_is_selected(self, Publisher): + def test_publisher_is_selected(self): """Test that Publisher is selected as publisher class.""" - from posttroll.publisher import create_publisher_from_dict_config + from posttroll.publisher import Publisher settings = {'port': 12345, 'nameservers': False} pub = create_publisher_from_dict_config(settings) - Publisher.assert_called_once() + assert isinstance(pub, Publisher) assert pub is not None @mock.patch('posttroll.publisher.Publisher') def test_publisher_all_arguments(self, Publisher): """Test that only valid arguments are passed to Publisher.""" - from posttroll.publisher import create_publisher_from_dict_config - settings = {'port': 12345, 'nameservers': False, 'name': 'foo', 'min_port': 40000, 'max_port': 41000, 'invalid_arg': 'bar'} _ = create_publisher_from_dict_config(settings) @@ -475,31 +476,26 @@ def test_publisher_all_arguments(self, Publisher): def test_no_name_raises_keyerror(self): """Trying to create a NoisyPublisher without a given name will raise KeyError.""" - from posttroll.publisher import create_publisher_from_dict_config - - with self.assertRaises(KeyError): + with pytest.raises(KeyError): _ = create_publisher_from_dict_config(dict()) - @mock.patch('posttroll.publisher.NoisyPublisher') - def test_noisypublisher_is_selected_only_name(self, NoisyPublisher): + def test_noisypublisher_is_selected_only_name(self): """Test that NoisyPublisher is selected as publisher class.""" - from posttroll.publisher import create_publisher_from_dict_config + from posttroll.publisher import NoisyPublisher settings = {'name': 'publisher_name'} pub = create_publisher_from_dict_config(settings) - NoisyPublisher.assert_called_once() - assert pub is not None + assert isinstance(pub, NoisyPublisher) - @mock.patch('posttroll.publisher.NoisyPublisher') - def test_noisypublisher_is_selected_name_and_port(self, NoisyPublisher): + def test_noisypublisher_is_selected_name_and_port(self): """Test that NoisyPublisher is selected as publisher class.""" - from posttroll.publisher import create_publisher_from_dict_config + from posttroll.publisher import NoisyPublisher settings = {'name': 'publisher_name', 'port': 40000} - _ = create_publisher_from_dict_config(settings) - NoisyPublisher.assert_called_once() + pub = create_publisher_from_dict_config(settings) + assert isinstance(pub, NoisyPublisher) @mock.patch('posttroll.publisher.NoisyPublisher') def test_noisypublisher_all_arguments(self, NoisyPublisher): @@ -513,37 +509,33 @@ def test_noisypublisher_all_arguments(self, NoisyPublisher): _check_valid_settings_in_call(settings, NoisyPublisher, ignore=['name']) assert NoisyPublisher.call_args[0][0] == settings["name"] - @mock.patch('posttroll.publisher.Publisher') - def test_publish_is_not_noisy(self, Publisher): + def test_publish_is_not_noisy(self): """Test that Publisher is selected with the context manager when it should be.""" - from posttroll.publisher import Publish + from posttroll.publisher import Publish, Publisher - with Publish("service_name", port=40000, nameservers=False): - Publisher.assert_called_once() + with Publish("service_name", port=40000, nameservers=False) as pub: + assert isinstance(pub, Publisher) - @mock.patch('posttroll.publisher.NoisyPublisher') - def test_publish_is_noisy_only_name(self, NoisyPublisher): + def test_publish_is_noisy_only_name(self): """Test that NoisyPublisher is selected with the context manager when only name is given.""" - from posttroll.publisher import Publish + from posttroll.publisher import Publish, NoisyPublisher - with Publish("service_name"): - NoisyPublisher.assert_called_once() + with Publish("service_name") as pub: + assert isinstance(pub, NoisyPublisher) - @mock.patch('posttroll.publisher.NoisyPublisher') - def test_publish_is_noisy_with_port(self, NoisyPublisher): + def test_publish_is_noisy_with_port(self): """Test that NoisyPublisher is selected with the context manager when port is given.""" - from posttroll.publisher import Publish + from posttroll.publisher import Publish, NoisyPublisher - with Publish("service_name", port=40000): - NoisyPublisher.assert_called_once() + with Publish("service_name", port=40001) as pub: + assert isinstance(pub, NoisyPublisher) - @mock.patch('posttroll.publisher.NoisyPublisher') - def test_publish_is_noisy_with_nameservers(self, NoisyPublisher): + def test_publish_is_noisy_with_nameservers(self): """Test that NoisyPublisher is selected with the context manager when nameservers are given.""" - from posttroll.publisher import Publish + from posttroll.publisher import Publish, NoisyPublisher - with Publish("service_name", nameservers=['a', 'b']): - NoisyPublisher.assert_called_once() + with Publish("service_name", nameservers=['a', 'b']) as pub: + assert isinstance(pub, NoisyPublisher) def _check_valid_settings_in_call(settings, pub_class, ignore=None): @@ -632,7 +624,7 @@ def test_dict_config_full_subscriber(Subscriber_update): @pytest.fixture -def tcp_keepalive_settings(monkeypatch): +def oldtcp_keepalive_settings(monkeypatch): """Set TCP Keepalive settings.""" monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE", "1") monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE_CNT", "10") @@ -641,6 +633,12 @@ def tcp_keepalive_settings(monkeypatch): with reset_config_for_tests(): yield +@pytest.fixture +def tcp_keepalive_settings(monkeypatch): + """Set TCP Keepalive settings.""" + from posttroll import config + with config.set(tcp_keepalive=1, tcp_keepalive_cnt=10, tcp_keepalive_idle=1, tcp_keepalive_intvl=1): + yield @contextmanager def reset_config_for_tests(): @@ -652,71 +650,124 @@ def reset_config_for_tests(): @pytest.fixture -def tcp_keepalive_no_settings(monkeypatch): +def tcp_keepalive_no_settings(): """Set TCP Keepalive settings.""" - monkeypatch.delenv("POSTTROLL_TCP_KEEPALIVE", raising=False) - monkeypatch.delenv("POSTTROLL_TCP_KEEPALIVE_CNT", raising=False) - monkeypatch.delenv("POSTTROLL_TCP_KEEPALIVE_IDLE", raising=False) - monkeypatch.delenv("POSTTROLL_TCP_KEEPALIVE_INTVL", raising=False) - with reset_config_for_tests(): + from posttroll import config + with config.set(tcp_keepalive=None, tcp_keepalive_cnt=None, tcp_keepalive_idle=None, tcp_keepalive_intvl=None): yield def test_publisher_tcp_keepalive(tcp_keepalive_settings): """Test that TCP Keepalive is set for Publisher if the environment variables are present.""" - socket = mock.MagicMock() - with mock.patch('posttroll.publisher.get_context') as get_context: - get_context.return_value.socket.return_value = socket - from posttroll.publisher import Publisher - - _ = Publisher("tcp://127.0.0.1:9000").start() - - _assert_tcp_keepalive(socket) + from posttroll.backends.zmq.publisher import UnsecureZMQPublisher + pub = UnsecureZMQPublisher("tcp://127.0.0.1:9000").start() + _assert_tcp_keepalive(pub.publish_socket) def test_publisher_tcp_keepalive_not_set(tcp_keepalive_no_settings): """Test that TCP Keepalive is not set on by default.""" - socket = mock.MagicMock() - with mock.patch('posttroll.publisher.get_context') as get_context: - get_context.return_value.socket.return_value = socket - from posttroll.publisher import Publisher - - _ = Publisher("tcp://127.0.0.1:9000").start() - _assert_no_tcp_keepalive(socket) + from posttroll.backends.zmq.publisher import UnsecureZMQPublisher + pub = UnsecureZMQPublisher("tcp://127.0.0.1:9000").start() + _assert_no_tcp_keepalive(pub.publish_socket) def test_subscriber_tcp_keepalive(tcp_keepalive_settings): """Test that TCP Keepalive is set for Subscriber if the environment variables are present.""" - socket = mock.MagicMock() - with mock.patch('posttroll.subscriber.get_context') as get_context: - get_context.return_value.socket.return_value = socket - from posttroll.subscriber import Subscriber + from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber + + sub = UnsecureZMQSubscriber("tcp://127.0.0.1:9000") - _ = Subscriber("tcp://127.0.0.1:9000") + assert len(sub.addr_sub.values()) == 1 - _assert_tcp_keepalive(socket) + _assert_tcp_keepalive(list(sub.addr_sub.values())[0]) def test_subscriber_tcp_keepalive_not_set(tcp_keepalive_no_settings): """Test that TCP Keepalive is not set on by default.""" - socket = mock.MagicMock() - with mock.patch('posttroll.subscriber.get_context') as get_context: - get_context.return_value.socket.return_value = socket - from posttroll.subscriber import Subscriber + from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber - _ = Subscriber("tcp://127.0.0.1:9000") + sub = UnsecureZMQSubscriber("tcp://127.0.0.1:9000") - _assert_no_tcp_keepalive(socket) + assert len(sub.addr_sub.values()) == 1 + + _assert_no_tcp_keepalive(list(sub.addr_sub.values())[0]) def _assert_tcp_keepalive(socket): import zmq - assert mock.call(zmq.TCP_KEEPALIVE, 1) in socket.setsockopt.mock_calls - assert mock.call(zmq.TCP_KEEPALIVE_CNT, 10) in socket.setsockopt.mock_calls - assert mock.call(zmq.TCP_KEEPALIVE_IDLE, 1) in socket.setsockopt.mock_calls - assert mock.call(zmq.TCP_KEEPALIVE_INTVL, 1) in socket.setsockopt.mock_calls + assert socket.getsockopt(zmq.TCP_KEEPALIVE) == 1 + assert socket.getsockopt(zmq.TCP_KEEPALIVE_CNT) == 10 + assert socket.getsockopt(zmq.TCP_KEEPALIVE_IDLE) == 1 + assert socket.getsockopt(zmq.TCP_KEEPALIVE_INTVL) == 1 def _assert_no_tcp_keepalive(socket): - assert "TCP_KEEPALIVE" not in str(socket.setsockopt.mock_calls) + import zmq + + assert socket.getsockopt(zmq.TCP_KEEPALIVE) == -1 + assert socket.getsockopt(zmq.TCP_KEEPALIVE_CNT) == -1 + assert socket.getsockopt(zmq.TCP_KEEPALIVE_IDLE) == -1 + assert socket.getsockopt(zmq.TCP_KEEPALIVE_INTVL) == -1 + + +def test_ipc_pubsub(): + from posttroll import config + with config.set(backend="unsecure_zmq"): + subscriber_settings = dict(addresses="ipc://bla.ipc", topics="", nameserver=False, port=10202) + sub = create_subscriber_from_dict_config(subscriber_settings) + from posttroll.publisher import Publisher + pub = Publisher("ipc://bla.ipc") + pub.start() + def delayed_send(msg): + time.sleep(.2) + from posttroll.message import Message + msg = Message(subject="/hi", atype="string", data=msg) + pub.send(str(msg)) + pub.stop() + from threading import Thread + Thread(target=delayed_send, args=["hi"]).start() + for msg in sub.recv(): + assert msg.data == "hi" + break + sub.stop() + +def test_ipc_pubsub_with_sec(): + from posttroll import config + with config.set(backend="secure_zmq"): + subscriber_settings = dict(addresses="ipc://bla.ipc", topics="", nameserver=False, port=10202) + sub = create_subscriber_from_dict_config(subscriber_settings) + from posttroll.publisher import Publisher + pub = Publisher("ipc://bla.ipc", secure=True) + pub.start() + def delayed_send(msg): + time.sleep(.2) + from posttroll.message import Message + msg = Message(subject="/hi", atype="string", data=msg) + pub.send(str(msg)) + pub.stop() + from threading import Thread + Thread(target=delayed_send, args=["hi"]).start() + for msg in sub.recv(): + assert msg.data == "hi" + break + sub.stop() + +def test_switch_to_unknown_backend(): + from posttroll import config + from posttroll.publisher import Publisher + from posttroll.subscriber import Subscriber + with config.set(backend="unsecure_and_deprecated"): + with pytest.raises(NotImplementedError): + Publisher("ipc://bla.ipc") + with pytest.raises(NotImplementedError): + Subscriber("ipc://bla.ipc") + +def test_switch_to_secure_zmq_backend(): + from posttroll import config + from posttroll.publisher import Publisher + from posttroll.subscriber import Subscriber + + with config.set(backend="secure_zmq"): + Publisher("ipc://bla.ipc") + Subscriber("ipc://bla.ipc") From 49b74d3ae07f00b96eb7ae35fa2ac7b8ef700796 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 29 Nov 2023 21:04:07 +0100 Subject: [PATCH 02/45] Fix style --- posttroll/backends/zmq/subscriber.py | 2 +- posttroll/message_broadcaster.py | 1 - posttroll/tests/test_pubsub.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index 980a2b0..46f3a3d 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -1,6 +1,6 @@ from threading import Lock from urllib.parse import urlsplit -from posttroll.message import _MAGICK, Message +from posttroll.message import Message from zmq import Poller, SUB, SUBSCRIBE, POLLIN, PULL, ZMQError, NOBLOCK, LINGER from time import sleep import logging diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py index 1912c3b..55eca43 100644 --- a/posttroll/message_broadcaster.py +++ b/posttroll/message_broadcaster.py @@ -21,7 +21,6 @@ # You should have received a copy of the GNU General Public License along with # pytroll. If not, see . -import time import threading import logging import errno diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index ec1642c..48e3528 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -359,7 +359,6 @@ class TestListenerContainer(unittest.TestCase): def setUp(self): """Set up the testing class.""" - from posttroll.ns import NameServer test_lock.acquire() self.ns = NameServer(max_age=timedelta(seconds=3)) self.thr = Thread(target=self.ns.run) From f15cda3acba635ae278f5a98f4741056621203e0 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 29 Nov 2023 21:04:34 +0100 Subject: [PATCH 03/45] Use python >= 3.10 for ci --- .github/workflows/ci.yaml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4033aca..2fbaede 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -10,7 +10,7 @@ jobs: strategy: fail-fast: true matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.10", "3.11", "3.12"] experimental: [false] steps: - name: Checkout source diff --git a/setup.py b/setup.py index e80d8d6..750dca2 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,6 @@ 'Topic :: Scientific/Engineering', 'Topic :: Communications' ], - python_requires='>=3.7', + python_requires='>=3.10', test_suite='posttroll.tests.suite', ) From ae3c70bc1bbebe6f88f63776ab8d7fcbe90294ae Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 29 Nov 2023 21:14:28 +0100 Subject: [PATCH 04/45] Fix test --- posttroll/tests/test_pubsub.py | 369 ++++++++++++++++----------------- 1 file changed, 179 insertions(+), 190 deletions(-) diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 48e3528..72bb357 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -40,184 +40,184 @@ test_lock = Lock() -# class TestNS(unittest.TestCase): -# """Test the nameserver.""" - -# def setUp(self): -# """Set up the testing class.""" -# test_lock.acquire() -# self.ns = NameServer(max_age=timedelta(seconds=3)) -# self.thr = Thread(target=self.ns.run) -# self.thr.start() - -# def tearDown(self): -# """Clean up after the tests have run.""" -# self.ns.stop() -# self.thr.join() -# time.sleep(2) -# test_lock.release() - -# def test_pub_addresses(self): -# """Test retrieving addresses.""" -# from posttroll.ns import get_pub_addresses -# from posttroll.publisher import Publish - -# with Publish(str("data_provider"), 0, ["this_data"], broadcast_interval=0.1): -# time.sleep(.3) -# res = get_pub_addresses(["this_data"], timeout=.5) -# assert len(res) == 1 -# expected = {u'status': True, -# u'service': [u'data_provider', u'this_data'], -# u'name': u'address'} -# for key, val in expected.items(): -# assert res[0][key] == val -# assert "receive_time" in res[0] -# assert "URI" in res[0] -# res = get_pub_addresses([str("data_provider")]) -# assert len(res) == 1 -# expected = {u'status': True, -# u'service': [u'data_provider', u'this_data'], -# u'name': u'address'} -# for key, val in expected.items(): -# assert res[0][key] == val -# assert "receive_time" in res[0] -# assert "URI" in res[0] - -# def test_pub_sub_ctx(self): -# """Test publish and subscribe.""" -# from posttroll.message import Message -# from posttroll.publisher import Publish -# from posttroll.subscriber import Subscribe - -# with Publish("data_provider", 0, ["this_data"]) as pub: -# with Subscribe("this_data", "counter") as sub: -# for counter in range(5): -# message = Message("/counter", "info", str(counter)) -# pub.send(str(message)) -# time.sleep(1) -# msg = next(sub.recv(2)) -# if msg is not None: -# assert str(msg) == str(message) -# tested = True -# sub.close() -# assert tested - -# def test_pub_sub_add_rm(self): -# """Test adding and removing publishers.""" -# from posttroll.publisher import Publish -# from posttroll.subscriber import Subscribe - -# time.sleep(4) -# with Subscribe("this_data", "counter", True) as sub: -# assert len(sub.sub_addr) == 0 -# with Publish("data_provider", 0, ["this_data"]): -# time.sleep(4) -# next(sub.recv(2)) -# assert len(sub.sub_addr) == 1 -# time.sleep(3) -# for msg in sub.recv(2): -# if msg is None: -# break -# time.sleep(3) -# assert len(sub.sub_addr) == 0 -# with Publish("data_provider_2", 0, ["another_data"]): -# time.sleep(4) -# next(sub.recv(2)) -# assert len(sub.sub_addr) == 0 -# sub.close() - - -# class TestNSWithoutMulticasting(unittest.TestCase): -# """Test the nameserver.""" - -# def setUp(self): -# """Set up the testing class.""" -# test_lock.acquire() -# self.nameservers = ['localhost'] -# self.ns = NameServer(max_age=timedelta(seconds=3), -# multicast_enabled=False) -# self.thr = Thread(target=self.ns.run) -# self.thr.start() - -# def tearDown(self): -# """Clean up after the tests have run.""" -# self.ns.stop() -# self.thr.join() -# time.sleep(2) -# test_lock.release() - -# def test_pub_addresses(self): -# """Test retrieving addresses.""" -# from posttroll.ns import get_pub_addresses -# from posttroll.publisher import Publish - -# with Publish("data_provider", 0, ["this_data"], -# nameservers=self.nameservers): -# time.sleep(3) -# res = get_pub_addresses(["this_data"]) -# self.assertEqual(len(res), 1) -# expected = {u'status': True, -# u'service': [u'data_provider', u'this_data'], -# u'name': u'address'} -# for key, val in expected.items(): -# self.assertEqual(res[0][key], val) -# self.assertTrue("receive_time" in res[0]) -# self.assertTrue("URI" in res[0]) -# res = get_pub_addresses(["data_provider"]) -# self.assertEqual(len(res), 1) -# expected = {u'status': True, -# u'service': [u'data_provider', u'this_data'], -# u'name': u'address'} -# for key, val in expected.items(): -# self.assertEqual(res[0][key], val) -# self.assertTrue("receive_time" in res[0]) -# self.assertTrue("URI" in res[0]) - -# def test_pub_sub_ctx(self): -# """Test publish and subscribe.""" -# from posttroll.message import Message -# from posttroll.publisher import Publish -# from posttroll.subscriber import Subscribe - -# with Publish("data_provider", 0, ["this_data"], -# nameservers=self.nameservers) as pub: -# with Subscribe("this_data", "counter") as sub: -# for counter in range(5): -# message = Message("/counter", "info", str(counter)) -# pub.send(str(message)) -# time.sleep(1) -# msg = next(sub.recv(2)) -# if msg is not None: -# self.assertEqual(str(msg), str(message)) -# tested = True -# sub.close() -# self.assertTrue(tested) - -# def test_pub_sub_add_rm(self): -# """Test adding and removing publishers.""" -# from posttroll.publisher import Publish -# from posttroll.subscriber import Subscribe - -# time.sleep(4) -# with Subscribe("this_data", "counter", True) as sub: -# self.assertEqual(len(sub.sub_addr), 0) -# with Publish("data_provider", 0, ["this_data"], -# nameservers=self.nameservers): -# time.sleep(4) -# next(sub.recv(2)) -# self.assertEqual(len(sub.sub_addr), 1) -# time.sleep(3) -# for msg in sub.recv(2): -# if msg is None: -# break - -# time.sleep(3) -# self.assertEqual(len(sub.sub_addr), 0) -# with Publish("data_provider_2", 0, ["another_data"], -# nameservers=self.nameservers): -# time.sleep(4) -# next(sub.recv(2)) -# self.assertEqual(len(sub.sub_addr), 0) +class TestNS(unittest.TestCase): + """Test the nameserver.""" + + def setUp(self): + """Set up the testing class.""" + test_lock.acquire() + self.ns = NameServer(max_age=timedelta(seconds=3)) + self.thr = Thread(target=self.ns.run) + self.thr.start() + + def tearDown(self): + """Clean up after the tests have run.""" + self.ns.stop() + self.thr.join() + time.sleep(2) + test_lock.release() + + def test_pub_addresses(self): + """Test retrieving addresses.""" + from posttroll.ns import get_pub_addresses + from posttroll.publisher import Publish + + with Publish(str("data_provider"), 0, ["this_data"], broadcast_interval=0.1): + time.sleep(.3) + res = get_pub_addresses(["this_data"], timeout=.5) + assert len(res) == 1 + expected = {u'status': True, + u'service': [u'data_provider', u'this_data'], + u'name': u'address'} + for key, val in expected.items(): + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] + res = get_pub_addresses([str("data_provider")]) + assert len(res) == 1 + expected = {u'status': True, + u'service': [u'data_provider', u'this_data'], + u'name': u'address'} + for key, val in expected.items(): + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] + + def test_pub_sub_ctx(self): + """Test publish and subscribe.""" + from posttroll.message import Message + from posttroll.publisher import Publish + from posttroll.subscriber import Subscribe + + with Publish("data_provider", 0, ["this_data"]) as pub: + with Subscribe("this_data", "counter") as sub: + for counter in range(5): + message = Message("/counter", "info", str(counter)) + pub.send(str(message)) + time.sleep(1) + msg = next(sub.recv(2)) + if msg is not None: + assert str(msg) == str(message) + tested = True + sub.close() + assert tested + + def test_pub_sub_add_rm(self): + """Test adding and removing publishers.""" + from posttroll.publisher import Publish + from posttroll.subscriber import Subscribe + + time.sleep(4) + with Subscribe("this_data", "counter", True) as sub: + assert len(sub.sub_addr) == 0 + with Publish("data_provider", 0, ["this_data"]): + time.sleep(4) + next(sub.recv(2)) + assert len(sub.sub_addr) == 1 + time.sleep(3) + for msg in sub.recv(2): + if msg is None: + break + time.sleep(3) + assert len(sub.sub_addr) == 0 + with Publish("data_provider_2", 0, ["another_data"]): + time.sleep(4) + next(sub.recv(2)) + assert len(sub.sub_addr) == 0 + sub.close() + + +class TestNSWithoutMulticasting(unittest.TestCase): + """Test the nameserver.""" + + def setUp(self): + """Set up the testing class.""" + test_lock.acquire() + self.nameservers = ['localhost'] + self.ns = NameServer(max_age=timedelta(seconds=3), + multicast_enabled=False) + self.thr = Thread(target=self.ns.run) + self.thr.start() + + def tearDown(self): + """Clean up after the tests have run.""" + self.ns.stop() + self.thr.join() + time.sleep(2) + test_lock.release() + + def test_pub_addresses(self): + """Test retrieving addresses.""" + from posttroll.ns import get_pub_addresses + from posttroll.publisher import Publish + + with Publish("data_provider", 0, ["this_data"], + nameservers=self.nameservers): + time.sleep(3) + res = get_pub_addresses(["this_data"]) + self.assertEqual(len(res), 1) + expected = {u'status': True, + u'service': [u'data_provider', u'this_data'], + u'name': u'address'} + for key, val in expected.items(): + self.assertEqual(res[0][key], val) + self.assertTrue("receive_time" in res[0]) + self.assertTrue("URI" in res[0]) + res = get_pub_addresses(["data_provider"]) + self.assertEqual(len(res), 1) + expected = {u'status': True, + u'service': [u'data_provider', u'this_data'], + u'name': u'address'} + for key, val in expected.items(): + self.assertEqual(res[0][key], val) + self.assertTrue("receive_time" in res[0]) + self.assertTrue("URI" in res[0]) + + def test_pub_sub_ctx(self): + """Test publish and subscribe.""" + from posttroll.message import Message + from posttroll.publisher import Publish + from posttroll.subscriber import Subscribe + + with Publish("data_provider", 0, ["this_data"], + nameservers=self.nameservers) as pub: + with Subscribe("this_data", "counter") as sub: + for counter in range(5): + message = Message("/counter", "info", str(counter)) + pub.send(str(message)) + time.sleep(1) + msg = next(sub.recv(2)) + if msg is not None: + self.assertEqual(str(msg), str(message)) + tested = True + sub.close() + self.assertTrue(tested) + + def test_pub_sub_add_rm(self): + """Test adding and removing publishers.""" + from posttroll.publisher import Publish + from posttroll.subscriber import Subscribe + + time.sleep(4) + with Subscribe("this_data", "counter", True) as sub: + self.assertEqual(len(sub.sub_addr), 0) + with Publish("data_provider", 0, ["this_data"], + nameservers=self.nameservers): + time.sleep(4) + next(sub.recv(2)) + self.assertEqual(len(sub.sub_addr), 1) + time.sleep(3) + for msg in sub.recv(2): + if msg is None: + break + + time.sleep(3) + self.assertEqual(len(sub.sub_addr), 0) + with Publish("data_provider_2", 0, ["another_data"], + nameservers=self.nameservers): + time.sleep(4) + next(sub.recv(2)) + self.assertEqual(len(sub.sub_addr), 0) class TestPubSub(unittest.TestCase): @@ -604,7 +604,7 @@ def test_dict_config_full_nssubscriber(NSSubscriber_start): NSSubscriber_start.assert_called_once() -@mock.patch('posttroll.subscriber.Subscriber.update') +@mock.patch('posttroll.subscriber.UnsecureZMQSubscriber.update') def test_dict_config_full_subscriber(Subscriber_update): """Test that all Subscriber options are passed.""" from posttroll.subscriber import create_subscriber_from_dict_config @@ -621,17 +621,6 @@ def test_dict_config_full_subscriber(Subscriber_update): } _ = create_subscriber_from_dict_config(settings) - -@pytest.fixture -def oldtcp_keepalive_settings(monkeypatch): - """Set TCP Keepalive settings.""" - monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE", "1") - monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE_CNT", "10") - monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE_IDLE", "1") - monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE_INTVL", "1") - with reset_config_for_tests(): - yield - @pytest.fixture def tcp_keepalive_settings(monkeypatch): """Set TCP Keepalive settings.""" From 66bd3e96cb78a928605624835df1e39f55bd02d9 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 29 Nov 2023 21:45:57 +0100 Subject: [PATCH 05/45] Fix NS tests --- posttroll/tests/test_pubsub.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 72bb357..1cb339a 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -109,21 +109,21 @@ def test_pub_sub_add_rm(self): time.sleep(4) with Subscribe("this_data", "counter", True) as sub: - assert len(sub.sub_addr) == 0 + assert len(sub.addresses) == 0 with Publish("data_provider", 0, ["this_data"]): time.sleep(4) next(sub.recv(2)) - assert len(sub.sub_addr) == 1 + assert len(sub.addresses) == 1 time.sleep(3) for msg in sub.recv(2): if msg is None: break time.sleep(3) - assert len(sub.sub_addr) == 0 + assert len(sub.addresses) == 0 with Publish("data_provider_2", 0, ["another_data"]): time.sleep(4) next(sub.recv(2)) - assert len(sub.sub_addr) == 0 + assert len(sub.addresses) == 0 sub.close() @@ -200,24 +200,24 @@ def test_pub_sub_add_rm(self): time.sleep(4) with Subscribe("this_data", "counter", True) as sub: - self.assertEqual(len(sub.sub_addr), 0) + assert len(sub.addresses) == 0 with Publish("data_provider", 0, ["this_data"], nameservers=self.nameservers): time.sleep(4) next(sub.recv(2)) - self.assertEqual(len(sub.sub_addr), 1) + assert len(sub.addresses) == 1 time.sleep(3) for msg in sub.recv(2): if msg is None: break time.sleep(3) - self.assertEqual(len(sub.sub_addr), 0) + assert len(sub.sub_addr) == 0 with Publish("data_provider_2", 0, ["another_data"], nameservers=self.nameservers): time.sleep(4) next(sub.recv(2)) - self.assertEqual(len(sub.sub_addr), 0) + assert len(sub.sub_addr) == 0 class TestPubSub(unittest.TestCase): From 268b78d1165fb488d1b04d5a1b62aea348ef1029 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Mon, 4 Dec 2023 11:19:19 +0100 Subject: [PATCH 06/45] Fix style --- posttroll/tests/test_pubsub.py | 50 +++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 1cb339a..ff31f60 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -233,18 +233,15 @@ def tearDown(self): def test_pub_address_timeout(self): """Test timeout in offline nameserver.""" - from posttroll.ns import get_pub_address - from posttroll.ns import TimeoutError + from posttroll.ns import TimeoutError, get_pub_address with pytest.raises(TimeoutError): get_pub_address("this_data", 0.05) def test_pub_suber(self): """Test publisher and subscriber.""" from posttroll.message import Message - from posttroll.publisher import Publisher - from posttroll.publisher import get_own_ip + from posttroll.publisher import Publisher, get_own_ip from posttroll.subscriber import Subscriber - pub_address = "tcp://" + str(get_own_ip()) + ":0" pub = Publisher(pub_address).start() addr = pub_address[:-1] + str(pub.port_number) @@ -281,7 +278,7 @@ def test_pub_sub_ctx_no_nameserver(self): time.sleep(.05) msg = next(sub.recv(2)) if msg is not None: - self.assertEqual(str(msg), str(message)) + assert str(msg) == str(message) tested = True sub.close() assert tested @@ -341,6 +338,7 @@ def test_pub_minmax_port_from_instanciation(self): def _get_port_from_publish_instance(min_port=None, max_port=None): from zmq.error import ZMQError + from posttroll.publisher import Publish try: @@ -372,9 +370,9 @@ def tearDown(self): def test_listener_container(self): """Test listener container.""" + from posttroll.listener import ListenerContainer from posttroll.message import Message from posttroll.publisher import NoisyPublisher - from posttroll.listener import ListenerContainer pub = NoisyPublisher("test", broadcast_interval=0.1) pub.start() @@ -389,7 +387,7 @@ def test_listener_container(self): if msg_in is not None: assert str(msg_in) == str(msg_out) tested = True - self.assertTrue(tested) + assert tested pub.stop() sub.stop() @@ -407,9 +405,9 @@ def tearDown(self): def test_listener_container(self): """Test listener container.""" + from posttroll.listener import ListenerContainer from posttroll.message import Message from posttroll.publisher import Publisher - from posttroll.listener import ListenerContainer pub_addr = "tcp://127.0.0.1:55000" pub = Publisher(pub_addr, name="test") @@ -424,9 +422,9 @@ def test_listener_container(self): msg_in = sub.output_queue.get(True, 1) if msg_in is not None: - self.assertEqual(str(msg_in), str(msg_out)) + assert str(msg_in) == str(msg_out) tested = True - self.assertTrue(tested) + assert tested pub.stop() sub.stop() @@ -517,21 +515,21 @@ def test_publish_is_not_noisy(self): def test_publish_is_noisy_only_name(self): """Test that NoisyPublisher is selected with the context manager when only name is given.""" - from posttroll.publisher import Publish, NoisyPublisher + from posttroll.publisher import NoisyPublisher, Publish with Publish("service_name") as pub: assert isinstance(pub, NoisyPublisher) def test_publish_is_noisy_with_port(self): """Test that NoisyPublisher is selected with the context manager when port is given.""" - from posttroll.publisher import Publish, NoisyPublisher + from posttroll.publisher import NoisyPublisher, Publish with Publish("service_name", port=40001) as pub: assert isinstance(pub, NoisyPublisher) def test_publish_is_noisy_with_nameservers(self): """Test that NoisyPublisher is selected with the context manager when nameservers are given.""" - from posttroll.publisher import Publish, NoisyPublisher + from posttroll.publisher import NoisyPublisher, Publish with Publish("service_name", nameservers=['a', 'b']) as pub: assert isinstance(pub, NoisyPublisher) @@ -621,8 +619,8 @@ def test_dict_config_full_subscriber(Subscriber_update): } _ = create_subscriber_from_dict_config(settings) -@pytest.fixture -def tcp_keepalive_settings(monkeypatch): +@pytest.fixture() +def _tcp_keepalive_settings(monkeypatch): """Set TCP Keepalive settings.""" from posttroll import config with config.set(tcp_keepalive=1, tcp_keepalive_cnt=10, tcp_keepalive_idle=1, tcp_keepalive_intvl=1): @@ -637,29 +635,32 @@ def reset_config_for_tests(): posttroll.config = old_config -@pytest.fixture -def tcp_keepalive_no_settings(): +@pytest.fixture() +def _tcp_keepalive_no_settings(): """Set TCP Keepalive settings.""" from posttroll import config with config.set(tcp_keepalive=None, tcp_keepalive_cnt=None, tcp_keepalive_idle=None, tcp_keepalive_intvl=None): yield -def test_publisher_tcp_keepalive(tcp_keepalive_settings): +@pytest.mark.usefixtures("_tcp_keepalive_settings") +def test_publisher_tcp_keepalive(): """Test that TCP Keepalive is set for Publisher if the environment variables are present.""" from posttroll.backends.zmq.publisher import UnsecureZMQPublisher pub = UnsecureZMQPublisher("tcp://127.0.0.1:9000").start() _assert_tcp_keepalive(pub.publish_socket) -def test_publisher_tcp_keepalive_not_set(tcp_keepalive_no_settings): +@pytest.mark.usefixtures("_tcp_keepalive_no_settings") +def test_publisher_tcp_keepalive_not_set(): """Test that TCP Keepalive is not set on by default.""" from posttroll.backends.zmq.publisher import UnsecureZMQPublisher pub = UnsecureZMQPublisher("tcp://127.0.0.1:9000").start() _assert_no_tcp_keepalive(pub.publish_socket) -def test_subscriber_tcp_keepalive(tcp_keepalive_settings): +@pytest.mark.usefixtures("_tcp_keepalive_settings") +def test_subscriber_tcp_keepalive(): """Test that TCP Keepalive is set for Subscriber if the environment variables are present.""" from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber @@ -670,7 +671,8 @@ def test_subscriber_tcp_keepalive(tcp_keepalive_settings): _assert_tcp_keepalive(list(sub.addr_sub.values())[0]) -def test_subscriber_tcp_keepalive_not_set(tcp_keepalive_no_settings): +@pytest.mark.usefixtures("_tcp_keepalive_no_settings") +def test_subscriber_tcp_keepalive_not_set(): """Test that TCP Keepalive is not set on by default.""" from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber @@ -700,6 +702,7 @@ def _assert_no_tcp_keepalive(socket): def test_ipc_pubsub(): + """Test pub-sub on an ipc socket.""" from posttroll import config with config.set(backend="unsecure_zmq"): subscriber_settings = dict(addresses="ipc://bla.ipc", topics="", nameserver=False, port=10202) @@ -721,6 +724,7 @@ def delayed_send(msg): sub.stop() def test_ipc_pubsub_with_sec(): + """Test pub-sub on a secure ipc socket.""" from posttroll import config with config.set(backend="secure_zmq"): subscriber_settings = dict(addresses="ipc://bla.ipc", topics="", nameserver=False, port=10202) @@ -742,6 +746,7 @@ def delayed_send(msg): sub.stop() def test_switch_to_unknown_backend(): + """Test switching to unknown backend.""" from posttroll import config from posttroll.publisher import Publisher from posttroll.subscriber import Subscriber @@ -752,6 +757,7 @@ def test_switch_to_unknown_backend(): Subscriber("ipc://bla.ipc") def test_switch_to_secure_zmq_backend(): + """Test switching to the secure_zmq backend.""" from posttroll import config from posttroll.publisher import Publisher from posttroll.subscriber import Subscriber From b2e3df1178cd64cd723ea6287c6119f40d48d5c8 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 6 Dec 2023 17:24:50 +0100 Subject: [PATCH 07/45] Fix style --- bin/nameserver | 16 +++++++--------- posttroll/__init__.py | 4 ++-- posttroll/publisher.py | 19 ++++++++++--------- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/bin/nameserver b/bin/nameserver index d2d3370..8aa89e6 100755 --- a/bin/nameserver +++ b/bin/nameserver @@ -20,19 +20,17 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -"""The nameserver. Port 5555 (hardcoded) is used for communications. -""" +"""The nameserver. Port 5555 (hardcoded) is used for communications.""" # TODO: make port configurable. -from posttroll.ns import NameServer - import logging -import _strptime + +from posttroll.ns import NameServer logger = logging.getLogger(__name__) -if __name__ == '__main__': +if __name__ == "__main__": import argparse @@ -58,14 +56,14 @@ if __name__ == '__main__': handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("[%(levelname)s: %(asctime)s :" " %(name)s] %(message)s", - '%Y-%m-%d %H:%M:%S')) + "%Y-%m-%d %H:%M:%S")) if opts.verbose: loglevel = logging.DEBUG else: loglevel = logging.INFO handler.setLevel(loglevel) - logging.getLogger('').setLevel(loglevel) - logging.getLogger('').addHandler(handler) + logging.getLogger("").setLevel(loglevel) + logging.getLogger("").addHandler(handler) logger = logging.getLogger("nameserver") multicast_enabled = (opts.no_multicast == False) diff --git a/posttroll/__init__.py b/posttroll/__init__.py index 30f40d9..fe14295 100644 --- a/posttroll/__init__.py +++ b/posttroll/__init__.py @@ -34,7 +34,7 @@ from .version import get_versions -config = Config('posttroll') +config = Config("posttroll") context = {} logger = logging.getLogger(__name__) @@ -47,7 +47,7 @@ def get_context(): pid = os.getpid() if pid not in context: context[pid] = zmq.Context() - logger.debug('renewed context for PID %d', pid) + logger.debug("renewed context for PID %d", pid) return context[pid] diff --git a/posttroll/publisher.py b/posttroll/publisher.py index 036d387..317bfdf 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -27,9 +27,9 @@ import socket from datetime import datetime, timedelta +from posttroll import config from posttroll.message import Message from posttroll.message_broadcaster import sendaddressservice -from posttroll import config LOGGER = logging.getLogger(__name__) @@ -89,8 +89,8 @@ def __init__(self, address, name="", min_port=None, max_port=None): """Bind the publisher class to a port.""" # Limit port range or use the defaults when no port is defined # by the user - min_port = min_port or int(config.get('pub_min_port', 49152)) - max_port = max_port or int(config.get('pub_max_port', 65536)) + min_port = min_port or int(config.get("pub_min_port", 49152)) + max_port = max_port or int(config.get("pub_max_port", 65536)) # Initialize no heartbeat self._heartbeat = None @@ -143,7 +143,7 @@ class _PublisherHeartbeat: def __init__(self, publisher): self.publisher = publisher - self.subject = '/heartbeat/' + publisher.name + self.subject = "/heartbeat/" + publisher.name self.lastbeat = datetime(1900, 1, 1) def __call__(self, min_interval=0): @@ -228,6 +228,7 @@ def close(self): @property def port_number(self): + """Get the port number.""" return self._publisher.port_number @@ -266,9 +267,9 @@ class Publish: def __init__(self, name, port=0, aliases=None, broadcast_interval=2, nameservers=None, min_port=None, max_port=None): """Initialize the class.""" - settings = {'name': name, 'port': port, 'min_port': min_port, 'max_port': max_port, - 'aliases': aliases, 'broadcast_interval': broadcast_interval, - 'nameservers': nameservers} + settings = {"name": name, "port": port, "min_port": min_port, "max_port": max_port, + "aliases": aliases, "broadcast_interval": broadcast_interval, + "nameservers": nameservers} self.publisher = create_publisher_from_dict_config(settings) def __enter__(self): @@ -300,13 +301,13 @@ def create_publisher_from_dict_config(settings): described in the docstrings of the respective classes, namely :class:`~posttroll.publisher.Publisher` and :class:`~posttroll.publisher.NoisyPublisher`. """ - if settings.get('port') and settings.get('nameservers') is False: + if settings.get("port") and settings.get("nameservers") is False: return _get_publisher_instance(settings) return _get_noisypublisher_instance(settings) def _get_publisher_instance(settings): - publisher_address = _create_tcp_publish_address(settings['port']) + publisher_address = _create_tcp_publish_address(settings["port"]) publisher_name = settings.get("name", "") min_port = settings.get("min_port") max_port = settings.get("max_port") From 6a56d2db16e4cbde6d02475bded50217b782734c Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 6 Dec 2023 19:33:07 +0100 Subject: [PATCH 08/45] Move everything zmq related to it's own backend --- posttroll/__init__.py | 22 ++--- posttroll/address_receiver.py | 37 ++----- posttroll/backends/zmq/__init__.py | 18 ++++ posttroll/backends/zmq/address_receiver.py | 22 +++++ posttroll/backends/zmq/message_broadcaster.py | 51 ++++++++++ posttroll/backends/zmq/ns.py | 90 +++++++++++++++++ posttroll/backends/zmq/publisher.py | 8 +- posttroll/backends/zmq/subscriber.py | 16 ++-- posttroll/message_broadcaster.py | 58 ++++------- posttroll/ns.py | 96 +++++-------------- posttroll/tests/test_pubsub.py | 77 +++++++-------- 11 files changed, 290 insertions(+), 205 deletions(-) create mode 100644 posttroll/backends/zmq/address_receiver.py create mode 100644 posttroll/backends/zmq/message_broadcaster.py create mode 100644 posttroll/backends/zmq/ns.py diff --git a/posttroll/__init__.py b/posttroll/__init__.py index fe14295..469d02e 100644 --- a/posttroll/__init__.py +++ b/posttroll/__init__.py @@ -25,30 +25,28 @@ """Posttroll packages.""" import logging -import os import sys from datetime import datetime -import zmq from donfig import Config from .version import get_versions config = Config("posttroll") -context = {} +# context = {} logger = logging.getLogger(__name__) -def get_context(): - """Provide the context to use. +# def get_context(): +# """Provide the context to use. - This function takes care of creating new contexts in case of forks. - """ - pid = os.getpid() - if pid not in context: - context[pid] = zmq.Context() - logger.debug("renewed context for PID %d", pid) - return context[pid] +# This function takes care of creating new contexts in case of forks. +# """ +# pid = os.getpid() +# if pid not in context: +# context[pid] = zmq.Context() +# logger.debug("renewed context for PID %d", pid) +# return context[pid] def strp_isoformat(strg): diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index 3cedf79..c5a7661 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -27,28 +27,25 @@ //address info ... host:port """ import copy +import errno import logging import os import threading -import errno import time - from datetime import datetime, timedelta import netifaces -from zmq import REP, LINGER +from posttroll import config from posttroll.bbmcast import MulticastReceiver, SocketTimeout from posttroll.message import Message from posttroll.publisher import Publish -from posttroll import get_context - -__all__ = ('AddressReceiver', 'getaddress') +__all__ = ("AddressReceiver", "getaddress") LOGGER = logging.getLogger(__name__) -debug = os.environ.get('DEBUG', False) +debug = os.environ.get("DEBUG", False) broadcast_port = 21200 default_publish_port = 16543 @@ -64,7 +61,7 @@ def get_local_ips(): for addr in inet_addrs: if addr is not None: for add in addr: - ips.append(add['addr']) + ips.append(add["addr"]) return ips # ----------------------------------------------------------------------------- @@ -169,7 +166,9 @@ def _run(self): break else: - recv = _SimpleReceiver(port) + if config.get("backend", "unsecure_zmq") == "unsecure_zmq": + from posttroll.backends.zmq.address_receiver import SimpleReceiver + recv = SimpleReceiver(port) nameservers = ["localhost"] self._is_running = True @@ -221,26 +220,6 @@ def _add(self, adr, metadata): self._addresses[adr] = metadata -class _SimpleReceiver(object): - - """ Simple listing on port for address messages.""" - - def __init__(self, port=None): - self._port = port or default_publish_port - self._socket = get_context().socket(REP) - self._socket.bind("tcp://*:" + str(port)) - - def __call__(self): - message = self._socket.recv_string() - self._socket.send_string("ok") - return message, None - - def close(self): - """Close the receiver.""" - self._socket.setsockopt(LINGER, 1) - self._socket.close() - - # ----------------------------------------------------------------------------- # default getaddress = AddressReceiver diff --git a/posttroll/backends/zmq/__init__.py b/posttroll/backends/zmq/__init__.py index 5596169..f086f98 100644 --- a/posttroll/backends/zmq/__init__.py +++ b/posttroll/backends/zmq/__init__.py @@ -1,7 +1,25 @@ +import logging +import os + import zmq from posttroll import config +logger = logging.getLogger(__name__) +context = {} + + +def get_context(): + """Provide the context to use. + + This function takes care of creating new contexts in case of forks. + """ + pid = os.getpid() + if pid not in context: + context[pid] = zmq.Context() + logger.debug("renewed context for PID %d", pid) + return context[pid] + def _set_tcp_keepalive(socket): _set_int_sockopt(socket, zmq.TCP_KEEPALIVE, config.get("tcp_keepalive", None)) _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_CNT, config.get("tcp_keepalive_cnt", None)) diff --git a/posttroll/backends/zmq/address_receiver.py b/posttroll/backends/zmq/address_receiver.py new file mode 100644 index 0000000..858d6be --- /dev/null +++ b/posttroll/backends/zmq/address_receiver.py @@ -0,0 +1,22 @@ +from posttroll.address_receiver import default_publish_port +from posttroll.backends.zmq import get_context +from zmq import REP, LINGER + +class SimpleReceiver(object): + + """ Simple listing on port for address messages.""" + + def __init__(self, port=None): + self._port = port or default_publish_port + self._socket = get_context().socket(REP) + self._socket.bind("tcp://*:" + str(port)) + + def __call__(self): + message = self._socket.recv_string() + self._socket.send_string("ok") + return message, None + + def close(self): + """Close the receiver.""" + self._socket.setsockopt(LINGER, 1) + self._socket.close() diff --git a/posttroll/backends/zmq/message_broadcaster.py b/posttroll/backends/zmq/message_broadcaster.py new file mode 100644 index 0000000..8d3310f --- /dev/null +++ b/posttroll/backends/zmq/message_broadcaster.py @@ -0,0 +1,51 @@ +import threading +from posttroll.backends.zmq import get_context + +from zmq import LINGER, NOBLOCK, REQ, ZMQError + +import logging + +logger = logging.getLogger(__name__) + + +class UnsecureZMQDesignatedReceiversSender: + """Sends message to multiple *receivers* on *port*.""" + + def __init__(self, default_port, receivers): + self.default_port = default_port + + self.receivers = receivers + self._shutdown_event = threading.Event() + + def __call__(self, data): + """Send data.""" + for receiver in self.receivers: + self._send_to_address(receiver, data) + + def _send_to_address(self, address, data, timeout=10): + """Send data to *address* and *port* without verification of response.""" + # Socket to talk to server + socket = get_context().socket(REQ) + try: + socket.setsockopt(LINGER, timeout * 1000) + if address.find(":") == -1: + socket.connect("tcp://%s:%d" % (address, self.default_port)) + else: + socket.connect("tcp://%s" % address) + socket.send_string(data) + while not self._shutdown_event.is_set(): + try: + message = socket.recv_string(NOBLOCK) + except ZMQError: + self._shutdown_event.wait(.1) + continue + if message != "ok": + logger.warn("invalid acknowledge received: %s" % message) + break + + finally: + socket.close() + + def close(self): + """Close the sender.""" + self._shutdown_event.set() diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py new file mode 100644 index 0000000..7255e05 --- /dev/null +++ b/posttroll/backends/zmq/ns.py @@ -0,0 +1,90 @@ +"""ZMQ implexentation of ns.""" + +import logging +from threading import Lock + +from zmq import LINGER, NOBLOCK, POLLIN, REP, REQ, Poller + +from posttroll.backends.zmq import get_context +from posttroll.message import Message +from posttroll.ns import PORT, get_active_address + +logger = logging.getLogger("__name__") + +nslock = Lock() + + +def unsecure_zmq_get_pub_address(name, timeout=10, nameserver="localhost"): + """Get the address of the publisher. + + For a given publisher *name* from the nameserver on *nameserver* (localhost by default). + """ + # Socket to talk to server + socket = get_context().socket(REQ) + try: + socket.setsockopt(LINGER, int(timeout * 1000)) + socket.connect("tcp://" + nameserver + ":" + str(PORT)) + logger.debug("Connecting to %s", + "tcp://" + nameserver + ":" + str(PORT)) + poller = Poller() + poller.register(socket, POLLIN) + + message = Message("/oper/ns", "request", {"service": name}) + socket.send_string(str(message)) + + # Get the reply. + sock = poller.poll(timeout=timeout * 1000) + if sock: + if sock[0][0] == socket: + message = Message.decode(socket.recv_string(NOBLOCK)) + return message.data + else: + raise TimeoutError("Didn't get an address after %d seconds." + % timeout) + finally: + socket.close() + + +class UnsecureZMQNameServer: + """The name server.""" + + def __init__(self): + """Set up the nameserver.""" + self.loop = True + self.listener = None + + def run(self, arec): + """Run the listener and answer to requests.""" + port = PORT + + try: + with nslock: + self.listener = get_context().socket(REP) + self.listener.bind("tcp://*:" + str(port)) + logger.debug("Listening on port %s", str(port)) + poller = Poller() + poller.register(self.listener, POLLIN) + while self.loop: + with nslock: + socks = dict(poller.poll(1000)) + if socks: + if socks.get(self.listener) == POLLIN: + msg = self.listener.recv_string() + else: + continue + logger.debug("Replying to request: " + str(msg)) + msg = Message.decode(msg) + active_address = get_active_address(msg.data["service"], arec) + self.listener.send_unicode(str(active_address)) + except KeyboardInterrupt: + # Needed to stop the nameserver. + pass + finally: + self.stop() + + def stop(self): + """Stop the name server.""" + self.listener.setsockopt(LINGER, 1) + self.loop = False + with nslock: + self.listener.close() diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py index 5ba7fbe..6c88f1c 100644 --- a/posttroll/backends/zmq/publisher.py +++ b/posttroll/backends/zmq/publisher.py @@ -1,10 +1,12 @@ +"""ZMQ implementation of the publisher.""" + +import logging from threading import Lock from urllib.parse import urlsplit, urlunsplit + import zmq -import logging -from posttroll import get_context -from posttroll.backends.zmq import _set_tcp_keepalive +from posttroll.backends.zmq import _set_tcp_keepalive, get_context LOGGER = logging.getLogger(__name__) diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index 46f3a3d..caf267b 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -1,21 +1,21 @@ +"""ZMQ implementation of the subscriber.""" + +import logging from threading import Lock -from urllib.parse import urlsplit -from posttroll.message import Message -from zmq import Poller, SUB, SUBSCRIBE, POLLIN, PULL, ZMQError, NOBLOCK, LINGER from time import sleep -import logging - -from posttroll import get_context -from posttroll.backends.zmq import _set_tcp_keepalive +from urllib.parse import urlsplit +from zmq import LINGER, NOBLOCK, POLLIN, PULL, SUB, SUBSCRIBE, Poller, ZMQError +from posttroll.backends.zmq import _set_tcp_keepalive, get_context +from posttroll.message import Message LOGGER = logging.getLogger(__name__) class UnsecureZMQSubscriber: """Unsecure ZMQ implementation of the subscriber.""" - def __init__(self, addresses, topics='', message_filter=None, translate=False): + def __init__(self, addresses, topics="", message_filter=None, translate=False): """Initialize the subscriber.""" self._topics = topics self._filter = message_filter diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py index 55eca43..98ffdf8 100644 --- a/posttroll/message_broadcaster.py +++ b/posttroll/message_broadcaster.py @@ -21,62 +21,36 @@ # You should have received a copy of the GNU General Public License along with # pytroll. If not, see . -import threading -import logging import errno +import logging +import threading -from posttroll import message -from posttroll.bbmcast import MulticastSender, MC_GROUP -from posttroll import get_context -from zmq import REQ, LINGER, NOBLOCK, ZMQError +from posttroll import config, message +from posttroll.bbmcast import MC_GROUP, MulticastSender -__all__ = ('MessageBroadcaster', 'AddressBroadcaster', 'sendaddress') +__all__ = ("MessageBroadcaster", "AddressBroadcaster", "sendaddress") LOGGER = logging.getLogger(__name__) broadcast_port = 21200 -class DesignatedReceiversSender(object): +class DesignatedReceiversSender: """Sends message to multiple *receivers* on *port*.""" - def __init__(self, default_port, receivers): - self.default_port = default_port - - self.receivers = receivers - self._shutdown_event = threading.Event() + backend = config.get("backend", "unsecure_zmq") + if backend == "unsecure_zmq": + from posttroll.backends.zmq.message_broadcaster import UnsecureZMQDesignatedReceiversSender + self._sender = UnsecureZMQDesignatedReceiversSender(default_port, receivers) def __call__(self, data): - for receiver in self.receivers: - self._send_to_address(receiver, data) - - def _send_to_address(self, address, data, timeout=10): - """Send data to *address* and *port* without verification of response.""" - # Socket to talk to server - socket = get_context().socket(REQ) - try: - socket.setsockopt(LINGER, timeout * 1000) - if address.find(":") == -1: - socket.connect("tcp://%s:%d" % (address, self.default_port)) - else: - socket.connect("tcp://%s" % address) - socket.send_string(data) - while not self._shutdown_event.is_set(): - try: - message = socket.recv_string(NOBLOCK) - except ZMQError: - self._shutdown_event.wait(.1) - continue - if message != "ok": - LOGGER.warn("invalid acknowledge received: %s" % message) - break - - finally: - socket.close() + """Send data.""" + return self._sender(data) def close(self): """Close the sender.""" - self._shutdown_event.set() + return self._sender.close() + #----------------------------------------------------------------------------- # # General thread to broadcast messages. @@ -91,6 +65,7 @@ class MessageBroadcaster(object): """ def __init__(self, msg, port, interval, designated_receivers=None): + """Set up the message broadcaster.""" if designated_receivers: self._sender = DesignatedReceiversSender(port, designated_receivers) @@ -152,9 +127,10 @@ def _run(self): class AddressBroadcaster(MessageBroadcaster): - """Class to broadcast stuff.""" + """Class to broadcast addresses.""" def __init__(self, name, address, interval, nameservers): + """Set up the Address broadcaster.""" msg = message.Message("/address/%s" % name, "info", {"URI": "%s:%d" % address}).encode() MessageBroadcaster.__init__(self, msg, broadcast_port, interval, diff --git a/posttroll/ns.py b/posttroll/ns.py index 12a1bd0..ae49bc7 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -30,11 +30,7 @@ import time from datetime import datetime, timedelta -from threading import Lock -# pylint: disable=E0611 -from zmq import LINGER, NOBLOCK, POLLIN, REP, REQ, Poller - -from posttroll import get_context +from posttroll import config from posttroll.address_receiver import AddressReceiver from posttroll.message import Message @@ -45,13 +41,6 @@ logger = logging.getLogger(__name__) -nslock = Lock() - -class TimeoutError(BaseException): - - """A timeout.""" - pass - # Client functions. @@ -59,6 +48,7 @@ def get_pub_addresses(names=None, timeout=10, nameserver="localhost"): """Get the address of the publisher for a given list of publisher *names* from the nameserver on *nameserver* (localhost by default). """ + addrs = [] if names is None: names = ["", ] @@ -72,40 +62,23 @@ def get_pub_addresses(names=None, timeout=10, nameserver="localhost"): return addrs + + def get_pub_address(name, timeout=10, nameserver="localhost"): - """Get the address of the publisher for a given publisher *name* from the - nameserver on *nameserver* (localhost by default).""" - # Socket to talk to server - socket = get_context().socket(REQ) - try: - socket.setsockopt(LINGER, int(timeout * 1000)) - socket.connect("tcp://" + nameserver + ":" + str(PORT)) - logger.debug('Connecting to %s', - "tcp://" + nameserver + ":" + str(PORT)) - poller = Poller() - poller.register(socket, POLLIN) - - message = Message("/oper/ns", "request", {"service": name}) - socket.send_string(str(message)) - - # Get the reply. - sock = poller.poll(timeout=timeout * 1000) - if sock: - if sock[0][0] == socket: - message = Message.decode(socket.recv_string(NOBLOCK)) - return message.data - else: - raise TimeoutError("Didn't get an address after %d seconds." - % timeout) - finally: - socket.close() + """Get the address of the publisher. + + For a given publisher *name* from the nameserver on *nameserver* (localhost by default). + """ + backend = config.get("backend", "unsecure_zmq") + if backend == "unsecure_zmq": + from posttroll.backends.zmq.ns import unsecure_zmq_get_pub_address + return unsecure_zmq_get_pub_address(name, timeout, nameserver) # Server part. def get_active_address(name, arec): - """Get the addresses of the active modules for a given publisher *name*. - """ + """Get the addresses of the active modules for a given publisher *name*.""" addrs = arec.get(name) if addrs: return Message("/oper/ns", "info", addrs) @@ -113,57 +86,32 @@ def get_active_address(name, arec): return Message("/oper/ns", "info", "") + class NameServer: """The name server.""" def __init__(self, max_age=timedelta(minutes=10), multicast_enabled=True, restrict_to_localhost=False): - self.loop = True - self.listener = None self._max_age = max_age self._multicast_enabled = multicast_enabled self._restrict_to_localhost = restrict_to_localhost + backend = config.get("backend", "unsecure_zmq") + if backend == "unsecure_zmq": + from posttroll.backends.zmq.ns import UnsecureZMQNameServer + self._ns = UnsecureZMQNameServer() def run(self, *args): - """Run the listener and answer to requests. - """ + """Run the nameserver.""" del args arec = AddressReceiver(max_age=self._max_age, multicast_enabled=self._multicast_enabled, restrict_to_localhost=self._restrict_to_localhost) arec.start() - port = PORT - try: - with nslock: - self.listener = get_context().socket(REP) - self.listener.bind("tcp://*:" + str(port)) - logger.debug('Listening on port %s', str(port)) - poller = Poller() - poller.register(self.listener, POLLIN) - while self.loop: - with nslock: - socks = dict(poller.poll(1000)) - if socks: - if socks.get(self.listener) == POLLIN: - msg = self.listener.recv_string() - else: - continue - logger.debug("Replying to request: " + str(msg)) - msg = Message.decode(msg) - active_address = get_active_address(msg.data["service"], arec) - self.listener.send_unicode(str(active_address)) - except KeyboardInterrupt: - # Needed to stop the nameserver. - pass + return self._ns.run(arec) finally: arec.stop() - self.stop() def stop(self): - """Stop the name server. - """ - self.listener.setsockopt(LINGER, 1) - self.loop = False - with nslock: - self.listener.close() + """Stop the nameserver.""" + return self._ns.stop() diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index ff31f60..607aeb2 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -23,19 +23,20 @@ """Test the publishing and subscribing facilities.""" -import unittest -from unittest import mock -from datetime import timedelta -from threading import Thread, Lock import time +import unittest from contextlib import contextmanager +from datetime import timedelta +from threading import Lock, Thread +from unittest import mock + +import pytest +from donfig import Config import posttroll from posttroll.ns import NameServer from posttroll.publisher import create_publisher_from_dict_config from posttroll.subscriber import Subscribe, Subscriber, create_subscriber_from_dict_config -import pytest -from donfig import Config test_lock = Lock() @@ -66,18 +67,18 @@ def test_pub_addresses(self): time.sleep(.3) res = get_pub_addresses(["this_data"], timeout=.5) assert len(res) == 1 - expected = {u'status': True, - u'service': [u'data_provider', u'this_data'], - u'name': u'address'} + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} for key, val in expected.items(): assert res[0][key] == val assert "receive_time" in res[0] assert "URI" in res[0] res = get_pub_addresses([str("data_provider")]) assert len(res) == 1 - expected = {u'status': True, - u'service': [u'data_provider', u'this_data'], - u'name': u'address'} + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} for key, val in expected.items(): assert res[0][key] == val assert "receive_time" in res[0] @@ -133,7 +134,7 @@ class TestNSWithoutMulticasting(unittest.TestCase): def setUp(self): """Set up the testing class.""" test_lock.acquire() - self.nameservers = ['localhost'] + self.nameservers = ["localhost"] self.ns = NameServer(max_age=timedelta(seconds=3), multicast_enabled=False) self.thr = Thread(target=self.ns.run) @@ -155,23 +156,23 @@ def test_pub_addresses(self): nameservers=self.nameservers): time.sleep(3) res = get_pub_addresses(["this_data"]) - self.assertEqual(len(res), 1) - expected = {u'status': True, - u'service': [u'data_provider', u'this_data'], - u'name': u'address'} + assert len(res) == 1 + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} for key, val in expected.items(): - self.assertEqual(res[0][key], val) - self.assertTrue("receive_time" in res[0]) - self.assertTrue("URI" in res[0]) + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] res = get_pub_addresses(["data_provider"]) - self.assertEqual(len(res), 1) - expected = {u'status': True, - u'service': [u'data_provider', u'this_data'], - u'name': u'address'} + assert len(res) == 1 + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} for key, val in expected.items(): - self.assertEqual(res[0][key], val) - self.assertTrue("receive_time" in res[0]) - self.assertTrue("URI" in res[0]) + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] def test_pub_sub_ctx(self): """Test publish and subscribe.""" @@ -188,10 +189,10 @@ def test_pub_sub_ctx(self): time.sleep(1) msg = next(sub.recv(2)) if msg is not None: - self.assertEqual(str(msg), str(message)) + assert str(msg) == str(message) tested = True sub.close() - self.assertTrue(tested) + assert tested def test_pub_sub_add_rm(self): """Test adding and removing publishers.""" @@ -233,7 +234,7 @@ def tearDown(self): def test_pub_address_timeout(self): """Test timeout in offline nameserver.""" - from posttroll.ns import TimeoutError, get_pub_address + from posttroll.ns import get_pub_address with pytest.raises(TimeoutError): get_pub_address("this_data", 0.05) @@ -245,7 +246,7 @@ def test_pub_suber(self): pub_address = "tcp://" + str(get_own_ip()) + ":0" pub = Publisher(pub_address).start() addr = pub_address[:-1] + str(pub.port_number) - sub = Subscriber([addr], '/counter') + sub = Subscriber([addr], "/counter") # wait a bit before sending the first message so that the subscriber is ready time.sleep(.002) @@ -300,7 +301,7 @@ def test_pub_supports_unicode(self): from posttroll.message import Message from posttroll.publisher import Publish - message = Message("/pџтяöll", "info", 'hej') + message = Message("/pџтяöll", "info", "hej") with Publish("a_service", 9000) as pub: try: pub.send(message.encode()) @@ -319,7 +320,7 @@ def test_pub_minmax_port_from_config(self): # The port wasn't free, try another one continue # Port was selected, make sure it's within the "range" of one - self.assertEqual(res, port) + assert res == port break def test_pub_minmax_port_from_instanciation(self): @@ -332,7 +333,7 @@ def test_pub_minmax_port_from_instanciation(self): # The port wasn't free, try again continue # Port was selected, make sure it's within the "range" of one - self.assertEqual(res, port) + assert res == port break @@ -439,7 +440,7 @@ def test_localhost_restriction(self, mcrec, pub, msg): """Test address receiver restricted only to localhost.""" mcr_instance = mock.Mock() mcrec.return_value = mcr_instance - mcr_instance.return_value = 'blabla', ('255.255.255.255', 12) + mcr_instance.return_value = "blabla", ("255.255.255.255", 12) from posttroll.address_receiver import AddressReceiver adr = AddressReceiver(restrict_to_localhost=True) adr.start() @@ -455,16 +456,16 @@ def test_publisher_is_selected(self): """Test that Publisher is selected as publisher class.""" from posttroll.publisher import Publisher - settings = {'port': 12345, 'nameservers': False} + settings = {"port": 12345, "nameservers": False} pub = create_publisher_from_dict_config(settings) assert isinstance(pub, Publisher) assert pub is not None - @mock.patch('posttroll.publisher.Publisher') + @mock.patch("posttroll.publisher.Publisher") def test_publisher_all_arguments(self, Publisher): """Test that only valid arguments are passed to Publisher.""" - settings = {'port': 12345, 'nameservers': False, 'name': 'foo', + settings = {"port": 12345, 'nameservers': False, 'name': 'foo', 'min_port': 40000, 'max_port': 41000, 'invalid_arg': 'bar'} _ = create_publisher_from_dict_config(settings) _check_valid_settings_in_call(settings, Publisher, ignore=['port', 'nameservers']) From a49f615a3ea2270f5b3d6d1a7ea95904e576c0ba Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 6 Dec 2023 19:42:05 +0100 Subject: [PATCH 09/45] Fix style --- doc/source/conf.py | 45 +++--- posttroll/__init__.py | 10 +- posttroll/address_receiver.py | 20 +-- posttroll/backends/zmq/address_receiver.py | 7 +- posttroll/backends/zmq/message_broadcaster.py | 4 +- posttroll/backends/zmq/publisher.py | 3 +- posttroll/bbmcast.py | 60 ++++---- posttroll/listener.py | 17 +-- posttroll/logger.py | 50 +++---- posttroll/message.py | 124 +++++++---------- posttroll/ns.py | 1 - posttroll/subscriber.py | 50 +++---- posttroll/tests/test_bbmcast.py | 36 +++-- posttroll/tests/test_message.py | 130 ++++++++---------- posttroll/tests/test_pubsub.py | 48 +++---- setup.py | 34 ++--- 16 files changed, 296 insertions(+), 343 deletions(-) diff --git a/doc/source/conf.py b/doc/source/conf.py index 5593311..fd519b5 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -11,21 +11,22 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys import os +import sys + from posttroll import __version__ + # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. #sys.path.insert(0, os.path.abspath('.')) -sys.path.insert(0, os.path.abspath('../../')) -sys.path.insert(0, os.path.abspath('../../posttroll')) +sys.path.insert(0, os.path.abspath("../../")) +sys.path.insert(0, os.path.abspath("../../posttroll")) class Mock(object): - """A mocking class - """ + """A mocking class.""" def __init__(self, *args, **kwargs): pass @@ -34,8 +35,8 @@ def __call__(self, *args, **kwargs): @classmethod def __getattr__(cls, name): - if name in ('__file__', '__path__'): - return '/dev/null' + if name in ("__file__", "__path__"): + return "/dev/null" elif name[0] == name[0].upper(): mock_type = type(name, (), {}) mock_type.__module__ = __name__ @@ -43,7 +44,7 @@ def __getattr__(cls, name): else: return Mock() -MOCK_MODULES = ['zmq'] +MOCK_MODULES = ["zmq"] for mod_name in MOCK_MODULES: sys.modules[mod_name] = Mock() @@ -54,23 +55,23 @@ def __getattr__(cls, name): # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.doctest"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['sphinx_templates'] +templates_path = ["sphinx_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. #source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'PostTroll' -copyright = u'2012-2014, Pytroll crew' +project = u"PostTroll" +copyright = u"2012-2014, Pytroll crew" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -112,7 +113,7 @@ def __getattr__(cls, name): #show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] @@ -122,7 +123,7 @@ def __getattr__(cls, name): # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = "default" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -151,7 +152,7 @@ def __getattr__(cls, name): # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['sphinx_static'] +html_static_path = ["sphinx_static"] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. @@ -195,7 +196,7 @@ def __getattr__(cls, name): #html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'PostTrolldoc' +htmlhelp_basename = "PostTrolldoc" # -- Options for LaTeX output -------------------------------------------------- @@ -209,8 +210,8 @@ def __getattr__(cls, name): # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'PostTroll.tex', u'PostTroll Documentation', - u'Pytroll crew', 'manual'), + ("index", "PostTroll.tex", u"PostTroll Documentation", + u"Pytroll crew", "manual"), ] # The name of an image file (relative to this directory) to place at the top of @@ -242,6 +243,6 @@ def __getattr__(cls, name): # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - ('index', 'posttroll', u'PostTroll Documentation', - [u'Pytroll crew'], 1) + ("index", "posttroll", u"PostTroll Documentation", + [u"Pytroll crew"], 1) ] diff --git a/posttroll/__init__.py b/posttroll/__init__.py index 469d02e..3ed72f9 100644 --- a/posttroll/__init__.py +++ b/posttroll/__init__.py @@ -60,18 +60,18 @@ def strp_isoformat(strg): return strg if len(strg) < 19 or len(strg) > 26: if len(strg) > 30: - strg = strg[:30] + '...' + strg = strg[:30] + "..." raise ValueError("Invalid ISO formatted time string '%s'" % strg) if strg.find(".") == -1: - strg += '.000000' - if sys.version[0:3] >= '2.6': + strg += ".000000" + if sys.version[0:3] >= "2.6": return datetime.strptime(strg, "%Y-%m-%dT%H:%M:%S.%f") else: dat, mis = strg.split(".") dat = datetime.strptime(dat, "%Y-%m-%dT%H:%M:%S") - mis = int(float('.' + mis)*1000000) + mis = int(float("." + mis)*1000000) return dat.replace(microsecond=mis) -__version__ = get_versions()['version'] +__version__ = get_versions()["version"] del get_versions diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index c5a7661..fb36eaa 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -80,7 +80,7 @@ def __init__(self, max_age=ten_minutes, port=None, self._port = port or default_publish_port self._address_lock = threading.Lock() self._addresses = {} - self._subject = '/address' + self._subject = "/address" self._do_heartbeat = do_heartbeat self._multicast_enabled = multicast_enabled self._last_age_check = datetime(1900, 1, 1) @@ -117,7 +117,7 @@ def get(self, name=""): mda = copy.copy(metadata) mda["receive_time"] = mda["receive_time"].isoformat() addrs.append(mda) - LOGGER.debug('return address %s', str(addrs)) + LOGGER.debug("return address %s", str(addrs)) return addrs def _check_age(self, pub, min_interval=zero_seconds): @@ -133,10 +133,10 @@ def _check_age(self, pub, min_interval=zero_seconds): for addr, metadata in self._addresses.items(): atime = metadata["receive_time"] if now - atime > self._max_age: - mda = {'status': False, - 'URI': addr, - 'service': metadata['service']} - msg = Message('/address/' + metadata['name'], 'info', mda) + mda = {"status": False, + "URI": addr, + "service": metadata["service"]} + msg = Message("/address/" + metadata["name"], "info", mda) to_del.append(addr) LOGGER.info("publish remove '%s'", str(msg)) pub.send(msg.encode()) @@ -182,7 +182,7 @@ def _run(self): ip_, port = fromaddr if self._restrict_to_localhost and ip_ not in self._local_ips: # discard external message - LOGGER.debug('Discard external message') + LOGGER.debug("Discard external message") continue LOGGER.debug("data %s", data) except SocketTimeout: @@ -195,14 +195,14 @@ def _run(self): pub.heartbeat(min_interval=29) msg = Message.decode(data) name = msg.subject.split("/")[1] - if(msg.type == 'info' and + if(msg.type == "info" and msg.subject.lower().startswith(self._subject)): addr = msg.data["URI"] - msg.data['status'] = True + msg.data["status"] = True metadata = copy.copy(msg.data) metadata["name"] = name - LOGGER.debug('receiving address %s %s %s', str(addr), + LOGGER.debug("receiving address %s %s %s", str(addr), str(name), str(metadata)) if addr not in self._addresses: LOGGER.info("nameserver: publish add '%s'", diff --git a/posttroll/backends/zmq/address_receiver.py b/posttroll/backends/zmq/address_receiver.py index 858d6be..e6db4e6 100644 --- a/posttroll/backends/zmq/address_receiver.py +++ b/posttroll/backends/zmq/address_receiver.py @@ -1,10 +1,11 @@ +from zmq import LINGER, REP + from posttroll.address_receiver import default_publish_port from posttroll.backends.zmq import get_context -from zmq import REP, LINGER -class SimpleReceiver(object): - """ Simple listing on port for address messages.""" +class SimpleReceiver(object): + """Simple listing on port for address messages.""" def __init__(self, port=None): self._port = port or default_publish_port diff --git a/posttroll/backends/zmq/message_broadcaster.py b/posttroll/backends/zmq/message_broadcaster.py index 8d3310f..5fbff8d 100644 --- a/posttroll/backends/zmq/message_broadcaster.py +++ b/posttroll/backends/zmq/message_broadcaster.py @@ -1,9 +1,9 @@ +import logging import threading -from posttroll.backends.zmq import get_context from zmq import LINGER, NOBLOCK, REQ, ZMQError -import logging +from posttroll.backends.zmq import get_context logger = logging.getLogger(__name__) diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py index 6c88f1c..f4796a8 100644 --- a/posttroll/backends/zmq/publisher.py +++ b/posttroll/backends/zmq/publisher.py @@ -25,8 +25,7 @@ def __init__(self, address, name="", min_port=None, max_port=None): self._pub_lock = Lock() def start(self): - """Start the publisher. - """ + """Start the publisher.""" self.publish_socket = get_context().socket(zmq.PUB) _set_tcp_keepalive(self.publish_socket) diff --git a/posttroll/bbmcast.py b/posttroll/bbmcast.py index 358eff5..bb91d65 100644 --- a/posttroll/bbmcast.py +++ b/posttroll/bbmcast.py @@ -31,19 +31,32 @@ import logging import os import struct -from socket import (AF_INET, INADDR_ANY, IP_ADD_MEMBERSHIP, IP_MULTICAST_LOOP, - IP_MULTICAST_TTL, IPPROTO_IP, SO_BROADCAST, SO_REUSEADDR, - SOCK_DGRAM, SOL_IP, SOL_SOCKET, gethostbyname, socket, - timeout, SO_LINGER) - -__all__ = ('MulticastSender', 'MulticastReceiver', 'mcast_sender', - 'mcast_receiver', 'SocketTimeout') +from socket import ( + AF_INET, + INADDR_ANY, + IP_ADD_MEMBERSHIP, + IP_MULTICAST_LOOP, + IP_MULTICAST_TTL, + IPPROTO_IP, + SO_BROADCAST, + SO_LINGER, + SO_REUSEADDR, + SOCK_DGRAM, + SOL_IP, + SOL_SOCKET, + gethostbyname, + socket, + timeout, +) + +__all__ = ("MulticastSender", "MulticastReceiver", "mcast_sender", + "mcast_receiver", "SocketTimeout") # 224.0.0.0 through 224.0.0.255 is reserved administrative tasks -MC_GROUP = os.environ.get('PYTROLL_MC_GROUP', '225.0.0.212') +MC_GROUP = os.environ.get("PYTROLL_MC_GROUP", "225.0.0.212") # local network multicast (<32) -TTL_LOCALNET = int(os.environ.get('PYTROLL_MC_TTL', 31)) +TTL_LOCALNET = int(os.environ.get("PYTROLL_MC_TTL", 31)) logger = logging.getLogger(__name__) @@ -63,7 +76,7 @@ def __init__(self, port, mcgroup=MC_GROUP): self.port = port self.group = mcgroup self.socket, self.group = mcast_sender(mcgroup) - logger.debug('Started multicast group %s', mcgroup) + logger.debug("Started multicast group %s", mcgroup) def __call__(self, data): self.socket.sendto(data.encode(), (self.group, self.port)) @@ -76,20 +89,19 @@ def close(self): def mcast_sender(mcgroup=MC_GROUP): - """Non-object interface for sending multicast messages. - """ + """Non-object interface for sending multicast messages.""" sock = socket(AF_INET, SOCK_DGRAM) try: sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) if _is_broadcast_group(mcgroup): - group = '' + group = "" sock.setsockopt(SOL_SOCKET, SO_BROADCAST, 1) elif((int(mcgroup.split(".")[0]) > 239) or (int(mcgroup.split(".")[0]) < 224)): raise IOError("Invalid multicast address.") else: group = mcgroup - ttl = struct.pack('b', TTL_LOCALNET) # Time-to-live + ttl = struct.pack("b", TTL_LOCALNET) # Time-to-live sock.setsockopt(IPPROTO_IP, IP_MULTICAST_TTL, ttl) except Exception: sock.close() @@ -113,8 +125,7 @@ def __init__(self, port, mcgroup=MC_GROUP): self.socket, self.group = mcast_receiver(port, mcgroup) def settimeout(self, tout=None): - """A timeout will throw a 'socket.timeout'. - """ + """A timeout will throw a 'socket.timeout'.""" self.socket.settimeout(tout) return self @@ -124,16 +135,14 @@ def __call__(self): def close(self): """Close the receiver.""" - self.socket.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack('ii', 1, 1)) + self.socket.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 1, 1)) self.socket.close() # Allow non-object interface def mcast_receiver(port, mcgroup=MC_GROUP): - """Open a UDP socket, bind it to a port and select a multicast group. - """ - + """Open a UDP socket, bind it to a port and select a multicast group.""" if _is_broadcast_group(mcgroup): group = None else: @@ -151,7 +160,7 @@ def mcast_receiver(port, mcgroup=MC_GROUP): sock.setsockopt(SOL_IP, IP_MULTICAST_LOOP, 1) # default # Bind it to the port - sock.bind(('', port)) + sock.bind(("", port)) # Look up multicast group address in name server # (doesn't hurt if it is already in ddd.ddd.ddd.ddd format) @@ -166,7 +175,7 @@ def mcast_receiver(port, mcgroup=MC_GROUP): # Construct struct mreq from grpaddr and ifaddr ifaddr = INADDR_ANY - mreq = struct.pack('!LL', grpaddr, ifaddr) + mreq = struct.pack("!LL", grpaddr, ifaddr) # Add group membership sock.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq) @@ -174,7 +183,7 @@ def mcast_receiver(port, mcgroup=MC_GROUP): sock.close() raise - return sock, group or '' + return sock, group or "" # ----------------------------------------------------------------------------- # @@ -184,8 +193,7 @@ def mcast_receiver(port, mcgroup=MC_GROUP): def _is_broadcast_group(group): - """Check if *group* is a valid multicasting group. - """ - if not group or gethostbyname(group) in ('0.0.0.0', '255.255.255.255'): + """Check if *group* is a valid multicasting group.""" + if not group or gethostbyname(group) in ("0.0.0.0", "255.255.255.255"): return True return False diff --git a/posttroll/listener.py b/posttroll/listener.py index 608d134..e89c486 100644 --- a/posttroll/listener.py +++ b/posttroll/listener.py @@ -23,11 +23,12 @@ """Listener module.""" -from posttroll.subscriber import create_subscriber_from_dict_config +import logging +import time from queue import Queue from threading import Thread -import time -import logging + +from posttroll.subscriber import create_subscriber_from_dict_config class ListenerContainer: @@ -106,11 +107,11 @@ def create_subscriber(self): def _get_subscriber_config(self): config = { - 'services': self.services, - 'topics': self.topics, - 'addr_listener': True, - 'addresses': self.addresses, - 'nameserver': self.nameserver, + "services": self.services, + "topics": self.topics, + "addr_listener": True, + "addresses": self.addresses, + "nameserver": self.nameserver, } return config diff --git a/posttroll/logger.py b/posttroll/logger.py index 22db60f..e4755b9 100644 --- a/posttroll/logger.py +++ b/posttroll/logger.py @@ -20,28 +20,25 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -"""Logger for pytroll system. -""" +"""Logger for pytroll system.""" # TODO: remove old hanging subscriptions -from posttroll.subscriber import Subscribe -from posttroll.publisher import NoisyPublisher -from posttroll.message import Message -from threading import Thread - import copy import logging import logging.handlers +from threading import Thread + +from posttroll.message import Message +from posttroll.publisher import NoisyPublisher +from posttroll.subscriber import Subscribe LOGGER = logging.getLogger(__name__) class PytrollFormatter(logging.Formatter): - - """Formats a pytroll message inside a log record. - """ + """Formats a pytroll message inside a log record.""" def __init__(self, fmt, datefmt): logging.Formatter.__init__(self, fmt, datefmt) @@ -54,9 +51,7 @@ def format(self, record): class PytrollHandler(logging.Handler): - - """Sends the record through a pytroll publisher. - """ + """Sends the record through a pytroll publisher.""" def __init__(self, name, port=0): logging.Handler.__init__(self) @@ -75,11 +70,11 @@ def close(self): BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) COLORS = { - 'WARNING': YELLOW, - 'INFO': GREEN, - 'DEBUG': BLUE, - 'CRITICAL': MAGENTA, - 'ERROR': RED + "WARNING": YELLOW, + "INFO": GREEN, + "DEBUG": BLUE, + "CRITICAL": MAGENTA, + "ERROR": RED } COLOR_SEQ = "\033[1;%dm" @@ -87,9 +82,7 @@ def close(self): class ColoredFormatter(logging.Formatter): - - """Adds a color for the levelname. - """ + """Adds a color for the levelname.""" def __init__(self, msg, use_color=True): logging.Formatter.__init__(self, msg) @@ -110,7 +103,6 @@ def format(self, record): class Logger(object): - """The logging machine. Contains a thread listening to incomming messages, and a thread logging. @@ -122,13 +114,11 @@ def __init__(self, nameserver_address="localhost", nameserver_port=16543): self.loop = True def start(self): - """Starts the logging. - """ + """Starts the logging.""" self.log_thread.start() def log(self): - """Log stuff. - """ + """Log stuff.""" with Subscribe(services=[""], addr_listener=True) as sub: for msg in sub.recv(1): if msg: @@ -156,14 +146,12 @@ def log(self): break def stop(self): - """Stop the machine. - """ + """Stop the machine.""" self.loop = False def run(): - """Main function - """ + """Main function.""" import argparse global LOGGER @@ -212,5 +200,5 @@ def run(): print("Thanks for using pytroll/logger. " "See you soon on www.pytroll.org!") -if __name__ == '__main__': +if __name__ == "__main__": run() diff --git a/posttroll/message.py b/posttroll/message.py index 9ff6958..b56eba2 100644 --- a/posttroll/message.py +++ b/posttroll/message.py @@ -24,7 +24,7 @@ # pytroll. If not, see . """A Message goes like: - [mime-type data] + [mime-type data]. :: @@ -39,7 +39,7 @@ import re from datetime import datetime -import _strptime + try: import json except ImportError: @@ -47,14 +47,12 @@ from posttroll import strp_isoformat -_MAGICK = 'pytroll:/' -_VERSION = 'v1.01' +_MAGICK = "pytroll:/" +_VERSION = "v1.01" class MessageError(Exception): - - """This modules exceptions. - """ + """This modules exceptions.""" pass # ----------------------------------------------------------------------------- @@ -65,26 +63,22 @@ class MessageError(Exception): def is_valid_subject(obj): - """Currently we only check for empty strings. - """ + """Currently we only check for empty strings.""" return isinstance(obj, str) and bool(obj) def is_valid_type(obj): - """Currently we only check for empty strings. - """ + """Currently we only check for empty strings.""" return isinstance(obj, str) and bool(obj) def is_valid_sender(obj): - """Currently we only check for empty strings. - """ + """Currently we only check for empty strings.""" return isinstance(obj, str) and bool(obj) def is_valid_data(obj): - """Check if data is JSON serializable. - """ + """Check if data is JSON serializable.""" if obj: try: tmp = json.dumps(obj, default=datetime_encoder) @@ -101,7 +95,6 @@ def is_valid_data(obj): class Message(object): - """A Message. - Has to be initialized with a *rawstr* (encoded message to decode) OR @@ -112,7 +105,7 @@ class Message(object): - It will make a Message pickleable. """ - def __init__(self, subject='', atype='', data='', binary=False, rawstr=None): + def __init__(self, subject="", atype="", data="", binary=False, rawstr=None): """Initialize a Message from a subject, type and data... or from a raw string. """ @@ -120,11 +113,11 @@ def __init__(self, subject='', atype='', data='', binary=False, rawstr=None): self.__dict__ = _decode(rawstr) else: try: - self.subject = subject.decode('utf-8') + self.subject = subject.decode("utf-8") except AttributeError: self.subject = subject try: - self.type = atype.decode('utf-8') + self.type = atype.decode("utf-8") except AttributeError: self.type = atype self.type = atype @@ -137,38 +130,33 @@ def __init__(self, subject='', atype='', data='', binary=False, rawstr=None): @property def user(self): - """Try to return a user from a sender. - """ + """Try to return a user from a sender.""" try: - return self.sender[:self.sender.index('@')] + return self.sender[:self.sender.index("@")] except ValueError: - return '' + return "" @property def host(self): - """Try to return a host from a sender. - """ + """Try to return a host from a sender.""" try: - return self.sender[self.sender.index('@') + 1:] + return self.sender[self.sender.index("@") + 1:] except ValueError: - return '' + return "" @property def head(self): - """Return header of a message (a message without the data part). - """ + """Return header of a message (a message without the data part).""" self._validate() return _encode(self, head=True) @staticmethod def decode(rawstr): - """Decode a raw string into a Message. - """ + """Decode a raw string into a Message.""" return Message(rawstr=rawstr) def encode(self): - """Encode a Message to a raw string. - """ + """Encode a Message to a raw string.""" self._validate() return _encode(self, binary=self.binary) @@ -180,14 +168,13 @@ def __unicode__(self): def __str__(self): try: - return unicode(self).encode('utf-8') + return unicode(self).encode("utf-8") except NameError: return self.encode() def _validate(self): - """Validate a messages attributes. - """ + """Validate a messages attributes.""" if not is_valid_subject(self.subject): raise MessageError("Invalid subject: '%s'" % self.subject) if not is_valid_type(self.type): @@ -216,14 +203,12 @@ def __setstate__(self, state): def _is_valid_version(version): - """Check version. - """ + """Check version.""" return version == _VERSION def datetime_decoder(dct): - """Decode datetimes to python objects. - """ + """Decode datetimes to python objects.""" if isinstance(dct, list): pairs = enumerate(dct) elif isinstance(dct, dict): @@ -245,18 +230,17 @@ def datetime_decoder(dct): def _decode(rawstr): - """Convert a raw string to a Message. - """ + """Convert a raw string to a Message.""" # Check for the magick word. try: - rawstr = rawstr.decode('utf-8') + rawstr = rawstr.decode("utf-8") except (AttributeError, UnicodeEncodeError): pass except (UnicodeDecodeError): try: - rawstr = rawstr.decode('iso-8859-1') + rawstr = rawstr.decode("iso-8859-1") except (UnicodeDecodeError): - rawstr = rawstr.decode('utf-8', 'ignore') + rawstr = rawstr.decode("utf-8", "ignore") if not rawstr.startswith(_MAGICK): raise MessageError("This is not a '%s' message (wrong magick word)" % _MAGICK) @@ -272,11 +256,11 @@ def _decode(rawstr): raise MessageError("Invalid Message version: '%s'" % str(version)) # Start to build message - msg = dict((('subject', raw[0].strip()), - ('type', raw[1].strip()), - ('sender', raw[2].strip()), - ('time', strp_isoformat(raw[3].strip())), - ('version', version))) + msg = dict((("subject", raw[0].strip()), + ("type", raw[1].strip()), + ("sender", raw[2].strip()), + ("time", strp_isoformat(raw[3].strip())), + ("version", version))) # Data part try: @@ -286,20 +270,20 @@ def _decode(rawstr): mimetype = None if mimetype is None: - msg['data'] = '' - msg['binary'] = False - elif mimetype == 'application/json': + msg["data"] = "" + msg["binary"] = False + elif mimetype == "application/json": try: - msg['data'] = json.loads(raw[6], object_hook=datetime_decoder) - msg['binary'] = False + msg["data"] = json.loads(raw[6], object_hook=datetime_decoder) + msg["binary"] = False except ValueError: raise MessageError("JSON decode failed on '%s ...'" % raw[6][:36]) - elif mimetype == 'text/ascii': - msg['data'] = str(data) - msg['binary'] = False - elif mimetype == 'binary/octet-stream': - msg['data'] = data - msg['binary'] = True + elif mimetype == "text/ascii": + msg["data"] = str(data) + msg["binary"] = False + elif mimetype == "binary/octet-stream": + msg["data"] = data + msg["binary"] = True else: raise MessageError("Unknown mime-type '%s'" % mimetype) @@ -307,8 +291,7 @@ def _decode(rawstr): def datetime_encoder(obj): - """Encodes datetimes into iso format. - """ + """Encodes datetimes into iso format.""" try: return obj.isoformat() except AttributeError: @@ -316,22 +299,21 @@ def datetime_encoder(obj): def _encode(msg, head=False, binary=False): - """Convert a Message to a raw string. - """ + """Convert a Message to a raw string.""" rawstr = str(_MAGICK) + u"{0:s} {1:s} {2:s} {3:s} {4:s}".format( msg.subject, msg.type, msg.sender, msg.time.isoformat(), msg.version) if not head and msg.data: if not binary and isinstance(msg.data, str): - return (rawstr + ' ' + - 'text/ascii' + ' ' + msg.data) + return (rawstr + " " + + "text/ascii" + " " + msg.data) elif not binary: - return (rawstr + ' ' + - 'application/json' + ' ' + + return (rawstr + " " + + "application/json" + " " + json.dumps(msg.data, default=datetime_encoder)) else: - return (rawstr + ' ' + - 'binary/octet-stream' + ' ' + msg.data) + return (rawstr + " " + + "binary/octet-stream" + " " + msg.data) return rawstr # ----------------------------------------------------------------------------- diff --git a/posttroll/ns.py b/posttroll/ns.py index ae49bc7..e00e313 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -48,7 +48,6 @@ def get_pub_addresses(names=None, timeout=10, nameserver="localhost"): """Get the address of the publisher for a given list of publisher *names* from the nameserver on *nameserver* (localhost by default). """ - addrs = [] if names is None: names = ["", ] diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index a34fad1..ca4d993 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -61,7 +61,7 @@ class Subscriber: """ - def __init__(self, addresses, topics='', message_filter=None, translate=False): + def __init__(self, addresses, topics="", message_filter=None, translate=False): """Initialize the subscriber.""" topics = self._magickfy_topics(topics) backend = config.get("backend", "unsecure_zmq") @@ -143,10 +143,10 @@ def _magickfy_topics(topics): ts_ = [] for t__ in topics: if not t__.startswith(_MAGICK): - if t__ and t__[0] == '/': + if t__ and t__[0] == "/": t__ = _MAGICK + t__ else: - t__ = _MAGICK + '/' + t__ + t__ = _MAGICK + "/" + t__ ts_.append(t__) return ts_ @@ -257,14 +257,14 @@ def __init__(self, services="", topics=_MAGICK, addr_listener=False, message_filter=None): """Initialize the class.""" settings = { - 'services': services, - 'topics': topics, - 'message_filter': message_filter, - 'translate': translate, - 'addr_listener': addr_listener, - 'addresses': addresses, - 'timeout': timeout, - 'nameserver': nameserver, + "services": services, + "topics": topics, + "message_filter": message_filter, + "translate": translate, + "addr_listener": addr_listener, + "addresses": addresses, + "timeout": timeout, + "nameserver": nameserver, } self.subscriber = create_subscriber_from_dict_config(settings) @@ -302,9 +302,9 @@ def __init__(self, subscriber, services="", nameserver="localhost"): def handle_msg(self, msg): """Handle the message *msg*.""" addr_ = msg.data["URI"] - status = msg.data.get('status', True) + status = msg.data.get("status", True) if status: - msg_services = msg.data.get('service') + msg_services = msg.data.get("service") for service in self.services: if not service or service in msg_services: LOGGER.debug("Adding address %s %s", str(addr_), @@ -333,28 +333,28 @@ def create_subscriber_from_dict_config(settings): :class:`~posttroll.subscriber.Subscriber` and :class:`~posttroll.subscriber.NSSubscriber`. """ - if settings.get('addresses') and settings.get('nameserver') is False: + if settings.get("addresses") and settings.get("nameserver") is False: return _get_subscriber_instance(settings) return _get_nssubscriber_instance(settings).start() def _get_subscriber_instance(settings): - addresses = settings['addresses'] - topics = settings.get('topics', '') - message_filter = settings.get('message_filter', None) - translate = settings.get('translate', False) + addresses = settings["addresses"] + topics = settings.get("topics", "") + message_filter = settings.get("message_filter", None) + translate = settings.get("translate", False) return Subscriber(addresses, topics=topics, message_filter=message_filter, translate=translate) def _get_nssubscriber_instance(settings): - services = settings.get('services', '') - topics = settings.get('topics', _MAGICK) - addr_listener = settings.get('addr_listener', False) - addresses = settings.get('addresses', None) - timeout = settings.get('timeout', 10) - translate = settings.get('translate', False) - nameserver = settings.get('nameserver', 'localhost') or 'localhost' + services = settings.get("services", "") + topics = settings.get("topics", _MAGICK) + addr_listener = settings.get("addr_listener", False) + addresses = settings.get("addresses", None) + timeout = settings.get("timeout", 10) + translate = settings.get("translate", False) + nameserver = settings.get("nameserver", "localhost") or "localhost" return NSSubscriber( services=services, diff --git a/posttroll/tests/test_bbmcast.py b/posttroll/tests/test_bbmcast.py index 45b19b6..f99c202 100644 --- a/posttroll/tests/test_bbmcast.py +++ b/posttroll/tests/test_bbmcast.py @@ -20,45 +20,42 @@ # You should have received a copy of the GNU General Public License along with # pytroll. If not, see . -import unittest import random -from socket import SOL_SOCKET, SO_BROADCAST, error +import unittest +from socket import SO_BROADCAST, SOL_SOCKET, error from posttroll import bbmcast class TestBB(unittest.TestCase): - - """Test class. - """ + """Test class.""" def test_mcast_sender(self): - """Unit test for mcast_sender. - """ + """Unit test for mcast_sender.""" mcgroup = (str(random.randint(224, 239)) + "." + str(random.randint(0, 255)) + "." + str(random.randint(0, 255)) + "." + str(random.randint(0, 255))) socket, group = bbmcast.mcast_sender(mcgroup) if mcgroup in ("0.0.0.0", "255.255.255.255"): - self.assertEqual(group, "") - self.assertEqual(socket.getsockopt(SOL_SOCKET, SO_BROADCAST), 1) + assert group == "" + assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 else: - self.assertEqual(group, mcgroup) - self.assertEqual(socket.getsockopt(SOL_SOCKET, SO_BROADCAST), 0) + assert group == mcgroup + assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 0 socket.close() mcgroup = "0.0.0.0" socket, group = bbmcast.mcast_sender(mcgroup) - self.assertEqual(group, "") - self.assertEqual(socket.getsockopt(SOL_SOCKET, SO_BROADCAST), 1) + assert group == "" + assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 socket.close() mcgroup = "255.255.255.255" socket, group = bbmcast.mcast_sender(mcgroup) - self.assertEqual(group, "") - self.assertEqual(socket.getsockopt(SOL_SOCKET, SO_BROADCAST), 1) + assert group == "" + assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 socket.close() mcgroup = (str(random.randint(0, 223)) + "." + @@ -74,17 +71,16 @@ def test_mcast_sender(self): self.assertRaises(IOError, bbmcast.mcast_sender, mcgroup) def test_mcast_receiver(self): - """Unit test for mcast_receiver. - """ + """Unit test for mcast_receiver.""" mcport = random.randint(1025, 65535) mcgroup = "0.0.0.0" socket, group = bbmcast.mcast_receiver(mcport, mcgroup) - self.assertEqual(group, "") + assert group == "" socket.close() mcgroup = "255.255.255.255" socket, group = bbmcast.mcast_receiver(mcport, mcgroup) - self.assertEqual(group, "") + assert group == "" socket.close() # Valid multicast range is 224.0.0.0 to 239.255.255.255 @@ -93,7 +89,7 @@ def test_mcast_receiver(self): str(random.randint(0, 255)) + "." + str(random.randint(0, 255))) socket, group = bbmcast.mcast_receiver(mcport, mcgroup) - self.assertEqual(group, mcgroup) + assert group == mcgroup socket.close() mcgroup = (str(random.randint(0, 223)) + "." + diff --git a/posttroll/tests/test_message.py b/posttroll/tests/test_message.py index dbe088c..54e9063 100644 --- a/posttroll/tests/test_message.py +++ b/posttroll/tests/test_message.py @@ -21,76 +21,60 @@ # You should have received a copy of the GNU General Public License along with # pytroll. If not, see . -"""Test module for the message class. -""" +"""Test module for the message class.""" +import copy import os import sys import unittest -import copy from datetime import datetime -from posttroll.message import Message, _MAGICK +from posttroll.message import _MAGICK, Message +HOME = os.path.dirname(__file__) or "." +sys.path = [os.path.abspath(HOME + "/../.."), ] + sys.path -HOME = os.path.dirname(__file__) or '.' -sys.path = [os.path.abspath(HOME + '/../..'), ] + sys.path - -DATADIR = HOME + '/data' -SOME_METADATA = {'timestamp': datetime(2010, 12, 3, 16, 28, 39), - 'satellite': 'metop2', - 'uri': 'file://data/my/path/to/hrpt/files/myfile', - 'orbit': 1222, - 'format': 'hrpt', - 'afloat': 1.2345} +DATADIR = HOME + "/data" +SOME_METADATA = {"timestamp": datetime(2010, 12, 3, 16, 28, 39), + "satellite": "metop2", + "uri": "file://data/my/path/to/hrpt/files/myfile", + "orbit": 1222, + "format": "hrpt", + "afloat": 1.2345} class Test(unittest.TestCase): - - """Test class. - """ + """Test class.""" def test_encode_decode(self): - """Test the encoding/decoding of the message class. - """ - msg1 = Message('/test/whatup/doc', 'info', data='not much to say') + """Test the encoding/decoding of the message class.""" + msg1 = Message("/test/whatup/doc", "info", data="not much to say") - sender = '%s@%s' % (msg1.user, msg1.host) - self.assertTrue(sender == msg1.sender, - msg='Messaging, decoding user, host from sender failed') + sender = "%s@%s" % (msg1.user, msg1.host) + assert sender == msg1.sender, "Messaging, decoding user, host from sender failed" msg2 = Message.decode(msg1.encode()) - self.assertTrue(str(msg2) == str(msg1), - msg='Messaging, encoding, decoding failed') + assert str(msg2) == str(msg1), "Messaging, encoding, decoding failed" def test_decode(self): - """Test the decoding of a message. - """ + """Test the decoding of a message.""" rawstr = (_MAGICK + - r'/test/1/2/3 info ras@hawaii 2008-04-11T22:13:22.123000 v1.01' + + r"/test/1/2/3 info ras@hawaii 2008-04-11T22:13:22.123000 v1.01" + r' text/ascii "what' + r"'" + r's up doc"') msg = Message.decode(rawstr) - self.assertTrue(str(msg) == rawstr, - msg='Messaging, decoding of message failed') + assert str(msg) == rawstr, "Messaging, decoding of message failed" def test_encode(self): - """Test the encoding of a message. - """ - subject = '/test/whatup/doc' + """Test the encoding of a message.""" + subject = "/test/whatup/doc" atype = "info" - data = 'not much to say' + data = "not much to say" msg1 = Message(subject, atype, data=data) - sender = '%s@%s' % (msg1.user, msg1.host) - self.assertEqual(_MAGICK + - subject + " " + - atype + " " + - sender + " " + - str(msg1.time.isoformat()) + " " + - msg1.version + " " + - 'text/ascii' + " " + - data, - msg1.encode()) + sender = "%s@%s" % (msg1.user, msg1.host) + full_message = (_MAGICK + subject + " " + atype + " " + sender + " " + + str(msg1.time.isoformat()) + " " + msg1.version + " " + "text/ascii" + " " + data) + assert full_message == msg1.encode() def test_unicode(self): """Test handling of unicode.""" @@ -98,74 +82,69 @@ def test_unicode(self): msg = ('pytroll://PPS-monitorplot/3/norrköping/utv/polar/direct_readout/ file ' 'safusr.u@lxserv1096.smhi.se 2018-11-16T12:19:29.934025 v1.01 application/json' ' {"start_time": "2018-11-16T12:02:43.700000"}') - self.assertEqual(msg, str(Message(rawstr=msg))) + assert msg == str(Message(rawstr=msg)) except UnicodeDecodeError: - self.fail('Unexpected unicode decoding error') + self.fail("Unexpected unicode decoding error") try: msg = (u'pytroll://oper/polar/direct_readout/norrköping pong sat@MERLIN 2019-01-07T12:52:19.872171' r' v1.01 application/json {"station": "norrk\u00f6ping"}') try: - self.assertEqual(msg, str(Message(rawstr=msg)).decode('utf-8')) + assert msg == str(Message(rawstr=msg)).decode("utf-8") except AttributeError: - self.assertEqual(msg, str(Message(rawstr=msg))) + assert msg == str(Message(rawstr=msg)) except UnicodeDecodeError: - self.fail('Unexpected unicode decoding error') + self.fail("Unexpected unicode decoding error") def test_iso(self): """Test handling of iso-8859-1.""" - msg = 'pytroll://oper/polar/direct_readout/norrköping pong sat@MERLIN 2019-01-07T12:52:19.872171 v1.01 application/json {"station": "norrköping"}' + msg = 'pytroll://oper/polar/direct_readout/norrköping pong sat@MERLIN 2019-01-07T12:52:19.872171 v1.01 application/json {"station": "norrköping"}' # noqa: E501 try: - iso_msg = msg.decode('utf-8').encode('iso-8859-1') + iso_msg = msg.decode("utf-8").encode("iso-8859-1") except AttributeError: - iso_msg = msg.encode('iso-8859-1') + iso_msg = msg.encode("iso-8859-1") try: Message(rawstr=iso_msg) except UnicodeDecodeError: - self.fail('Unexpected iso decoding error') + self.fail("Unexpected iso decoding error") def test_pickle(self): - """Test pickling. - """ + """Test pickling.""" import pickle - msg1 = Message('/test/whatup/doc', 'info', data='not much to say') + msg1 = Message("/test/whatup/doc", "info", data="not much to say") try: - fp_ = open("pickle.message", 'wb') + fp_ = open("pickle.message", "wb") pickle.dump(msg1, fp_) fp_.close() - fp_ = open("pickle.message", 'rb') + fp_ = open("pickle.message", "rb") msg2 = pickle.load(fp_) fp_.close() - self.assertTrue(str(msg1) == str(msg2), - msg='Messaging, pickle failed') + assert str(msg1) == str(msg2), "Messaging, pickle failed" finally: try: - os.remove('pickle.message') + os.remove("pickle.message") except OSError: pass def test_metadata(self): - """Test metadata encoding/decoding. - """ + """Test metadata encoding/decoding.""" metadata = copy.copy(SOME_METADATA) - msg = Message.decode(Message('/sat/polar/smb/level1', 'file', + msg = Message.decode(Message("/sat/polar/smb/level1", "file", data=metadata).encode()) - self.assertTrue(msg.data == metadata, - msg='Messaging, metadata decoding / encoding failed') + assert msg.data == metadata, "Messaging, metadata decoding / encoding failed" def test_serialization(self): - """Test json serialization. - """ - compare_file = '/message_metadata.dumps' + """Test json serialization.""" + compare_file = "/message_metadata.dumps" try: import json except ImportError: import simplejson as json - compare_file += '.simplejson' + compare_file += ".simplejson" metadata = copy.copy(SOME_METADATA) - metadata['timestamp'] = metadata['timestamp'].isoformat() + metadata["timestamp"] = metadata["timestamp"].isoformat() fp_ = open(DATADIR + compare_file) dump = fp_.read() fp_.close() @@ -173,21 +152,20 @@ def test_serialization(self): msg = json.loads(dump) for key, val in msg.items(): - self.assertEqual(val, metadata.get(key)) + assert val == metadata.get(key) msg = json.loads(local_dump) for key, val in msg.items(): - self.assertEqual(val, metadata.get(key)) + assert val == metadata.get(key) def suite(): - """The suite for test_message. - """ + """The suite for test_message.""" loader = unittest.TestLoader() mysuite = unittest.TestSuite() mysuite.addTest(loader.loadTestsFromTestCase(Test)) return mysuite -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 607aeb2..1445de2 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -465,12 +465,12 @@ def test_publisher_is_selected(self): @mock.patch("posttroll.publisher.Publisher") def test_publisher_all_arguments(self, Publisher): """Test that only valid arguments are passed to Publisher.""" - settings = {"port": 12345, 'nameservers': False, 'name': 'foo', - 'min_port': 40000, 'max_port': 41000, 'invalid_arg': 'bar'} + settings = {"port": 12345, "nameservers": False, "name": "foo", + "min_port": 40000, "max_port": 41000, "invalid_arg": "bar"} _ = create_publisher_from_dict_config(settings) - _check_valid_settings_in_call(settings, Publisher, ignore=['port', 'nameservers']) + _check_valid_settings_in_call(settings, Publisher, ignore=["port", "nameservers"]) assert Publisher.call_args[0][0].startswith("tcp://*:") - assert Publisher.call_args[0][0].endswith(str(settings['port'])) + assert Publisher.call_args[0][0].endswith(str(settings["port"])) def test_no_name_raises_keyerror(self): """Trying to create a NoisyPublisher without a given name will raise KeyError.""" @@ -481,7 +481,7 @@ def test_noisypublisher_is_selected_only_name(self): """Test that NoisyPublisher is selected as publisher class.""" from posttroll.publisher import NoisyPublisher - settings = {'name': 'publisher_name'} + settings = {"name": "publisher_name"} pub = create_publisher_from_dict_config(settings) assert isinstance(pub, NoisyPublisher) @@ -490,21 +490,21 @@ def test_noisypublisher_is_selected_name_and_port(self): """Test that NoisyPublisher is selected as publisher class.""" from posttroll.publisher import NoisyPublisher - settings = {'name': 'publisher_name', 'port': 40000} + settings = {"name": "publisher_name", "port": 40000} pub = create_publisher_from_dict_config(settings) assert isinstance(pub, NoisyPublisher) - @mock.patch('posttroll.publisher.NoisyPublisher') + @mock.patch("posttroll.publisher.NoisyPublisher") def test_noisypublisher_all_arguments(self, NoisyPublisher): """Test that only valid arguments are passed to NoisyPublisher.""" from posttroll.publisher import create_publisher_from_dict_config - settings = {'port': 12345, 'nameservers': ['foo'], 'name': 'foo', - 'min_port': 40000, 'max_port': 41000, 'invalid_arg': 'bar', - 'aliases': ['alias1', 'alias2'], 'broadcast_interval': 42} + settings = {"port": 12345, "nameservers": ["foo"], "name": "foo", + "min_port": 40000, "max_port": 41000, "invalid_arg": "bar", + "aliases": ["alias1", "alias2"], "broadcast_interval": 42} _ = create_publisher_from_dict_config(settings) - _check_valid_settings_in_call(settings, NoisyPublisher, ignore=['name']) + _check_valid_settings_in_call(settings, NoisyPublisher, ignore=["name"]) assert NoisyPublisher.call_args[0][0] == settings["name"] def test_publish_is_not_noisy(self): @@ -532,23 +532,23 @@ def test_publish_is_noisy_with_nameservers(self): """Test that NoisyPublisher is selected with the context manager when nameservers are given.""" from posttroll.publisher import NoisyPublisher, Publish - with Publish("service_name", nameservers=['a', 'b']) as pub: + with Publish("service_name", nameservers=["a", "b"]) as pub: assert isinstance(pub, NoisyPublisher) def _check_valid_settings_in_call(settings, pub_class, ignore=None): ignore = ignore or [] for key in settings: - if key == 'invalid_arg': - assert 'invalid_arg' not in pub_class.call_args[1] + if key == "invalid_arg": + assert "invalid_arg" not in pub_class.call_args[1] continue if key in ignore: continue assert pub_class.call_args[1][key] == settings[key] -@mock.patch('posttroll.subscriber.Subscriber') -@mock.patch('posttroll.subscriber.NSSubscriber') +@mock.patch("posttroll.subscriber.Subscriber") +@mock.patch("posttroll.subscriber.NSSubscriber") def test_dict_config_minimal(NSSubscriber, Subscriber): """Test that without any settings NSSubscriber is created.""" from posttroll.subscriber import create_subscriber_from_dict_config @@ -559,31 +559,31 @@ def test_dict_config_minimal(NSSubscriber, Subscriber): Subscriber.assert_not_called() -@mock.patch('posttroll.subscriber.Subscriber') -@mock.patch('posttroll.subscriber.NSSubscriber') +@mock.patch("posttroll.subscriber.Subscriber") +@mock.patch("posttroll.subscriber.NSSubscriber") def test_dict_config_nameserver_false(NSSubscriber, Subscriber): """Test that NSSubscriber is created with 'localhost' nameserver when no addresses are given.""" from posttroll.subscriber import create_subscriber_from_dict_config - subscriber = create_subscriber_from_dict_config({'nameserver': False}) + subscriber = create_subscriber_from_dict_config({"nameserver": False}) NSSubscriber.assert_called_once() assert subscriber == NSSubscriber().start() Subscriber.assert_not_called() -@mock.patch('posttroll.subscriber.Subscriber') -@mock.patch('posttroll.subscriber.NSSubscriber') +@mock.patch("posttroll.subscriber.Subscriber") +@mock.patch("posttroll.subscriber.NSSubscriber") def test_dict_config_subscriber(NSSubscriber, Subscriber): """Test that Subscriber is created when nameserver is False and addresses are given.""" from posttroll.subscriber import create_subscriber_from_dict_config - subscriber = create_subscriber_from_dict_config({'nameserver': False, 'addresses': ['addr1']}) + subscriber = create_subscriber_from_dict_config({"nameserver": False, "addresses": ["addr1"]}) assert subscriber == Subscriber.return_value Subscriber.assert_called_once() NSSubscriber.assert_not_called() -@mock.patch('posttroll.subscriber.NSSubscriber.start') +@mock.patch("posttroll.subscriber.NSSubscriber.start") def test_dict_config_full_nssubscriber(NSSubscriber_start): """Test that all NSSubscriber options are passed.""" from posttroll.subscriber import create_subscriber_from_dict_config @@ -603,7 +603,7 @@ def test_dict_config_full_nssubscriber(NSSubscriber_start): NSSubscriber_start.assert_called_once() -@mock.patch('posttroll.subscriber.UnsecureZMQSubscriber.update') +@mock.patch("posttroll.subscriber.UnsecureZMQSubscriber.update") def test_dict_config_full_subscriber(Subscriber_update): """Test that all Subscriber options are passed.""" from posttroll.subscriber import create_subscriber_from_dict_config diff --git a/setup.py b/setup.py index 750dca2..9809c2e 100644 --- a/setup.py +++ b/setup.py @@ -23,35 +23,35 @@ # this program. If not, see . from setuptools import setup -import versioneer +import versioneer -requirements = ['pyzmq', 'netifaces', "donfig"] +requirements = ["pyzmq", "netifaces", "donfig"] setup(name="posttroll", version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), - description='Messaging system for pytroll', - author='The pytroll team', - author_email='pytroll@googlegroups.com', + description="Messaging system for pytroll", + author="The pytroll team", + author_email="pytroll@googlegroups.com", url="http://github.com/pytroll/posttroll", - packages=['posttroll'], + packages=["posttroll"], entry_points={ - 'console_scripts': ['pytroll-logger = posttroll.logger:run', ]}, - scripts=['bin/nameserver'], + "console_scripts": ["pytroll-logger = posttroll.logger:run", ]}, + scripts=["bin/nameserver"], zip_safe=False, license="GPLv3", install_requires=requirements, classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', - 'Programming Language :: Python', - 'Operating System :: OS Independent', - 'Intended Audience :: Science/Research', - 'Topic :: Scientific/Engineering', - 'Topic :: Communications' + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Programming Language :: Python", + "Operating System :: OS Independent", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering", + "Topic :: Communications" ], - python_requires='>=3.10', - test_suite='posttroll.tests.suite', + python_requires=">=3.10", + test_suite="posttroll.tests.suite", ) From 3af571930be5b869134c528468b67777dfa9311c Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 6 Dec 2023 19:44:57 +0100 Subject: [PATCH 10/45] Remove import of python 2 library --- posttroll/tests/test_message.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/posttroll/tests/test_message.py b/posttroll/tests/test_message.py index 54e9063..c6ab5e4 100644 --- a/posttroll/tests/test_message.py +++ b/posttroll/tests/test_message.py @@ -138,11 +138,7 @@ def test_metadata(self): def test_serialization(self): """Test json serialization.""" compare_file = "/message_metadata.dumps" - try: - import json - except ImportError: - import simplejson as json - compare_file += ".simplejson" + import json metadata = copy.copy(SOME_METADATA) metadata["timestamp"] = metadata["timestamp"].isoformat() fp_ = open(DATADIR + compare_file) From d90341b582f9dc7579b6cefb571af66701a0f339 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 7 Dec 2023 09:48:43 +0100 Subject: [PATCH 11/45] Allow passing interface to multicast on --- .github/workflows/ci.yaml | 2 +- posttroll/bbmcast.py | 54 ++++++--- posttroll/message_broadcaster.py | 6 +- posttroll/tests/test_bbmcast.py | 195 +++++++++++++++++++------------ setup.py | 2 + 5 files changed, 165 insertions(+), 94 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2fbaede..045c467 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,7 +21,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install -U pytest pytest-cov pyzmq netifaces donfig + pip install -U pytest pytest-cov pyzmq netifaces donfig pytest-reraise - name: Install posttroll run: | pip install --no-deps -e . diff --git a/posttroll/bbmcast.py b/posttroll/bbmcast.py index bb91d65..ebfcd53 100644 --- a/posttroll/bbmcast.py +++ b/posttroll/bbmcast.py @@ -31,10 +31,12 @@ import logging import os import struct +import warnings from socket import ( AF_INET, INADDR_ANY, IP_ADD_MEMBERSHIP, + IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL, IPPROTO_IP, @@ -45,15 +47,20 @@ SOL_IP, SOL_SOCKET, gethostbyname, + inet_aton, socket, timeout, ) +from posttroll import config + __all__ = ("MulticastSender", "MulticastReceiver", "mcast_sender", "mcast_receiver", "SocketTimeout") # 224.0.0.0 through 224.0.0.255 is reserved administrative tasks -MC_GROUP = os.environ.get("PYTROLL_MC_GROUP", "225.0.0.212") +DEFAULT_MC_GROUP = "225.0.0.212" + +MULTICAST_INTERFACE = config.get("multicast_interface", "0.0.0.0") # local network multicast (<32) TTL_LOCALNET = int(os.environ.get("PYTROLL_MC_TTL", 31)) @@ -69,14 +76,14 @@ # ----------------------------------------------------------------------------- -class MulticastSender(object): +class MulticastSender: """Multicast sender on *port* and *mcgroup*.""" - def __init__(self, port, mcgroup=MC_GROUP): + def __init__(self, port, mcgroup=None): self.port = port self.group = mcgroup self.socket, self.group = mcast_sender(mcgroup) - logger.debug("Started multicast group %s", mcgroup) + logger.debug("Started multicast group %s", self.group) def __call__(self, data): self.socket.sendto(data.encode(), (self.group, self.port)) @@ -88,8 +95,10 @@ def close(self): # Allow non-object interface -def mcast_sender(mcgroup=MC_GROUP): +def mcast_sender(mcgroup=None): """Non-object interface for sending multicast messages.""" + if mcgroup is None: + mcgroup = get_mc_group() sock = socket(AF_INET, SOCK_DGRAM) try: sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) @@ -98,16 +107,28 @@ def mcast_sender(mcgroup=MC_GROUP): sock.setsockopt(SOL_SOCKET, SO_BROADCAST, 1) elif((int(mcgroup.split(".")[0]) > 239) or (int(mcgroup.split(".")[0]) < 224)): - raise IOError("Invalid multicast address.") + raise IOError(f"Invalid multicast address {mcgroup}") else: group = mcgroup ttl = struct.pack("b", TTL_LOCALNET) # Time-to-live sock.setsockopt(IPPROTO_IP, IP_MULTICAST_TTL, ttl) + if MULTICAST_INTERFACE != "0.0.0.0": + sock.setsockopt(IPPROTO_IP, IP_MULTICAST_IF, inet_aton(MULTICAST_INTERFACE)) except Exception: sock.close() raise return sock, group +def get_mc_group(): + try: + mcgroup = os.environ["PYTROLL_MC_GROUP"] + warnings.warn("PYTROLL_MC_GROUP is pending deprecation, please use POSTTROLL_MC_GROUP instead.", + PendingDeprecationWarning) + except KeyError: + mcgroup = DEFAULT_MC_GROUP + mcgroup = config.get("mc_group", mcgroup) + return mcgroup + # ----------------------------------------------------------------------------- # # Receiver. @@ -119,7 +140,7 @@ class MulticastReceiver(object): """Multicast receiver on *port* for an *mcgroup*.""" BUFSIZE = 1024 - def __init__(self, port, mcgroup=MC_GROUP): + def __init__(self, port, mcgroup=None): # Note: a multicast receiver will also receive broadcast on same port. self.port = port self.socket, self.group = mcast_receiver(port, mcgroup) @@ -141,8 +162,10 @@ def close(self): # Allow non-object interface -def mcast_receiver(port, mcgroup=MC_GROUP): +def mcast_receiver(port, mcgroup=None): """Open a UDP socket, bind it to a port and select a multicast group.""" + if mcgroup is None: + mcgroup = get_mc_group() if _is_broadcast_group(mcgroup): group = None else: @@ -167,15 +190,14 @@ def mcast_receiver(port, mcgroup=MC_GROUP): if group: group = gethostbyname(group) - # Construct binary group address - bytes_ = [int(b) for b in group.split(".")] - grpaddr = 0 - for byte in bytes_: - grpaddr = (grpaddr << 8) | byte + # Construct struct mreq + if MULTICAST_INTERFACE == "0.0.0.0": + ifaddr = INADDR_ANY + mreq = struct.pack("=4sl", inet_aton(group), ifaddr) - # Construct struct mreq from grpaddr and ifaddr - ifaddr = INADDR_ANY - mreq = struct.pack("!LL", grpaddr, ifaddr) + else: + ifaddr = inet_aton(MULTICAST_INTERFACE) + mreq = struct.pack("=4s4s", inet_aton(group), ifaddr) # Add group membership sock.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq) diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py index 98ffdf8..6be9d29 100644 --- a/posttroll/message_broadcaster.py +++ b/posttroll/message_broadcaster.py @@ -26,7 +26,7 @@ import threading from posttroll import config, message -from posttroll.bbmcast import MC_GROUP, MulticastSender +from posttroll.bbmcast import MulticastSender __all__ = ("MessageBroadcaster", "AddressBroadcaster", "sendaddress") @@ -70,9 +70,7 @@ def __init__(self, msg, port, interval, designated_receivers=None): self._sender = DesignatedReceiversSender(port, designated_receivers) else: - # mcgroup = None or '' is broadcast - # mcgroup = MC_GROUP is default multicast group - self._sender = MulticastSender(port, mcgroup=MC_GROUP) + self._sender = MulticastSender(port) self._interval = interval self._message = msg diff --git a/posttroll/tests/test_bbmcast.py b/posttroll/tests/test_bbmcast.py index f99c202..ee7146b 100644 --- a/posttroll/tests/test_bbmcast.py +++ b/posttroll/tests/test_bbmcast.py @@ -20,86 +20,135 @@ # You should have received a copy of the GNU General Public License along with # pytroll. If not, see . +"""Test multicasting and broadcasting.""" + import random -import unittest from socket import SO_BROADCAST, SOL_SOCKET, error +from threading import Thread -from posttroll import bbmcast - - -class TestBB(unittest.TestCase): - """Test class.""" - - def test_mcast_sender(self): - """Unit test for mcast_sender.""" - mcgroup = (str(random.randint(224, 239)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255))) - socket, group = bbmcast.mcast_sender(mcgroup) - if mcgroup in ("0.0.0.0", "255.255.255.255"): - assert group == "" - assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 - else: - assert group == mcgroup - assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 0 +import pytest - socket.close() +from posttroll import bbmcast - mcgroup = "0.0.0.0" - socket, group = bbmcast.mcast_sender(mcgroup) - assert group == "" - assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 - socket.close() - mcgroup = "255.255.255.255" - socket, group = bbmcast.mcast_sender(mcgroup) +def test_mcast_sender_works_with_valid_addresses(): + """Unit test for mcast_sender.""" + mcgroup = (str(random.randint(224, 239)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255))) + socket, group = bbmcast.mcast_sender(mcgroup) + if mcgroup in ("0.0.0.0", "255.255.255.255"): assert group == "" assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 - socket.close() - - mcgroup = (str(random.randint(0, 223)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255))) - self.assertRaises(IOError, bbmcast.mcast_sender, mcgroup) - - mcgroup = (str(random.randint(240, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255))) - self.assertRaises(IOError, bbmcast.mcast_sender, mcgroup) - - def test_mcast_receiver(self): - """Unit test for mcast_receiver.""" - mcport = random.randint(1025, 65535) - mcgroup = "0.0.0.0" - socket, group = bbmcast.mcast_receiver(mcport, mcgroup) - assert group == "" - socket.close() - - mcgroup = "255.255.255.255" - socket, group = bbmcast.mcast_receiver(mcport, mcgroup) - assert group == "" - socket.close() - - # Valid multicast range is 224.0.0.0 to 239.255.255.255 - mcgroup = (str(random.randint(224, 239)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255))) - socket, group = bbmcast.mcast_receiver(mcport, mcgroup) + else: assert group == mcgroup + assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 0 + + socket.close() + + mcgroup = "0.0.0.0" + socket, group = bbmcast.mcast_sender(mcgroup) + assert group == "" + assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 + socket.close() + + mcgroup = "255.255.255.255" + socket, group = bbmcast.mcast_sender(mcgroup) + assert group == "" + assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 + socket.close() + + mcgroup = (str(random.randint(0, 223)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255))) + with pytest.raises(OSError, match="Invalid multicast address .*"): + bbmcast.mcast_sender(mcgroup) + + mcgroup = (str(random.randint(240, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255))) + with pytest.raises(OSError, match="Invalid multicast address .*"): + bbmcast.mcast_sender(mcgroup) + + +def test_mcast_receiver_works_with_valid_addresses(): + """Unit test for mcast_receiver.""" + mcport = random.randint(1025, 65535) + mcgroup = "0.0.0.0" + socket, group = bbmcast.mcast_receiver(mcport, mcgroup) + assert group == "" + socket.close() + + mcgroup = "255.255.255.255" + socket, group = bbmcast.mcast_receiver(mcport, mcgroup) + assert group == "" + socket.close() + + # Valid multicast range is 224.0.0.0 to 239.255.255.255 + mcgroup = (str(random.randint(224, 239)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255))) + socket, group = bbmcast.mcast_receiver(mcport, mcgroup) + assert group == mcgroup + socket.close() + + mcgroup = (str(random.randint(0, 223)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255))) + with pytest.raises(error, match=".*Invalid argument.*"): + bbmcast.mcast_receiver(mcport, mcgroup) + + mcgroup = (str(random.randint(240, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255))) + with pytest.raises(error, match=".*Invalid argument.*"): + bbmcast.mcast_receiver(mcport, mcgroup) + + +def test_mcast_send_recv(reraise): + """Test sending and receiving a multicast message.""" + mcgroup = bbmcast.DEFAULT_MC_GROUP + mcport = 5555 + rec_socket, rec_group = bbmcast.mcast_receiver(mcport, mcgroup) + + message = "Ho Ho Ho!" + + def check_message(sock, message): + data, sender = sock.recvfrom(1024) + with reraise: + assert data.decode() == message + + snd_socket, snd_group = bbmcast.mcast_sender(mcgroup) + + thr = Thread(target=check_message, args=(rec_socket, message)) + thr.start() + + snd_socket.sendto(message.encode(), (mcgroup, mcport)) + + thr.join() + rec_socket.close() + snd_socket.close() + +def test_posttroll_mc_group_is_used(): + """Test that configured mc_group is used.""" + from posttroll import config + other_group = "226.0.0.13" + with config.set(mc_group=other_group): + socket, group = bbmcast.mcast_sender() socket.close() - - mcgroup = (str(random.randint(0, 223)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255))) - self.assertRaises(error, bbmcast.mcast_receiver, mcport, mcgroup) - - mcgroup = (str(random.randint(240, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255))) - self.assertRaises(error, bbmcast.mcast_receiver, mcport, mcgroup) + assert group == "226.0.0.13" + +def test_pytroll_mc_group_is_deprecated(monkeypatch): + """Test that PYTROLL_MC_GROUP is used but pending deprecation.""" + other_group = "226.0.0.13" + monkeypatch.setenv("PYTROLL_MC_GROUP", other_group) + with pytest.deprecated_call(): + socket, group = bbmcast.mcast_sender() + socket.close() + assert group == "226.0.0.13" diff --git a/setup.py b/setup.py index 9809c2e..ad154e8 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,8 @@ # You should have received a copy of the GNU General Public License along with # this program. If not, see . +"""Set up the package.""" + from setuptools import setup import versioneer From 0c8141c743ad2b87f89f7e4feb665c585028c7b1 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 7 Dec 2023 09:57:49 +0100 Subject: [PATCH 12/45] Add some docstrings --- posttroll/bbmcast.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/posttroll/bbmcast.py b/posttroll/bbmcast.py index ebfcd53..cba9e34 100644 --- a/posttroll/bbmcast.py +++ b/posttroll/bbmcast.py @@ -80,12 +80,14 @@ class MulticastSender: """Multicast sender on *port* and *mcgroup*.""" def __init__(self, port, mcgroup=None): + """Set up the multicast sender.""" self.port = port self.group = mcgroup self.socket, self.group = mcast_sender(mcgroup) logger.debug("Started multicast group %s", self.group) def __call__(self, data): + """Send data.""" self.socket.sendto(data.encode(), (self.group, self.port)) def close(self): @@ -141,6 +143,7 @@ class MulticastReceiver(object): BUFSIZE = 1024 def __init__(self, port, mcgroup=None): + """Set up the multicast receiver.""" # Note: a multicast receiver will also receive broadcast on same port. self.port = port self.socket, self.group = mcast_receiver(port, mcgroup) @@ -151,6 +154,7 @@ def settimeout(self, tout=None): return self def __call__(self): + """Receive data.""" data, sender = self.socket.recvfrom(self.BUFSIZE) return data.decode(), sender From 38ed263b19dbf99f9259a98697eefa58b5ddf106 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 7 Dec 2023 10:58:43 +0100 Subject: [PATCH 13/45] Improve documentation and style --- doc/source/index.rst | 11 +++++++++++ posttroll/backends/zmq/address_receiver.py | 4 ++++ posttroll/tests/test_bbmcast.py | 2 +- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/doc/source/index.rst b/doc/source/index.rst index e66bba3..57aa0f0 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -115,6 +115,17 @@ to specify the nameserver(s) explicitly in the publishing code:: .. seealso:: :class:`posttroll.publisher.Publish` and :class:`posttroll.subscriber.Subscribe` +Configuration parameters +------------------------ + +Global configuration variables that are available through a Donfig configuration object: +- tcp_keepalive +- tcp_keepalive_cnt +- tcp_keepalive_idle +- tcp_keepalive_intvl +- multicast_interface +- mc_group + Setting TCP keep-alive ---------------------- diff --git a/posttroll/backends/zmq/address_receiver.py b/posttroll/backends/zmq/address_receiver.py index e6db4e6..0052b3e 100644 --- a/posttroll/backends/zmq/address_receiver.py +++ b/posttroll/backends/zmq/address_receiver.py @@ -1,3 +1,5 @@ +"""ZMQ implementation of the the simple receiver.""" + from zmq import LINGER, REP from posttroll.address_receiver import default_publish_port @@ -8,11 +10,13 @@ class SimpleReceiver(object): """Simple listing on port for address messages.""" def __init__(self, port=None): + """Set up the receiver.""" self._port = port or default_publish_port self._socket = get_context().socket(REP) self._socket.bind("tcp://*:" + str(port)) def __call__(self): + """Receive a message.""" message = self._socket.recv_string() self._socket.send_string("ok") return message, None diff --git a/posttroll/tests/test_bbmcast.py b/posttroll/tests/test_bbmcast.py index ee7146b..0baf65e 100644 --- a/posttroll/tests/test_bbmcast.py +++ b/posttroll/tests/test_bbmcast.py @@ -120,7 +120,7 @@ def test_mcast_send_recv(reraise): message = "Ho Ho Ho!" def check_message(sock, message): - data, sender = sock.recvfrom(1024) + data, _ = sock.recvfrom(1024) with reraise: assert data.decode() == message From 558166b68f4cd3e922c348921b2d69d4fd861ad3 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 7 Dec 2023 10:59:10 +0100 Subject: [PATCH 14/45] Refactor bbmcast a tad --- posttroll/bbmcast.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/posttroll/bbmcast.py b/posttroll/bbmcast.py index cba9e34..9d0f856 100644 --- a/posttroll/bbmcast.py +++ b/posttroll/bbmcast.py @@ -32,6 +32,7 @@ import os import struct import warnings +from contextlib import suppress from socket import ( AF_INET, INADDR_ANY, @@ -60,8 +61,6 @@ # 224.0.0.0 through 224.0.0.255 is reserved administrative tasks DEFAULT_MC_GROUP = "225.0.0.212" -MULTICAST_INTERFACE = config.get("multicast_interface", "0.0.0.0") - # local network multicast (<32) TTL_LOCALNET = int(os.environ.get("PYTROLL_MC_TTL", 31)) @@ -114,8 +113,10 @@ def mcast_sender(mcgroup=None): group = mcgroup ttl = struct.pack("b", TTL_LOCALNET) # Time-to-live sock.setsockopt(IPPROTO_IP, IP_MULTICAST_TTL, ttl) - if MULTICAST_INTERFACE != "0.0.0.0": - sock.setsockopt(IPPROTO_IP, IP_MULTICAST_IF, inet_aton(MULTICAST_INTERFACE)) + + with suppress(KeyError): + multicast_interface = config.get("multicast_interface") + sock.setsockopt(IPPROTO_IP, IP_MULTICAST_IF, inet_aton(multicast_interface)) except Exception: sock.close() raise @@ -195,14 +196,14 @@ def mcast_receiver(port, mcgroup=None): group = gethostbyname(group) # Construct struct mreq - if MULTICAST_INTERFACE == "0.0.0.0": + try: + multicast_interface = config.get("multicast_interface") + ifaddr = inet_aton(multicast_interface) + mreq = struct.pack("=4s4s", inet_aton(group), ifaddr) + except KeyError: ifaddr = INADDR_ANY mreq = struct.pack("=4sl", inet_aton(group), ifaddr) - else: - ifaddr = inet_aton(MULTICAST_INTERFACE) - mreq = struct.pack("=4s4s", inet_aton(group), ifaddr) - # Add group membership sock.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq) except Exception: From 927277aa84aa49cc42c28e7d07094402027acece Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 7 Dec 2023 10:59:38 +0100 Subject: [PATCH 15/45] Switch one class to pytest --- posttroll/tests/test_pubsub.py | 51 +++++++++++++++++----------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 1445de2..d12e942 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -128,10 +128,10 @@ def test_pub_sub_add_rm(self): sub.close() -class TestNSWithoutMulticasting(unittest.TestCase): +class TestNSWithoutMulticasting: """Test the nameserver.""" - def setUp(self): + def setup_method(self): """Set up the testing class.""" test_lock.acquire() self.nameservers = ["localhost"] @@ -140,39 +140,40 @@ def setUp(self): self.thr = Thread(target=self.ns.run) self.thr.start() - def tearDown(self): + def teardown_method(self): """Clean up after the tests have run.""" self.ns.stop() self.thr.join() time.sleep(2) test_lock.release() - def test_pub_addresses(self): + def test_pub_addresses(self, reraise): """Test retrieving addresses.""" from posttroll.ns import get_pub_addresses from posttroll.publisher import Publish - with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers): - time.sleep(3) - res = get_pub_addresses(["this_data"]) - assert len(res) == 1 - expected = {u"status": True, - u"service": [u"data_provider", u"this_data"], - u"name": u"address"} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] - res = get_pub_addresses(["data_provider"]) - assert len(res) == 1 - expected = {u"status": True, - u"service": [u"data_provider", u"this_data"], - u"name": u"address"} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] + with reraise: + with Publish("data_provider", 0, ["this_data"], + nameservers=self.nameservers): + time.sleep(3) + res = get_pub_addresses(["this_data"]) + assert len(res) == 1 + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} + for key, val in expected.items(): + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] + res = get_pub_addresses(["data_provider"]) + assert len(res) == 1 + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} + for key, val in expected.items(): + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] def test_pub_sub_ctx(self): """Test publish and subscribe.""" From 2d0eb0dbe2315ea14c4356cd9ede8c1f03686b33 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 7 Dec 2023 12:58:09 +0100 Subject: [PATCH 16/45] Speed up some tests --- posttroll/address_receiver.py | 3 +- posttroll/backends/zmq/ns.py | 2 +- posttroll/ns.py | 5 +- posttroll/publisher.py | 4 ++ posttroll/tests/test_pubsub.py | 85 +++++++++++++++++----------------- 5 files changed, 53 insertions(+), 46 deletions(-) diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index fb36eaa..5c4cb5e 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -76,6 +76,7 @@ class AddressReceiver(object): def __init__(self, max_age=ten_minutes, port=None, do_heartbeat=True, multicast_enabled=True, restrict_to_localhost=False): + """Set up the address receiver.""" self._max_age = max_age self._port = port or default_publish_port self._address_lock = threading.Lock() @@ -138,7 +139,7 @@ def _check_age(self, pub, min_interval=zero_seconds): "service": metadata["service"]} msg = Message("/address/" + metadata["name"], "info", mda) to_del.append(addr) - LOGGER.info("publish remove '%s'", str(msg)) + LOGGER.info(f"publish remove '{msg}'") pub.send(msg.encode()) for addr in to_del: del self._addresses[addr] diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py index 7255e05..bb4d0b0 100644 --- a/posttroll/backends/zmq/ns.py +++ b/posttroll/backends/zmq/ns.py @@ -61,7 +61,7 @@ def run(self, arec): with nslock: self.listener = get_context().socket(REP) self.listener.bind("tcp://*:" + str(port)) - logger.debug("Listening on port %s", str(port)) + logger.debug(f"Nameserver listening on port {port}") poller = Poller() poller.register(self.listener, POLLIN) while self.loop: diff --git a/posttroll/ns.py b/posttroll/ns.py index e00e313..cc41665 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -45,7 +45,9 @@ def get_pub_addresses(names=None, timeout=10, nameserver="localhost"): - """Get the address of the publisher for a given list of publisher *names* + """Get the address of the publisher. + + For a given list of publisher *names* from the nameserver on *nameserver* (localhost by default). """ addrs = [] @@ -90,6 +92,7 @@ class NameServer: """The name server.""" def __init__(self, max_age=timedelta(minutes=10), multicast_enabled=True, restrict_to_localhost=False): + """Set up the nameserver.""" self._max_age = max_age self._multicast_enabled = multicast_enabled self._restrict_to_localhost = restrict_to_localhost diff --git a/posttroll/publisher.py b/posttroll/publisher.py index 317bfdf..2110b13 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -231,6 +231,10 @@ def port_number(self): """Get the port number.""" return self._publisher.port_number + def heartbeat(self, min_interval=0): + """Send a heartbeat ... but only if *min_interval* seconds has passed since last beat.""" + self._publisher.heartbeat(min_interval) + def _create_tcp_publish_address(port, ip_address="*"): return "tcp://" + ip_address + ":" + str(port) diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index d12e942..d04fe07 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -135,7 +135,8 @@ def setup_method(self): """Set up the testing class.""" test_lock.acquire() self.nameservers = ["localhost"] - self.ns = NameServer(max_age=timedelta(seconds=3), + self.max_age = .3 + self.ns = NameServer(max_age=timedelta(seconds=self.max_age), multicast_enabled=False) self.thr = Thread(target=self.ns.run) self.thr.start() @@ -147,33 +148,32 @@ def teardown_method(self): time.sleep(2) test_lock.release() - def test_pub_addresses(self, reraise): + def test_pub_addresses(self): """Test retrieving addresses.""" from posttroll.ns import get_pub_addresses from posttroll.publisher import Publish - with reraise: - with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers): - time.sleep(3) - res = get_pub_addresses(["this_data"]) - assert len(res) == 1 - expected = {u"status": True, - u"service": [u"data_provider", u"this_data"], - u"name": u"address"} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] - res = get_pub_addresses(["data_provider"]) - assert len(res) == 1 - expected = {u"status": True, - u"service": [u"data_provider", u"this_data"], - u"name": u"address"} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] + with Publish("data_provider", 0, ["this_data"], + nameservers=self.nameservers, broadcast_interval=.1): + time.sleep(.2) + res = get_pub_addresses(["this_data"]) + assert len(res) == 1 + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} + for key, val in expected.items(): + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] + res = get_pub_addresses(["data_provider"]) + assert len(res) == 1 + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} + for key, val in expected.items(): + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] def test_pub_sub_ctx(self): """Test publish and subscribe.""" @@ -182,12 +182,12 @@ def test_pub_sub_ctx(self): from posttroll.subscriber import Subscribe with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers) as pub: + nameservers=self.nameservers, broadcast_interval=.1) as pub: with Subscribe("this_data", "counter") as sub: for counter in range(5): message = Message("/counter", "info", str(counter)) pub.send(str(message)) - time.sleep(1) + time.sleep(.1) msg = next(sub.recv(2)) if msg is not None: assert str(msg) == str(message) @@ -200,26 +200,27 @@ def test_pub_sub_add_rm(self): from posttroll.publisher import Publish from posttroll.subscriber import Subscribe - time.sleep(4) - with Subscribe("this_data", "counter", True) as sub: + with Subscribe("this_data", "counter", True, timeout=.1) as sub: assert len(sub.addresses) == 0 with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers): + nameservers=self.nameservers, broadcast_interval=.1): time.sleep(4) - next(sub.recv(2)) + next(sub.recv(.2)) assert len(sub.addresses) == 1 + time.sleep(3) - for msg in sub.recv(2): + + for msg in sub.recv(.2): if msg is None: break time.sleep(3) - assert len(sub.sub_addr) == 0 + assert len(sub.addresses) == 0 with Publish("data_provider_2", 0, ["another_data"], - nameservers=self.nameservers): + nameservers=self.nameservers, broadcast_interval=.1): time.sleep(4) - next(sub.recv(2)) - assert len(sub.sub_addr) == 0 + next(sub.recv(.2)) + assert len(sub.addresses) == 0 class TestPubSub(unittest.TestCase): @@ -649,40 +650,38 @@ def _tcp_keepalive_no_settings(): def test_publisher_tcp_keepalive(): """Test that TCP Keepalive is set for Publisher if the environment variables are present.""" from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - pub = UnsecureZMQPublisher("tcp://127.0.0.1:9000").start() + pub = UnsecureZMQPublisher("tcp://127.0.0.1:9001").start() _assert_tcp_keepalive(pub.publish_socket) + pub.stop() @pytest.mark.usefixtures("_tcp_keepalive_no_settings") def test_publisher_tcp_keepalive_not_set(): """Test that TCP Keepalive is not set on by default.""" from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - pub = UnsecureZMQPublisher("tcp://127.0.0.1:9000").start() + pub = UnsecureZMQPublisher("tcp://127.0.0.1:9002").start() _assert_no_tcp_keepalive(pub.publish_socket) + pub.stop() @pytest.mark.usefixtures("_tcp_keepalive_settings") def test_subscriber_tcp_keepalive(): """Test that TCP Keepalive is set for Subscriber if the environment variables are present.""" from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber - sub = UnsecureZMQSubscriber("tcp://127.0.0.1:9000") - assert len(sub.addr_sub.values()) == 1 - _assert_tcp_keepalive(list(sub.addr_sub.values())[0]) + sub.stop() @pytest.mark.usefixtures("_tcp_keepalive_no_settings") def test_subscriber_tcp_keepalive_not_set(): """Test that TCP Keepalive is not set on by default.""" from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber - sub = UnsecureZMQSubscriber("tcp://127.0.0.1:9000") - assert len(sub.addr_sub.values()) == 1 - _assert_no_tcp_keepalive(list(sub.addr_sub.values())[0]) + sub.close() def _assert_tcp_keepalive(socket): From 2afc51fe1256c2913a6b7ea19454939fb418a257 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Fri, 2 Feb 2024 16:06:19 +0100 Subject: [PATCH 17/45] Refactor for adding secure implementations of pub and sub --- posttroll/backends/zmq/publisher.py | 50 +++++++ posttroll/backends/zmq/subscriber.py | 190 +++++++++++++++++++++++++++ posttroll/publisher.py | 4 +- posttroll/subscriber.py | 6 +- posttroll/tests/test_pubsub.py | 24 +++- 5 files changed, 265 insertions(+), 9 deletions(-) diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py index f4796a8..8c5489e 100644 --- a/posttroll/backends/zmq/publisher.py +++ b/posttroll/backends/zmq/publisher.py @@ -60,3 +60,53 @@ def stop(self): """Stop the publisher.""" self.publish_socket.setsockopt(zmq.LINGER, 1) self.publish_socket.close() + +class SecureZMQPublisher: + """Unsecure ZMQ implementation of the publisher class.""" + + def __init__(self, address, name="", min_port=None, max_port=None): + """Bind the publisher class to a port.""" + self.name = name + self.destination = address + self.publish_socket = None + self.min_port = min_port + self.max_port = max_port + self.port_number = None + self._pub_lock = Lock() + + def start(self): + """Start the publisher.""" + self.publish_socket = get_context().socket(zmq.PUB) + _set_tcp_keepalive(self.publish_socket) + + self._bind() + LOGGER.info(f"Publisher for {self.destination} started on port {self.port_number}.") + return self + + def _bind(self): + # Check for port 0 (random port) + u__ = urlsplit(self.destination) + port = u__.port + if port == 0: + dest = urlunsplit((u__.scheme, u__.hostname, + u__.path, u__.query, u__.fragment)) + self.port_number = self.publish_socket.bind_to_random_port( + dest, + min_port=self.min_port, + max_port=self.max_port) + netloc = u__.hostname + ":" + str(self.port_number) + self.destination = urlunsplit((u__.scheme, netloc, u__.path, + u__.query, u__.fragment)) + else: + self.publish_socket.bind(self.destination) + self.port_number = port + + def send(self, msg): + """Send the given message.""" + with self._pub_lock: + self.publish_socket.send_string(msg) + + def stop(self): + """Stop the publisher.""" + self.publish_socket.setsockopt(zmq.LINGER, 1) + self.publish_socket.close() diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index caf267b..b0f3b63 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -200,3 +200,193 @@ def __del__(self): sub.close() except Exception: # noqa: E722 pass + + +class SecureZMQSubscriber: + """Unsecure ZMQ implementation of the subscriber.""" + + def __init__(self, addresses, topics="", message_filter=None, translate=False): + """Initialize the subscriber.""" + self._topics = topics + self._filter = message_filter + self._translate = translate + + self.sub_addr = {} + self.addr_sub = {} + + self._hooks = [] + self._hooks_cb = {} + + self.poller = Poller() + self._lock = Lock() + + self.update(addresses) + + self._loop = None + + def add(self, address, topics=None): + """Add *address* to the subscribing list for *topics*. + + It topics is None we will subscribe to already specified topics. + """ + with self._lock: + if address in self.addresses: + return + + topics = topics or self._topics + LOGGER.info("Subscriber adding address %s with topics %s", + str(address), str(topics)) + subscriber = self._add_sub_socket(address, topics) + self.sub_addr[subscriber] = address + self.addr_sub[address] = subscriber + + def _add_sub_socket(self, address, topics): + subscriber = get_context().socket(SUB) + _set_tcp_keepalive(subscriber) + for t__ in topics: + subscriber.setsockopt_string(SUBSCRIBE, str(t__)) + subscriber.connect(address) + + if self.poller: + self.poller.register(subscriber, POLLIN) + return subscriber + + def remove(self, address): + """Remove *address* from the subscribing list for *topics*.""" + with self._lock: + try: + subscriber = self.addr_sub[address] + except KeyError: + return + LOGGER.info("Subscriber removing address %s", str(address)) + del self.addr_sub[address] + del self.sub_addr[subscriber] + self._remove_sub_socket(subscriber) + + def _remove_sub_socket(self, subscriber): + if self.poller: + self.poller.unregister(subscriber) + subscriber.close() + + def update(self, addresses): + """Update with a set of addresses.""" + if isinstance(addresses, str): + addresses = [addresses, ] + current_addresses, new_addresses = set(self.addresses), set(addresses) + addresses_to_remove = current_addresses.difference(new_addresses) + addresses_to_add = new_addresses.difference(current_addresses) + for addr in addresses_to_remove: + self.remove(addr) + for addr in addresses_to_add: + self.add(addr) + return bool(addresses_to_remove or addresses_to_add) + + def add_hook_sub(self, address, topics, callback): + """Specify a SUB *callback* in the same stream (thread) as the main receive loop. + + The callback will be called with the received messages from the + specified subscription. + + Good for operations, which is required to be done in the same thread as + the main recieve loop (e.q operations on the underlying sockets). + """ + topics = topics + LOGGER.info("Subscriber adding SUB hook %s for topics %s", + str(address), str(topics)) + socket = self._add_sub_socket(address, topics) + self._add_hook(socket, callback) + + def add_hook_pull(self, address, callback): + """Specify a PULL *callback* in the same stream (thread) as the main receive loop. + + The callback will be called with the received messages from the + specified subscription. Good for pushed 'inproc' messages from another thread. + """ + LOGGER.info("Subscriber adding PULL hook %s", str(address)) + socket = get_context().socket(PULL) + socket.connect(address) + if self.poller: + self.poller.register(socket, POLLIN) + self._add_hook(socket, callback) + + def _add_hook(self, socket, callback): + """Add a generic hook. The passed socket has to be "receive only".""" + self._hooks.append(socket) + self._hooks_cb[socket] = callback + + + @property + def addresses(self): + """Get the addresses.""" + return self.sub_addr.values() + + @property + def subscribers(self): + """Get the subscribers.""" + return self.sub_addr.keys() + + def recv(self, timeout=None): + """Receive, optionally with *timeout* in seconds.""" + if timeout: + timeout *= 1000. + + for sub in list(self.subscribers) + self._hooks: + self.poller.register(sub, POLLIN) + self._loop = True + try: + while self._loop: + sleep(0) + try: + socks = dict(self.poller.poll(timeout=timeout)) + if socks: + for sub in self.subscribers: + if sub in socks and socks[sub] == POLLIN: + received = sub.recv_string(NOBLOCK) + m__ = Message.decode(received) + if not self._filter or self._filter(m__): + if self._translate: + url = urlsplit(self.sub_addr[sub]) + host = url[1].split(":")[0] + m__.sender = (m__.sender.split("@")[0] + + "@" + host) + yield m__ + + for sub in self._hooks: + if sub in socks and socks[sub] == POLLIN: + m__ = Message.decode(sub.recv_string(NOBLOCK)) + self._hooks_cb[sub](m__) + else: + # timeout + yield None + except ZMQError as err: + if self._loop: + LOGGER.exception("Receive failed: %s", str(err)) + finally: + for sub in list(self.subscribers) + self._hooks: + self.poller.unregister(sub) + + def __call__(self, **kwargs): + """Handle calls with class instance.""" + return self.recv(**kwargs) + + def stop(self): + """Stop the subscriber.""" + self._loop = False + + def close(self): + """Close the subscriber: stop it and close the local subscribers.""" + self.stop() + for sub in list(self.subscribers) + self._hooks: + try: + sub.setsockopt(LINGER, 1) + sub.close() + except ZMQError: + pass + + def __del__(self): + """Clean up after the instance is deleted.""" + for sub in list(self.subscribers) + self._hooks: + try: + sub.close() + except Exception: # noqa: E722 + pass diff --git a/posttroll/publisher.py b/posttroll/publisher.py index 2110b13..a16d8d2 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -99,8 +99,8 @@ def __init__(self, address, name="", min_port=None, max_port=None): from posttroll.backends.zmq.publisher import UnsecureZMQPublisher self._publisher = UnsecureZMQPublisher(address, name, min_port, max_port) elif backend == "secure_zmq": - from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - self._publisher = UnsecureZMQPublisher(address, name, min_port, max_port) + from posttroll.backends.zmq.publisher import SecureZMQPublisher + self._publisher = SecureZMQPublisher(address, name, min_port, max_port) else: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index ca4d993..85415c0 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -30,7 +30,6 @@ from datetime import datetime, timedelta from posttroll import config -from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber from posttroll.message import _MAGICK from posttroll.ns import get_pub_address @@ -66,8 +65,13 @@ def __init__(self, addresses, topics="", message_filter=None, translate=False): topics = self._magickfy_topics(topics) backend = config.get("backend", "unsecure_zmq") if backend == "unsecure_zmq": + from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber self._subscriber = UnsecureZMQSubscriber(addresses, topics=topics, message_filter=message_filter, translate=translate) + elif backend == "secure_zmq": + from posttroll.backends.zmq.subscriber import SecureZMQSubscriber + self._subscriber = SecureZMQSubscriber(addresses, topics=topics, + message_filter=message_filter, translate=translate) else: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index d04fe07..99623dd 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -605,8 +605,7 @@ def test_dict_config_full_nssubscriber(NSSubscriber_start): NSSubscriber_start.assert_called_once() -@mock.patch("posttroll.subscriber.UnsecureZMQSubscriber.update") -def test_dict_config_full_subscriber(Subscriber_update): +def test_dict_config_full_subscriber(): """Test that all Subscriber options are passed.""" from posttroll.subscriber import create_subscriber_from_dict_config @@ -614,7 +613,7 @@ def test_dict_config_full_subscriber(Subscriber_update): "services": "val1", "topics": "val2", "addr_listener": "val3", - "addresses": "val4", + "addresses": "ipc://bla.ipc", "timeout": "val5", "translate": "val6", "nameserver": False, @@ -622,6 +621,7 @@ def test_dict_config_full_subscriber(Subscriber_update): } _ = create_subscriber_from_dict_config(settings) + @pytest.fixture() def _tcp_keepalive_settings(monkeypatch): """Set TCP Keepalive settings.""" @@ -731,20 +731,22 @@ def test_ipc_pubsub_with_sec(): subscriber_settings = dict(addresses="ipc://bla.ipc", topics="", nameserver=False, port=10202) sub = create_subscriber_from_dict_config(subscriber_settings) from posttroll.publisher import Publisher - pub = Publisher("ipc://bla.ipc", secure=True) + pub = Publisher("ipc://bla.ipc") pub.start() def delayed_send(msg): time.sleep(.2) from posttroll.message import Message msg = Message(subject="/hi", atype="string", data=msg) pub.send(str(msg)) - pub.stop() from threading import Thread - Thread(target=delayed_send, args=["hi"]).start() + thr = Thread(target=delayed_send, args=["hi"]) + thr.start() for msg in sub.recv(): assert msg.data == "hi" break sub.stop() + thr.join() + pub.stop() def test_switch_to_unknown_backend(): """Test switching to unknown backend.""" @@ -766,3 +768,13 @@ def test_switch_to_secure_zmq_backend(): with config.set(backend="secure_zmq"): Publisher("ipc://bla.ipc") Subscriber("ipc://bla.ipc") + +def test_switch_to_unsecure_zmq_backend(): + """Test switching to the secure_zmq backend.""" + from posttroll import config + from posttroll.publisher import Publisher + from posttroll.subscriber import Subscriber + + with config.set(backend="unsecure_zmq"): + Publisher("ipc://bla.ipc") + Subscriber("ipc://bla.ipc") From f3b0b3a59510c577429e4ab08043b5f0cf1a7f7d Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Mon, 19 Feb 2024 10:42:18 +0100 Subject: [PATCH 18/45] Add some tests --- posttroll/__init__.py | 19 ++++---- posttroll/backends/zmq/publisher.py | 29 ++++++++++-- posttroll/backends/zmq/subscriber.py | 28 ++++++++++- posttroll/publisher.py | 4 +- posttroll/subscriber.py | 23 ++++++---- posttroll/testing.py | 2 +- posttroll/tests/test_pubsub.py | 69 ++++++++++++++++++++++++---- 7 files changed, 139 insertions(+), 35 deletions(-) diff --git a/posttroll/__init__.py b/posttroll/__init__.py index 3ed72f9..ec2c900 100644 --- a/posttroll/__init__.py +++ b/posttroll/__init__.py @@ -37,16 +37,17 @@ logger = logging.getLogger(__name__) -# def get_context(): -# """Provide the context to use. +def get_context(): + """Provide the context to use. -# This function takes care of creating new contexts in case of forks. -# """ -# pid = os.getpid() -# if pid not in context: -# context[pid] = zmq.Context() -# logger.debug("renewed context for PID %d", pid) -# return context[pid] + This function takes care of creating new contexts in case of forks. + """ + backend = config.get("backend", "unsecure_zmq") + if "zmq" in backend: + from posttroll.backends.zmq import get_context + return get_context() + else: + raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") def strp_isoformat(strg): diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py index 8c5489e..a36e661 100644 --- a/posttroll/backends/zmq/publisher.py +++ b/posttroll/backends/zmq/publisher.py @@ -62,9 +62,9 @@ def stop(self): self.publish_socket.close() class SecureZMQPublisher: - """Unsecure ZMQ implementation of the publisher class.""" + """Secure ZMQ implementation of the publisher class.""" - def __init__(self, address, name="", min_port=None, max_port=None): + def __init__(self, address, name="", min_port=None, max_port=None, server_secret_key=None, public_keys_directory=None, authorized_sub_addresses=None): """Bind the publisher class to a port.""" self.name = name self.destination = address @@ -74,9 +74,31 @@ def __init__(self, address, name="", min_port=None, max_port=None): self.port_number = None self._pub_lock = Lock() + self._server_secret_key = server_secret_key + self._authorized_sub_addresses = authorized_sub_addresses or [] + self._pub_keys_dir = public_keys_directory + self._authenticator = None + def start(self): """Start the publisher.""" - self.publish_socket = get_context().socket(zmq.PUB) + ctx = get_context() + + # Start an authenticator for this context. + from zmq.auth.thread import ThreadAuthenticator + auth = ThreadAuthenticator(ctx) + auth.start() + auth.allow(*self._authorized_sub_addresses) + # Tell authenticator to use the certificate in a directory + auth.configure_curve(domain='*', location=self._pub_keys_dir) + self._authenticator = auth + + self.publish_socket = ctx.socket(zmq.PUB) + + server_public, server_secret =zmq.auth.load_certificate(self._server_secret_key) + self.publish_socket.curve_secretkey = server_secret + self.publish_socket.curve_publickey = server_public + self.publish_socket.curve_server = True + _set_tcp_keepalive(self.publish_socket) self._bind() @@ -110,3 +132,4 @@ def stop(self): """Stop the publisher.""" self.publish_socket.setsockopt(zmq.LINGER, 1) self.publish_socket.close() + self._authenticator.stop() diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index b0f3b63..96e6552 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -34,6 +34,11 @@ def __init__(self, addresses, topics="", message_filter=None, translate=False): self._loop = None + @property + def running(self): + """Check if suscriber is running.""" + return self._loop + def add(self, address, topics=None): """Add *address* to the subscribing list for *topics*. @@ -203,14 +208,17 @@ def __del__(self): class SecureZMQSubscriber: - """Unsecure ZMQ implementation of the subscriber.""" + """Secure ZMQ implementation of the subscriber, using Curve.""" - def __init__(self, addresses, topics="", message_filter=None, translate=False): + def __init__(self, addresses, topics="", message_filter=None, translate=False, client_secret_key_file=None, server_public_key_file=None): """Initialize the subscriber.""" self._topics = topics self._filter = message_filter self._translate = translate + self._client_secret_file = client_secret_key_file + self._server_public_key_file = server_public_key_file + self.sub_addr = {} self.addr_sub = {} @@ -224,6 +232,11 @@ def __init__(self, addresses, topics="", message_filter=None, translate=False): self._loop = None + @property + def running(self): + """Check if suscriber is running.""" + return self._loop + def add(self, address, topics=None): """Add *address* to the subscribing list for *topics*. @@ -241,7 +254,18 @@ def add(self, address, topics=None): self.addr_sub[address] = subscriber def _add_sub_socket(self, address, topics): + import zmq.auth subscriber = get_context().socket(SUB) + + client_public, client_secret = zmq.auth.load_certificate(self._client_secret_file) + subscriber.curve_secretkey = client_secret + subscriber.curve_publickey = client_public + + server_public, _ = zmq.auth.load_certificate(self._server_public_key_file) + # The client must know the server's public key to make a CURVE connection. + subscriber.curve_serverkey = server_public + + _set_tcp_keepalive(subscriber) for t__ in topics: subscriber.setsockopt_string(SUBSCRIBE, str(t__)) diff --git a/posttroll/publisher.py b/posttroll/publisher.py index a16d8d2..776a85d 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -85,7 +85,7 @@ class Publisher: """ - def __init__(self, address, name="", min_port=None, max_port=None): + def __init__(self, address, name="", min_port=None, max_port=None, **kwargs): """Bind the publisher class to a port.""" # Limit port range or use the defaults when no port is defined # by the user @@ -100,7 +100,7 @@ def __init__(self, address, name="", min_port=None, max_port=None): self._publisher = UnsecureZMQPublisher(address, name, min_port, max_port) elif backend == "secure_zmq": from posttroll.backends.zmq.publisher import SecureZMQPublisher - self._publisher = SecureZMQPublisher(address, name, min_port, max_port) + self._publisher = SecureZMQPublisher(address, name, min_port, max_port, **kwargs) else: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 85415c0..1647397 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -60,7 +60,7 @@ class Subscriber: """ - def __init__(self, addresses, topics="", message_filter=None, translate=False): + def __init__(self, addresses, topics="", message_filter=None, translate=False, **kwargs): """Initialize the subscriber.""" topics = self._magickfy_topics(topics) backend = config.get("backend", "unsecure_zmq") @@ -71,7 +71,7 @@ def __init__(self, addresses, topics="", message_filter=None, translate=False): elif backend == "secure_zmq": from posttroll.backends.zmq.subscriber import SecureZMQSubscriber self._subscriber = SecureZMQSubscriber(addresses, topics=topics, - message_filter=message_filter, translate=translate) + message_filter=message_filter, translate=translate, **kwargs) else: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") @@ -135,6 +135,11 @@ def close(self): """Close the subscriber: stop it and close the local subscribers.""" return self._subscriber.close() + @property + def running(self): + """Check if suscriber is running.""" + return self._subscriber.running + @staticmethod def _magickfy_topics(topics): """Add the magick to the topics if missing.""" @@ -343,12 +348,14 @@ def create_subscriber_from_dict_config(settings): def _get_subscriber_instance(settings): - addresses = settings["addresses"] - topics = settings.get("topics", "") - message_filter = settings.get("message_filter", None) - translate = settings.get("translate", False) - - return Subscriber(addresses, topics=topics, message_filter=message_filter, translate=translate) + addresses = settings.pop("addresses") + topics = settings.pop("topics", "") + message_filter = settings.pop("message_filter", None) + translate = settings.pop("translate", False) + _ = settings.pop("nameserver", None) + _ = settings.pop("port", None) + + return Subscriber(addresses, topics=topics, message_filter=message_filter, translate=translate, **settings) def _get_nssubscriber_instance(settings): diff --git a/posttroll/testing.py b/posttroll/testing.py index 54a7e51..f632c8f 100644 --- a/posttroll/testing.py +++ b/posttroll/testing.py @@ -10,7 +10,7 @@ def patched_subscriber_recv(messages): def interuptible_recv(self): """Yield message until the subscriber is closed.""" for msg in messages: - if self._loop is False: + if self.running is False: break yield msg diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 99623dd..ae28c1a 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -724,14 +724,61 @@ def delayed_send(msg): break sub.stop() -def test_ipc_pubsub_with_sec(): + +def create_keys(tmp_path): + """Test pub-sub on a secure ipc socket.""" + base_dir = tmp_path + keys_dir = base_dir / "certificates" + public_keys_dir = base_dir / "public_keys" + secret_keys_dir = base_dir / "private_keys" + + keys_dir.mkdir() + public_keys_dir.mkdir() + secret_keys_dir.mkdir() + + import zmq.auth + import os + import shutil + + # create new keys in certificates dir + server_public_file, server_secret_file = zmq.auth.create_certificates( + keys_dir, "server" + ) + client_public_file, client_secret_file = zmq.auth.create_certificates( + keys_dir, "client" + ) + + # move public keys to appropriate directory + for key_file in os.listdir(keys_dir): + if key_file.endswith(".key"): + shutil.move( + os.path.join(keys_dir, key_file), os.path.join(public_keys_dir, '.') + ) + + # move secret keys to appropriate directory + for key_file in os.listdir(keys_dir): + if key_file.endswith(".key_secret"): + shutil.move( + os.path.join(keys_dir, key_file), os.path.join(secret_keys_dir, '.') + ) + + +def test_ipc_pubsub_with_sec(tmp_path): """Test pub-sub on a secure ipc socket.""" + base_dir = tmp_path + public_keys_dir = base_dir / "public_keys" + secret_keys_dir = base_dir / "private_keys" + + create_keys(tmp_path) + from posttroll import config with config.set(backend="secure_zmq"): - subscriber_settings = dict(addresses="ipc://bla.ipc", topics="", nameserver=False, port=10202) + subscriber_settings = dict(addresses="ipc://bla.ipc", topics="", nameserver=False, port=10202, + client_secret_key_file=secret_keys_dir / "client.key_secret", + server_public_key_file=public_keys_dir / "server.key") sub = create_subscriber_from_dict_config(subscriber_settings) from posttroll.publisher import Publisher - pub = Publisher("ipc://bla.ipc") + pub = Publisher("ipc://bla.ipc", server_secret_key=secret_keys_dir / "server.key_secret", public_keys_directory=public_keys_dir) pub.start() def delayed_send(msg): time.sleep(.2) @@ -739,14 +786,16 @@ def delayed_send(msg): msg = Message(subject="/hi", atype="string", data=msg) pub.send(str(msg)) from threading import Thread - thr = Thread(target=delayed_send, args=["hi"]) + thr = Thread(target=delayed_send, args=["very sensitive message"]) thr.start() - for msg in sub.recv(): - assert msg.data == "hi" - break - sub.stop() - thr.join() - pub.stop() + try: + for msg in sub.recv(): + assert msg.data == "very sensitive message" + break + finally: + sub.stop() + thr.join() + pub.stop() def test_switch_to_unknown_backend(): """Test switching to unknown backend.""" From 7a1f08d2c628512c9f8cdc9ece2434c5c0884715 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 23 Apr 2024 14:54:03 +0200 Subject: [PATCH 19/45] Switch to pyproject.toml only --- posttroll/__init__.py | 6 - posttroll/version.py | 657 ------------------------------------------ pyproject.toml | 73 +++++ setup.cfg | 14 - setup.py | 59 ---- 5 files changed, 73 insertions(+), 736 deletions(-) delete mode 100644 posttroll/version.py create mode 100644 pyproject.toml delete mode 100644 setup.cfg delete mode 100644 setup.py diff --git a/posttroll/__init__.py b/posttroll/__init__.py index fbc3dd8..aece644 100644 --- a/posttroll/__init__.py +++ b/posttroll/__init__.py @@ -30,8 +30,6 @@ from donfig import Config -from .version import get_versions - config = Config("posttroll") # context = {} logger = logging.getLogger(__name__) @@ -72,7 +70,3 @@ def strp_isoformat(strg): dat = dt.datetime.strptime(dat, "%Y-%m-%dT%H:%M:%S") mis = int(float("." + mis)*1000000) return dat.replace(microsecond=mis) - - -__version__ = get_versions()["version"] -del get_versions diff --git a/posttroll/version.py b/posttroll/version.py deleted file mode 100644 index aabec7c..0000000 --- a/posttroll/version.py +++ /dev/null @@ -1,657 +0,0 @@ - -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.23 (https://github.com/python-versioneer/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys -from typing import Callable, Dict -import functools - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "v" - cfg.parentdir_prefix = "None" - cfg.versionfile_source = "posttroll/version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} - - -def register_vcs_handler(vcs, method): # decorator - """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - - popen_kwargs = {} - if sys.platform == "win32": - # This hides the console window if pythonw.exe is used - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - popen_kwargs["startupinfo"] = startupinfo - - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, process.returncode - return stdout, process.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r'\d', r): - continue - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - # GIT_DIR can interfere with correct operation of Versioneer. - # It may be intended to be passed to the Versioneer-versioned project, - # but that should not change where we get our version from. - env = os.environ.copy() - env.pop("GIT_DIR", None) - runner = functools.partial(runner, env=env) - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) - pieces["distance"] = len(out.split()) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces): - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def pep440_split_post(ver): - """Split pep440 version string at the post-release segment. - - Returns the release segments before the post-release and the - post-release version number (or -1 if no post-release segment is present). - """ - vc = str.split(ver, ".post") - return vc[0], int(vc[1] or 0) if len(vc) == 2 else None - - -def render_pep440_pre(pieces): - """TAG[.postN.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - if pieces["distance"]: - # update the post release segment - tag_version, post_version = pep440_split_post(pieces["closest-tag"]) - rendered = tag_version - if post_version is not None: - rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) - else: - rendered += ".post0.dev%d" % (pieces["distance"]) - else: - # no commits, use the tag as the version - rendered = pieces["closest-tag"] - else: - # exception #1 - rendered = "0.post0.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0c7fb78 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,73 @@ +[project] +name = "posttroll" +dynamic = ["version"] +description = "Messaging system for pytroll" +authors = [ + { name = "The Pytroll Team", email = "pytroll@googlegroups.com" } +] +dependencies = ["pyzmq", "netifaces", "donfig"] +readme = "README.md" +requires-python = ">=3.10" +license = { text = "GPLv3" } +classifiers = [ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Programming Language :: Python", + "Operating System :: OS Independent", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering", + "Topic :: Communications" +] + +[project.scripts] +pytroll-logger = "posttroll.logger:run" + +[project.urls] +Homepage = "https://github.com/pytroll/posttroll" +"Bug Tracker" = "https://github.com/pytroll/posttroll/issues" +Documentation = "https://posttroll.readthedocs.io/" +"Source Code" = "https://github.com/pytroll/posttroll" +Organization = "https://pytroll.github.io/" +Slack = "https://pytroll.slack.com/" +"Release Notes" = "https://github.com/pytroll/posttroll/blob/main/CHANGELOG.md" + +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["posttroll"] + +[tool.hatch.version] +source = "vcs" + +[tool.hatch.build.hooks.vcs] +version-file = "posttroll/version.py" + +[tool.isort] +sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] +profile = "black" +skip_gitignore = true +default_section = "THIRDPARTY" +known_first_party = "posttroll" +line_length = 120 + + +[tool.ruff] +# See https://docs.astral.sh/ruff/rules/ +# In the future, add "B", "S", "N" +select = ["A", "D", "E", "W", "F", "I", "PT", "TID", "C90", "Q", "T10", "T20"] +line-length = 120 +exclude = ["versioneer.py", + "posttroll/version.py", + "doc"] + +[tool.ruff.pydocstyle] +convention = "google" + +[tool.ruff.mccabe] +# Unlike Flake8, default to a complexity level of 10. +max-complexity = 10 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 96fbf05..0000000 --- a/setup.cfg +++ /dev/null @@ -1,14 +0,0 @@ -[bdist_rpm] -requires=python-daemon pyzmq -release=1 - -[versioneer] -VCS = git -style = pep440 -versionfile_source = posttroll/version.py -versionfile_build = -tag_prefix = v -#parentdir_prefix = myproject- - -[flake8] -max-line-length = 120 diff --git a/setup.py b/setup.py deleted file mode 100644 index ad154e8..0000000 --- a/setup.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright (c) 2011, 2012, 2014, 2015, 2020. - -# Author(s): - -# The pytroll team: -# Martin Raspaud - -# This file is part of pytroll. - -# This is free software: you can redistribute it and/or modify it under the -# terms of the GNU General Public License as published by the Free Software -# Foundation, either version 3 of the License, or (at your option) any later -# version. - -# This program is distributed in the hope that it will be useful, but WITHOUT -# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS -# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more -# details. - -# You should have received a copy of the GNU General Public License along with -# this program. If not, see . - -"""Set up the package.""" - -from setuptools import setup - -import versioneer - -requirements = ["pyzmq", "netifaces", "donfig"] - - -setup(name="posttroll", - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - description="Messaging system for pytroll", - author="The pytroll team", - author_email="pytroll@googlegroups.com", - url="http://github.com/pytroll/posttroll", - packages=["posttroll"], - entry_points={ - "console_scripts": ["pytroll-logger = posttroll.logger:run", ]}, - scripts=["bin/nameserver"], - zip_safe=False, - license="GPLv3", - install_requires=requirements, - classifiers=[ - "Development Status :: 5 - Production/Stable", - "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", - "Programming Language :: Python", - "Operating System :: OS Independent", - "Intended Audience :: Science/Research", - "Topic :: Scientific/Engineering", - "Topic :: Communications" - ], - python_requires=">=3.10", - test_suite="posttroll.tests.suite", - ) From ce0f72efdfcbd3a56b2982db989212411644296c Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 23 Apr 2024 14:54:32 +0200 Subject: [PATCH 20/45] Update documentation --- doc/Makefile | 136 ++------------------- doc/source/conf.py | 256 +++------------------------------------- doc/source/index.rst | 9 -- posttroll/publisher.py | 2 +- posttroll/subscriber.py | 2 +- 5 files changed, 34 insertions(+), 371 deletions(-) diff --git a/doc/Makefile b/doc/Makefile index bd4a009..d0c3cbf 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -1,130 +1,20 @@ -# Makefile for Sphinx documentation +# Minimal makefile for Sphinx documentation # -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source BUILDDIR = build -# Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source - -.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest - +# Put it first so that "make" without argument is like "make help". help: - @echo "Please use \`make ' where is one of" - @echo " html to make standalone HTML files" - @echo " dirhtml to make HTML files named index.html in directories" - @echo " singlehtml to make a single large HTML file" - @echo " pickle to make pickle files" - @echo " json to make JSON files" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " qthelp to make HTML files and a qthelp project" - @echo " devhelp to make HTML files and a Devhelp project" - @echo " epub to make an epub" - @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" - @echo " latexpdf to make LaTeX files and run them through pdflatex" - @echo " text to make text files" - @echo " man to make manual pages" - @echo " changes to make an overview of all changed/added/deprecated items" - @echo " linkcheck to check all external links for integrity" - @echo " doctest to run all doctests embedded in the documentation (if enabled)" - -clean: - -rm -rf $(BUILDDIR)/* - -html: - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." - -dirhtml: - $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." - -singlehtml: - $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml - @echo - @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." - -pickle: - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle - @echo - @echo "Build finished; now you can process the pickle files." - -json: - $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json - @echo - @echo "Build finished; now you can process the JSON files." - -htmlhelp: - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in $(BUILDDIR)/htmlhelp." - -qthelp: - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp - @echo - @echo "Build finished; now you can run "qcollectiongenerator" with the" \ - ".qhcp project file in $(BUILDDIR)/qthelp, like this:" - @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/PostTroll.qhcp" - @echo "To view the help file:" - @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/PostTroll.qhc" - -devhelp: - $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp - @echo - @echo "Build finished." - @echo "To view the help file:" - @echo "# mkdir -p $$HOME/.local/share/devhelp/PostTroll" - @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/PostTroll" - @echo "# devhelp" - -epub: - $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub - @echo - @echo "Build finished. The epub file is in $(BUILDDIR)/epub." - -latex: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo - @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." - @echo "Run \`make' in that directory to run these through (pdf)latex" \ - "(use \`make latexpdf' here to do that automatically)." - -latexpdf: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through pdflatex..." - make -C $(BUILDDIR)/latex all-pdf - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -text: - $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text - @echo - @echo "Build finished. The text files are in $(BUILDDIR)/text." - -man: - $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man - @echo - @echo "Build finished. The manual pages are in $(BUILDDIR)/man." - -changes: - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes - @echo - @echo "The overview file is in $(BUILDDIR)/changes." + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -linkcheck: - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck - @echo - @echo "Link check complete; look for any errors in the above output " \ - "or in $(BUILDDIR)/linkcheck/output.txt." +.PHONY: help Makefile -doctest: - $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest - @echo "Testing of doctests in the sources finished, look at the " \ - "results in $(BUILDDIR)/doctest/output.txt." +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/doc/source/conf.py b/doc/source/conf.py index fd519b5..d45d3d6 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -1,248 +1,30 @@ -# -*- coding: utf-8 -*- +# Configuration file for the Sphinx documentation builder. # -# PostTroll documentation build configuration file, created by -# sphinx-quickstart on Tue Sep 11 12:58:14 2012. -# -# This file is execfile()d with the current directory set to its containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -import os -import sys - -from posttroll import __version__ - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) -sys.path.insert(0, os.path.abspath("../../")) -sys.path.insert(0, os.path.abspath("../../posttroll")) - - - -class Mock(object): - """A mocking class.""" - def __init__(self, *args, **kwargs): - pass - - def __call__(self, *args, **kwargs): - return Mock() - - @classmethod - def __getattr__(cls, name): - if name in ("__file__", "__path__"): - return "/dev/null" - elif name[0] == name[0].upper(): - mock_type = type(name, (), {}) - mock_type.__module__ = __name__ - return mock_type - else: - return Mock() - -MOCK_MODULES = ["zmq"] -for mod_name in MOCK_MODULES: - sys.modules[mod_name] = Mock() - -# -- General configuration ----------------------------------------------------- - -# If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be extensions -# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ["sphinx.ext.autodoc", "sphinx.ext.doctest"] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ["sphinx_templates"] +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html -# The suffix of source filenames. -source_suffix = ".rst" +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information +from posttroll.version import version -# The encoding of source files. -#source_encoding = 'utf-8-sig' +project = "Posttroll" +copyright = "2012, Pytroll Crew" +author = "Pytroll Crew" +release = version -# The master toctree document. -master_doc = "index" +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -# General information about the project. -project = u"PostTroll" -copyright = u"2012-2014, Pytroll crew" +extensions = ["sphinx.ext.napoleon", "sphinx.ext.autodoc"] +autodoc_mock_imports = ["pyzmq"] -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# - - -# The full version, including alpha/beta/rc tags. -release = __version__ -# The short X.Y version. -version = ".".join(release.split(".")[:2]) - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -#language = None - -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -#today = '' -# Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. +templates_path = ["_templates"] exclude_patterns = [] -# The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -#add_module_names = True - -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -#show_authors = False - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = "sphinx" - -# A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] - - -# -- Options for HTML output --------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -html_theme = "default" - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -#html_theme_options = {} - -# Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -#html_title = None - -# A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. -#html_logo = None - -# The name of an image file (within the static path) to use as favicon of the -# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. -#html_favicon = None - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["sphinx_static"] - -# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, -# using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -#html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -#html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -#html_additional_pages = {} - -# If false, no module index is generated. -#html_domain_indices = True - -# If false, no index is generated. -#html_use_index = True - -# If true, the index is split into individual pages for each letter. -#html_split_index = False - -# If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -#html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None - -# Output file base name for HTML help builder. -htmlhelp_basename = "PostTrolldoc" - - -# -- Options for LaTeX output -------------------------------------------------- - -# The paper size ('letter' or 'a4'). -#latex_paper_size = 'letter' - -# The font size ('10pt', '11pt' or '12pt'). -#latex_font_size = '10pt' - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, author, documentclass [howto/manual]). -latex_documents = [ - ("index", "PostTroll.tex", u"PostTroll Documentation", - u"Pytroll crew", "manual"), -] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -#latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -#latex_use_parts = False - -# If true, show page references after internal links. -#latex_show_pagerefs = False - -# If true, show URL addresses after external links. -#latex_show_urls = False - -# Additional stuff for the LaTeX preamble. -#latex_preamble = '' - -# Documents to append as an appendix to all manuals. -#latex_appendices = [] - -# If false, no module index is generated. -#latex_domain_indices = True -# -- Options for manual page output -------------------------------------------- +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - ("index", "posttroll", u"PostTroll Documentation", - [u"Pytroll crew"], 1) -] +html_theme = "alabaster" +html_static_path = ["_static"] diff --git a/doc/source/index.rst b/doc/source/index.rst index 57aa0f0..b556936 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -240,15 +240,6 @@ Multicast code :members: :undoc-members: - -Connections -~~~~~~~~~~~ - -.. automodule:: posttroll.connections - :members: - :undoc-members: - - Misc ~~~~ diff --git a/posttroll/publisher.py b/posttroll/publisher.py index 3a3455c..b067e0e 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -245,7 +245,7 @@ class Publish: See :class:`Publisher` and :class:`NoisyPublisher` for more information on the arguments. - The publisher is selected based on the arguments, see :function:`create_publisher_from_dict_config` for + The publisher is selected based on the arguments, see :func:`create_publisher_from_dict_config` for information how the selection is done. Example on how to use the :class:`Publish` context:: diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 7d49115..1cd187b 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -247,7 +247,7 @@ class Subscribe: See :class:`NSSubscriber` and :class:`Subscriber` for initialization parameters. - The subscriber is selected based on the arguments, see :function:`create_subscriber_from_dict_config` for + The subscriber is selected based on the arguments, see :func:`create_subscriber_from_dict_config` for information how the selection is done. Example:: From 8351a86679a066db419adee1c20147be451d63fb Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 23 Apr 2024 14:54:50 +0200 Subject: [PATCH 21/45] Refactor bbmcast tests --- posttroll/tests/test_bbmcast.py | 42 +++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/posttroll/tests/test_bbmcast.py b/posttroll/tests/test_bbmcast.py index 0baf65e..687ee8d 100644 --- a/posttroll/tests/test_bbmcast.py +++ b/posttroll/tests/test_bbmcast.py @@ -25,6 +25,7 @@ import random from socket import SO_BROADCAST, SOL_SOCKET, error from threading import Thread +import os import pytest @@ -47,18 +48,24 @@ def test_mcast_sender_works_with_valid_addresses(): socket.close() +def test_mcast_sender_uses_broadcast_for_0s(): + """Test mcast_sender uses broadcast for 0.0.0.0.""" mcgroup = "0.0.0.0" socket, group = bbmcast.mcast_sender(mcgroup) assert group == "" assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 socket.close() +def test_mcast_sender_uses_broadcast_for_255s(): + """Test mcast_sender uses broadcast for 255.255.255.255.""" mcgroup = "255.255.255.255" socket, group = bbmcast.mcast_sender(mcgroup) assert group == "" assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 socket.close() +def test_mcast_sender_raises_for_invalit_adresses(): + """Test mcast_sender uses broadcast for 0.0.0.0.""" mcgroup = (str(random.randint(0, 223)) + "." + str(random.randint(0, 255)) + "." + str(random.randint(0, 255)) + "." + @@ -111,7 +118,11 @@ def test_mcast_receiver_works_with_valid_addresses(): bbmcast.mcast_receiver(mcport, mcgroup) -def test_mcast_send_recv(reraise): +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST") != None, + reason="Multicast tests disabled.", +) +def test_multicast_roundtrip(reraise): """Test sending and receiving a multicast message.""" mcgroup = bbmcast.DEFAULT_MC_GROUP mcport = 5555 @@ -120,8 +131,8 @@ def test_mcast_send_recv(reraise): message = "Ho Ho Ho!" def check_message(sock, message): - data, _ = sock.recvfrom(1024) with reraise: + data, _ = sock.recvfrom(1024) assert data.decode() == message snd_socket, snd_group = bbmcast.mcast_sender(mcgroup) @@ -135,6 +146,32 @@ def check_message(sock, message): rec_socket.close() snd_socket.close() + +def test_broadcast_roundtrip(reraise): + """Test sending and receiving a broadcast message.""" + mcgroup = "0.0.0.0" + mcport = 5555 + rec_socket, rec_group = bbmcast.mcast_receiver(mcport, mcgroup) + + message = "Ho Ho Ho!" + + def check_message(sock, message): + with reraise: + data, _ = sock.recvfrom(1024) + assert data.decode() == message + + snd_socket, snd_group = bbmcast.mcast_sender(mcgroup) + + thr = Thread(target=check_message, args=(rec_socket, message)) + thr.start() + + snd_socket.sendto(message.encode(), (mcgroup, mcport)) + + thr.join() + rec_socket.close() + snd_socket.close() + + def test_posttroll_mc_group_is_used(): """Test that configured mc_group is used.""" from posttroll import config @@ -144,6 +181,7 @@ def test_posttroll_mc_group_is_used(): socket.close() assert group == "226.0.0.13" + def test_pytroll_mc_group_is_deprecated(monkeypatch): """Test that PYTROLL_MC_GROUP is used but pending deprecation.""" other_group = "226.0.0.13" From 3536cd091c89eeea840b45396aee2ca18499e8e2 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 23 Apr 2024 15:09:53 +0200 Subject: [PATCH 22/45] Fix pyproject.toml --- pyproject.toml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0c7fb78..784768f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ version-file = "posttroll/version.py" [tool.isort] sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] -profile = "black" skip_gitignore = true default_section = "THIRDPARTY" known_first_party = "posttroll" @@ -59,15 +58,15 @@ line_length = 120 [tool.ruff] # See https://docs.astral.sh/ruff/rules/ # In the future, add "B", "S", "N" -select = ["A", "D", "E", "W", "F", "I", "PT", "TID", "C90", "Q", "T10", "T20"] +lint.select = ["A", "D", "E", "W", "F", "I", "PT", "TID", "C90", "Q", "T10", "T20"] line-length = 120 exclude = ["versioneer.py", "posttroll/version.py", "doc"] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 10 From ba3189c3bfcbcf06e7b790fda0cce11df208aa6e Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 23 Apr 2024 15:10:46 +0200 Subject: [PATCH 23/45] Fix quotes --- posttroll/backends/zmq/publisher.py | 2 +- posttroll/message.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py index a36e661..ede9211 100644 --- a/posttroll/backends/zmq/publisher.py +++ b/posttroll/backends/zmq/publisher.py @@ -89,7 +89,7 @@ def start(self): auth.start() auth.allow(*self._authorized_sub_addresses) # Tell authenticator to use the certificate in a directory - auth.configure_curve(domain='*', location=self._pub_keys_dir) + auth.configure_curve(domain="*", location=self._pub_keys_dir) self._authenticator = auth self.publish_socket = ctx.socket(zmq.PUB) diff --git a/posttroll/message.py b/posttroll/message.py index 9ea845b..315e030 100644 --- a/posttroll/message.py +++ b/posttroll/message.py @@ -117,7 +117,7 @@ class Message(object): - It will make a Message pickleable. """ - def __init__(self, subject='', atype='', data='', binary=False, rawstr=None): + def __init__(self, subject="", atype="", data="", binary=False, rawstr=None): """Initialize a Message from a subject, type and data, or from a raw string.""" if rawstr: self.__dict__ = _decode(rawstr) From 87d73bd9e5e45bf31b454c2fc9d331c4889c26f3 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 23 Apr 2024 15:11:44 +0200 Subject: [PATCH 24/45] Timeout in multicast test when not working --- posttroll/tests/test_bbmcast.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/posttroll/tests/test_bbmcast.py b/posttroll/tests/test_bbmcast.py index 687ee8d..c354c4a 100644 --- a/posttroll/tests/test_bbmcast.py +++ b/posttroll/tests/test_bbmcast.py @@ -22,10 +22,10 @@ """Test multicasting and broadcasting.""" +import os import random from socket import SO_BROADCAST, SOL_SOCKET, error from threading import Thread -import os import pytest @@ -127,6 +127,7 @@ def test_multicast_roundtrip(reraise): mcgroup = bbmcast.DEFAULT_MC_GROUP mcport = 5555 rec_socket, rec_group = bbmcast.mcast_receiver(mcport, mcgroup) + rec_socket.settimeout(.1) message = "Ho Ho Ho!" From 59a620ef157349f06f53ee058a88a2353b4f49b1 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 24 Apr 2024 08:32:39 +0200 Subject: [PATCH 25/45] Make public and secret keys positional arguments --- posttroll/backends/zmq/publisher.py | 18 +- posttroll/backends/zmq/subscriber.py | 2 +- posttroll/publisher.py | 21 +- posttroll/subscriber.py | 4 +- posttroll/tests/test_pubsub.py | 278 ++++++++------------- posttroll/tests/test_secure_zmq_backend.py | 144 +++++++++++ 6 files changed, 272 insertions(+), 195 deletions(-) create mode 100644 posttroll/tests/test_secure_zmq_backend.py diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py index ede9211..64e6e15 100644 --- a/posttroll/backends/zmq/publisher.py +++ b/posttroll/backends/zmq/publisher.py @@ -64,8 +64,22 @@ def stop(self): class SecureZMQPublisher: """Secure ZMQ implementation of the publisher class.""" - def __init__(self, address, name="", min_port=None, max_port=None, server_secret_key=None, public_keys_directory=None, authorized_sub_addresses=None): - """Bind the publisher class to a port.""" + def __init__(self, address, server_secret_key, public_keys_directory, name="", min_port=None, max_port=None, + authorized_sub_addresses=None): + """Set up the secure ZMQ publisher. + + Args: + address: the address to connect to. + server_secret_key: the secret key for this publisher. + public_keys_directory: the directory containing the public keys of the subscribers that are allowed to + connect. + name: the name of this publishing service. + min_port: the minimal port number to use. + max_port: the maximal port number to use. + authorized_sub_addresses: the list of addresse allowed to subscibe to this publisher. By default, all are + allowed. + + """ self.name = name self.destination = address self.publish_socket = None diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index 96e6552..3351995 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -210,7 +210,7 @@ def __del__(self): class SecureZMQSubscriber: """Secure ZMQ implementation of the subscriber, using Curve.""" - def __init__(self, addresses, topics="", message_filter=None, translate=False, client_secret_key_file=None, server_public_key_file=None): + def __init__(self, addresses, client_secret_key_file, server_public_key_file, topics="", message_filter=None, translate=False): """Initialize the subscriber.""" self._topics = topics self._filter = message_filter diff --git a/posttroll/publisher.py b/posttroll/publisher.py index b067e0e..d1e4b85 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -85,7 +85,7 @@ class Publisher: """ - def __init__(self, address, name="", min_port=None, max_port=None, **kwargs): + def __init__(self, address, *args, name="", min_port=None, max_port=None, **kwargs): """Bind the publisher class to a port.""" # Limit port range or use the defaults when no port is defined # by the user @@ -97,10 +97,10 @@ def __init__(self, address, name="", min_port=None, max_port=None, **kwargs): backend = config.get("backend", "unsecure_zmq") if backend == "unsecure_zmq": from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - self._publisher = UnsecureZMQPublisher(address, name, min_port, max_port) + self._publisher = UnsecureZMQPublisher(address, name=name, min_port=min_port, max_port=max_port, **kwargs) elif backend == "secure_zmq": from posttroll.backends.zmq.publisher import SecureZMQPublisher - self._publisher = SecureZMQPublisher(address, name, min_port, max_port, **kwargs) + self._publisher = SecureZMQPublisher(address, *args, name=name, min_port=min_port, max_port=max_port, **kwargs) else: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") @@ -305,18 +305,19 @@ def create_publisher_from_dict_config(settings): described in the docstrings of the respective classes, namely :class:`~posttroll.publisher.Publisher` and :class:`~posttroll.publisher.NoisyPublisher`. """ - if settings.get("port") and settings.get("nameservers") is False: + if (settings.get("port") or settings.get("address")) and settings.get("nameservers") is False: return _get_publisher_instance(settings) return _get_noisypublisher_instance(settings) def _get_publisher_instance(settings): - publisher_address = _create_tcp_publish_address(settings["port"]) - publisher_name = settings.get("name", "") - min_port = settings.get("min_port") - max_port = settings.get("max_port") - - return Publisher(publisher_address, name=publisher_name, min_port=min_port, max_port=max_port) + settings = settings.copy() + publisher_address = settings.pop("address", None) + port = settings.pop("port", None) + if not publisher_address: + publisher_address = _create_tcp_publish_address(port) + settings.pop("nameservers", None) + return Publisher(publisher_address, **settings) def _get_noisypublisher_instance(settings): diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 1cd187b..6654dda 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -59,7 +59,7 @@ class Subscriber: """ - def __init__(self, addresses, topics="", message_filter=None, translate=False, **kwargs): + def __init__(self, addresses, *args, topics="", message_filter=None, translate=False, **kwargs): """Initialize the subscriber.""" topics = self._magickfy_topics(topics) backend = config.get("backend", "unsecure_zmq") @@ -69,7 +69,7 @@ def __init__(self, addresses, topics="", message_filter=None, translate=False, * message_filter=message_filter, translate=translate) elif backend == "secure_zmq": from posttroll.backends.zmq.subscriber import SecureZMQSubscriber - self._subscriber = SecureZMQSubscriber(addresses, topics=topics, + self._subscriber = SecureZMQSubscriber(addresses, *args, topics=topics, message_filter=message_filter, translate=translate, **kwargs) else: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 8502dbc..7f3e58a 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -34,8 +34,9 @@ from donfig import Config import posttroll +from posttroll import config from posttroll.ns import NameServer -from posttroll.publisher import create_publisher_from_dict_config +from posttroll.publisher import Publisher, create_publisher_from_dict_config from posttroll.subscriber import Subscribe, Subscriber, create_subscriber_from_dict_config test_lock = Lock() @@ -243,7 +244,7 @@ def test_pub_address_timeout(self): def test_pub_suber(self): """Test publisher and subscriber.""" from posttroll.message import Message - from posttroll.publisher import Publisher, get_own_ip + from posttroll.publisher import get_own_ip from posttroll.subscriber import Subscriber pub_address = "tcp://" + str(get_own_ip()) + ":0" pub = Publisher(pub_address).start() @@ -410,7 +411,6 @@ def test_listener_container(self): """Test listener container.""" from posttroll.listener import ListenerContainer from posttroll.message import Message - from posttroll.publisher import Publisher pub_addr = "tcp://127.0.0.1:55000" pub = Publisher(pub_addr, name="test") @@ -451,91 +451,97 @@ def test_localhost_restriction(self, mcrec, pub, msg): adr.stop() -class TestPublisherDictConfig(unittest.TestCase): - """Test configuring publishers with a dictionary.""" - def test_publisher_is_selected(self): - """Test that Publisher is selected as publisher class.""" - from posttroll.publisher import Publisher - settings = {"port": 12345, "nameservers": False} +## Test create_publisher_from_config - pub = create_publisher_from_dict_config(settings) - assert isinstance(pub, Publisher) - assert pub is not None - - @mock.patch("posttroll.publisher.Publisher") - def test_publisher_all_arguments(self, Publisher): - """Test that only valid arguments are passed to Publisher.""" - settings = {"port": 12345, "nameservers": False, "name": "foo", - "min_port": 40000, "max_port": 41000, "invalid_arg": "bar"} +def test_publisher_with_invalid_arguments_crashes(): + """Test that only valid arguments are passed to Publisher.""" + settings = {"address": "ipc:///tmp/test.ipc", "nameservers": False, "invalid_arg": "bar"} + with pytest.raises(TypeError): _ = create_publisher_from_dict_config(settings) - _check_valid_settings_in_call(settings, Publisher, ignore=["port", "nameservers"]) - assert Publisher.call_args[0][0].startswith("tcp://*:") - assert Publisher.call_args[0][0].endswith(str(settings["port"])) - def test_no_name_raises_keyerror(self): - """Trying to create a NoisyPublisher without a given name will raise KeyError.""" - with pytest.raises(KeyError): - _ = create_publisher_from_dict_config(dict()) - def test_noisypublisher_is_selected_only_name(self): - """Test that NoisyPublisher is selected as publisher class.""" - from posttroll.publisher import NoisyPublisher +def test_publisher_is_selected(): + """Test that Publisher is selected as publisher class.""" + settings = {"port": 12345, "nameservers": False} - settings = {"name": "publisher_name"} + pub = create_publisher_from_dict_config(settings) + assert isinstance(pub, Publisher) + assert pub is not None - pub = create_publisher_from_dict_config(settings) - assert isinstance(pub, NoisyPublisher) +@mock.patch("posttroll.publisher.Publisher") +def test_publisher_all_arguments(Publisher): + """Test that only valid arguments are passed to Publisher.""" + settings = {"port": 12345, "nameservers": False, "name": "foo", + "min_port": 40000, "max_port": 41000} + _ = create_publisher_from_dict_config(settings) + _check_valid_settings_in_call(settings, Publisher, ignore=["port", "nameservers"]) + assert Publisher.call_args[0][0].startswith("tcp://*:") + assert Publisher.call_args[0][0].endswith(str(settings["port"])) - def test_noisypublisher_is_selected_name_and_port(self): - """Test that NoisyPublisher is selected as publisher class.""" - from posttroll.publisher import NoisyPublisher +def test_no_name_raises_keyerror(): + """Trying to create a NoisyPublisher without a given name will raise KeyError.""" + with pytest.raises(KeyError): + _ = create_publisher_from_dict_config(dict()) - settings = {"name": "publisher_name", "port": 40000} +def test_noisypublisher_is_selected_only_name(): + """Test that NoisyPublisher is selected as publisher class.""" + from posttroll.publisher import NoisyPublisher - pub = create_publisher_from_dict_config(settings) - assert isinstance(pub, NoisyPublisher) + settings = {"name": "publisher_name"} - @mock.patch("posttroll.publisher.NoisyPublisher") - def test_noisypublisher_all_arguments(self, NoisyPublisher): - """Test that only valid arguments are passed to NoisyPublisher.""" - from posttroll.publisher import create_publisher_from_dict_config + pub = create_publisher_from_dict_config(settings) + assert isinstance(pub, NoisyPublisher) - settings = {"port": 12345, "nameservers": ["foo"], "name": "foo", - "min_port": 40000, "max_port": 41000, "invalid_arg": "bar", - "aliases": ["alias1", "alias2"], "broadcast_interval": 42} - _ = create_publisher_from_dict_config(settings) - _check_valid_settings_in_call(settings, NoisyPublisher, ignore=["name"]) - assert NoisyPublisher.call_args[0][0] == settings["name"] +def test_noisypublisher_is_selected_name_and_port(): + """Test that NoisyPublisher is selected as publisher class.""" + from posttroll.publisher import NoisyPublisher - def test_publish_is_not_noisy(self): - """Test that Publisher is selected with the context manager when it should be.""" - from posttroll.publisher import Publish, Publisher + settings = {"name": "publisher_name", "port": 40000} - with Publish("service_name", port=40000, nameservers=False) as pub: - assert isinstance(pub, Publisher) + pub = create_publisher_from_dict_config(settings) + assert isinstance(pub, NoisyPublisher) - def test_publish_is_noisy_only_name(self): - """Test that NoisyPublisher is selected with the context manager when only name is given.""" - from posttroll.publisher import NoisyPublisher, Publish +@mock.patch("posttroll.publisher.NoisyPublisher") +def test_noisypublisher_all_arguments(NoisyPublisher): + """Test that only valid arguments are passed to NoisyPublisher.""" + from posttroll.publisher import create_publisher_from_dict_config - with Publish("service_name") as pub: - assert isinstance(pub, NoisyPublisher) + settings = {"port": 12345, "nameservers": ["foo"], "name": "foo", + "min_port": 40000, "max_port": 41000, "invalid_arg": "bar", + "aliases": ["alias1", "alias2"], "broadcast_interval": 42} + _ = create_publisher_from_dict_config(settings) + _check_valid_settings_in_call(settings, NoisyPublisher, ignore=["name"]) + assert NoisyPublisher.call_args[0][0] == settings["name"] - def test_publish_is_noisy_with_port(self): - """Test that NoisyPublisher is selected with the context manager when port is given.""" - from posttroll.publisher import NoisyPublisher, Publish +def test_publish_is_not_noisy(): + """Test that Publisher is selected with the context manager when it should be.""" + from posttroll.publisher import Publish - with Publish("service_name", port=40001) as pub: - assert isinstance(pub, NoisyPublisher) + with Publish("service_name", port=40000, nameservers=False) as pub: + assert isinstance(pub, Publisher) - def test_publish_is_noisy_with_nameservers(self): - """Test that NoisyPublisher is selected with the context manager when nameservers are given.""" - from posttroll.publisher import NoisyPublisher, Publish +def test_publish_is_noisy_only_name(): + """Test that NoisyPublisher is selected with the context manager when only name is given.""" + from posttroll.publisher import NoisyPublisher, Publish - with Publish("service_name", nameservers=["a", "b"]) as pub: - assert isinstance(pub, NoisyPublisher) + with Publish("service_name") as pub: + assert isinstance(pub, NoisyPublisher) + +def test_publish_is_noisy_with_port(): + """Test that NoisyPublisher is selected with the context manager when port is given.""" + from posttroll.publisher import NoisyPublisher, Publish + + with Publish("service_name", port=40001) as pub: + assert isinstance(pub, NoisyPublisher) + +def test_publish_is_noisy_with_nameservers(): + """Test that NoisyPublisher is selected with the context manager when nameservers are given.""" + from posttroll.publisher import NoisyPublisher, Publish + + with Publish("service_name", nameservers=["a", "b"]) as pub: + assert isinstance(pub, NoisyPublisher) def _check_valid_settings_in_call(settings, pub_class, ignore=None): @@ -625,7 +631,6 @@ def test_dict_config_full_subscriber(): @pytest.fixture() def _tcp_keepalive_settings(monkeypatch): """Set TCP Keepalive settings.""" - from posttroll import config with config.set(tcp_keepalive=1, tcp_keepalive_cnt=10, tcp_keepalive_idle=1, tcp_keepalive_intvl=1): yield @@ -641,7 +646,6 @@ def reset_config_for_tests(): @pytest.fixture() def _tcp_keepalive_no_settings(): """Set TCP Keepalive settings.""" - from posttroll import config with config.set(tcp_keepalive=None, tcp_keepalive_cnt=None, tcp_keepalive_idle=None, tcp_keepalive_intvl=None): yield @@ -702,13 +706,36 @@ def _assert_no_tcp_keepalive(socket): assert socket.getsockopt(zmq.TCP_KEEPALIVE_INTVL) == -1 +def test_noisypublisher_heartbeat(): + """Test that the heartbeat in the NoisyPublisher works.""" + from posttroll.ns import NameServer + from posttroll.publisher import NoisyPublisher + from posttroll.subscriber import Subscribe + + ns_ = NameServer() + thr = Thread(target=ns_.run) + thr.start() + + pub = NoisyPublisher("test") + pub.start() + time.sleep(0.2) + + with Subscribe("test", topics="/heartbeat/test", nameserver="localhost") as sub: + time.sleep(0.2) + pub.heartbeat(min_interval=10) + msg = next(sub.recv(1)) + assert msg.type == "beat" + assert msg.data == {"min_interval": 10} + pub.stop() + ns_.stop() + thr.join() + + def test_ipc_pubsub(): """Test pub-sub on an ipc socket.""" - from posttroll import config with config.set(backend="unsecure_zmq"): subscriber_settings = dict(addresses="ipc://bla.ipc", topics="", nameserver=False, port=10202) sub = create_subscriber_from_dict_config(subscriber_settings) - from posttroll.publisher import Publisher pub = Publisher("ipc://bla.ipc") pub.start() def delayed_send(msg): @@ -725,81 +752,8 @@ def delayed_send(msg): sub.stop() -def create_keys(tmp_path): - """Test pub-sub on a secure ipc socket.""" - base_dir = tmp_path - keys_dir = base_dir / "certificates" - public_keys_dir = base_dir / "public_keys" - secret_keys_dir = base_dir / "private_keys" - - keys_dir.mkdir() - public_keys_dir.mkdir() - secret_keys_dir.mkdir() - - import zmq.auth - import os - import shutil - - # create new keys in certificates dir - server_public_file, server_secret_file = zmq.auth.create_certificates( - keys_dir, "server" - ) - client_public_file, client_secret_file = zmq.auth.create_certificates( - keys_dir, "client" - ) - - # move public keys to appropriate directory - for key_file in os.listdir(keys_dir): - if key_file.endswith(".key"): - shutil.move( - os.path.join(keys_dir, key_file), os.path.join(public_keys_dir, '.') - ) - - # move secret keys to appropriate directory - for key_file in os.listdir(keys_dir): - if key_file.endswith(".key_secret"): - shutil.move( - os.path.join(keys_dir, key_file), os.path.join(secret_keys_dir, '.') - ) - - -def test_ipc_pubsub_with_sec(tmp_path): - """Test pub-sub on a secure ipc socket.""" - base_dir = tmp_path - public_keys_dir = base_dir / "public_keys" - secret_keys_dir = base_dir / "private_keys" - - create_keys(tmp_path) - - from posttroll import config - with config.set(backend="secure_zmq"): - subscriber_settings = dict(addresses="ipc://bla.ipc", topics="", nameserver=False, port=10202, - client_secret_key_file=secret_keys_dir / "client.key_secret", - server_public_key_file=public_keys_dir / "server.key") - sub = create_subscriber_from_dict_config(subscriber_settings) - from posttroll.publisher import Publisher - pub = Publisher("ipc://bla.ipc", server_secret_key=secret_keys_dir / "server.key_secret", public_keys_directory=public_keys_dir) - pub.start() - def delayed_send(msg): - time.sleep(.2) - from posttroll.message import Message - msg = Message(subject="/hi", atype="string", data=msg) - pub.send(str(msg)) - from threading import Thread - thr = Thread(target=delayed_send, args=["very sensitive message"]) - thr.start() - try: - for msg in sub.recv(): - assert msg.data == "very sensitive message" - break - finally: - sub.stop() - thr.join() - pub.stop() - def test_switch_to_unknown_backend(): """Test switching to unknown backend.""" - from posttroll import config from posttroll.publisher import Publisher from posttroll.subscriber import Subscriber with config.set(backend="unsecure_and_deprecated"): @@ -808,47 +762,11 @@ def test_switch_to_unknown_backend(): with pytest.raises(NotImplementedError): Subscriber("ipc://bla.ipc") -def test_switch_to_secure_zmq_backend(): - """Test switching to the secure_zmq backend.""" - from posttroll import config - from posttroll.publisher import Publisher - from posttroll.subscriber import Subscriber - - with config.set(backend="secure_zmq"): - Publisher("ipc://bla.ipc") - Subscriber("ipc://bla.ipc") - def test_switch_to_unsecure_zmq_backend(): """Test switching to the secure_zmq backend.""" - from posttroll import config from posttroll.publisher import Publisher from posttroll.subscriber import Subscriber with config.set(backend="unsecure_zmq"): Publisher("ipc://bla.ipc") Subscriber("ipc://bla.ipc") - - -def test_noisypublisher_heartbeat(): - """Test that the heartbeat in the NoisyPublisher works.""" - from posttroll.ns import NameServer - from posttroll.publisher import NoisyPublisher - from posttroll.subscriber import Subscribe - - ns_ = NameServer() - thr = Thread(target=ns_.run) - thr.start() - - pub = NoisyPublisher("test") - pub.start() - time.sleep(0.2) - - with Subscribe("test", topics="/heartbeat/test", nameserver="localhost") as sub: - time.sleep(0.2) - pub.heartbeat(min_interval=10) - msg = next(sub.recv(1)) - assert msg.type == "beat" - assert msg.data == {'min_interval': 10} - pub.stop() - ns_.stop() - thr.join() diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py new file mode 100644 index 0000000..be11691 --- /dev/null +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -0,0 +1,144 @@ + +import os +import shutil +import time + +import zmq.auth + +from posttroll import config +from posttroll.publisher import Publisher +from posttroll.subscriber import Subscriber, create_subscriber_from_dict_config + + +def create_keys(tmp_path): + """Create keys.""" + base_dir = tmp_path + keys_dir = base_dir / "certificates" + public_keys_dir = base_dir / "public_keys" + secret_keys_dir = base_dir / "private_keys" + + keys_dir.mkdir() + public_keys_dir.mkdir() + secret_keys_dir.mkdir() + + # create new keys in certificates dir + server_public_file, server_secret_file = zmq.auth.create_certificates( + keys_dir, "server" + ) + client_public_file, client_secret_file = zmq.auth.create_certificates( + keys_dir, "client" + ) + + # move public keys to appropriate directory + for key_file in os.listdir(keys_dir): + if key_file.endswith(".key"): + shutil.move( + os.path.join(keys_dir, key_file), os.path.join(public_keys_dir, ".") + ) + + # move secret keys to appropriate directory + for key_file in os.listdir(keys_dir): + if key_file.endswith(".key_secret"): + shutil.move( + os.path.join(keys_dir, key_file), os.path.join(secret_keys_dir, ".") + ) + + +def test_ipc_pubsub_with_sec(tmp_path): + """Test pub-sub on a secure ipc socket.""" + server_public_key, server_secret_key = zmq.auth.create_certificates(tmp_path, "server") + client_public_key, client_secret_key = zmq.auth.create_certificates(tmp_path, "client") + + ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" + + with config.set(backend="secure_zmq"): + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202, + client_secret_key_file=client_secret_key, + server_public_key_file=server_public_key) + sub = create_subscriber_from_dict_config(subscriber_settings) + from posttroll.publisher import Publisher + + pub = Publisher(ipc_address, + server_secret_key=server_secret_key, + public_keys_directory=os.path.dirname(client_public_key)) + + + pub.start() + def delayed_send(msg): + time.sleep(.2) + from posttroll.message import Message + msg = Message(subject="/hi", atype="string", data=msg) + pub.send(str(msg)) + from threading import Thread + thr = Thread(target=delayed_send, args=["very sensitive message"]) + thr.start() + try: + for msg in sub.recv(): + assert msg.data == "very sensitive message" + break + finally: + sub.stop() + thr.join() + pub.stop() + + +def test_switch_to_secure_zmq_backend(tmp_path): + """Test switching to the secure_zmq backend.""" + create_keys(tmp_path) + + base_dir = tmp_path + public_keys_dir = base_dir / "public_keys" + secret_keys_dir = base_dir / "private_keys" + + server_secret_key = secret_keys_dir / "server.key_secret" + public_keys_directory = public_keys_dir + publisher_key_args = (server_secret_key, public_keys_directory) + + client_secret_key = secret_keys_dir / "client.key_secret" + server_public_key = public_keys_dir / "server.key" + subscriber_key_args = (client_secret_key, server_public_key) + + with config.set(backend="secure_zmq"): + Publisher("ipc://bla.ipc", *publisher_key_args) + Subscriber("ipc://bla.ipc", *subscriber_key_args) + + +def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): + """Test pub-sub on a secure ipc socket.""" + base_dir = tmp_path + public_keys_dir = base_dir / "public_keys" + secret_keys_dir = base_dir / "private_keys" + + create_keys(tmp_path) + + ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" + + with config.set(backend="secure_zmq"): + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202, + client_secret_key_file=secret_keys_dir / "client.key_secret", + server_public_key_file=public_keys_dir / "server.key") + sub = create_subscriber_from_dict_config(subscriber_settings) + from posttroll.publisher import create_publisher_from_dict_config + pub_settings = dict(address=ipc_address, + server_secret_key=secret_keys_dir / "server.key_secret", + public_keys_directory=public_keys_dir, + nameservers=False, port=1789) + pub = create_publisher_from_dict_config(pub_settings) + + pub.start() + def delayed_send(msg): + time.sleep(.2) + from posttroll.message import Message + msg = Message(subject="/hi", atype="string", data=msg) + pub.send(str(msg)) + from threading import Thread + thr = Thread(target=delayed_send, args=["very sensitive message"]) + thr.start() + try: + for msg in sub.recv(): + assert msg.data == "very sensitive message" + break + finally: + sub.stop() + thr.join() + pub.stop() From 01f7d9132da7379928adc06b456e0d58a67a5260 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 24 Apr 2024 09:20:47 +0200 Subject: [PATCH 26/45] Fix test --- posttroll/tests/test_bbmcast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/posttroll/tests/test_bbmcast.py b/posttroll/tests/test_bbmcast.py index c354c4a..1a47b40 100644 --- a/posttroll/tests/test_bbmcast.py +++ b/posttroll/tests/test_bbmcast.py @@ -119,7 +119,7 @@ def test_mcast_receiver_works_with_valid_addresses(): @pytest.mark.skipif( - os.getenv("DISABLED_MULTICAST") != None, + os.getenv("DISABLED_MULTICAST"), reason="Multicast tests disabled.", ) def test_multicast_roundtrip(reraise): From 8e47d1a0fa68e6ce106472af45eb41b86b510e05 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 25 Apr 2024 10:39:06 +0000 Subject: [PATCH 27/45] Improve tests --- posttroll/address_receiver.py | 9 +- posttroll/backends/zmq/address_receiver.py | 4 +- posttroll/backends/zmq/ns.py | 9 +- posttroll/ns.py | 13 +- posttroll/publisher.py | 4 +- posttroll/subscriber.py | 6 +- posttroll/tests/test_pubsub.py | 266 +-- pyproject.toml | 6 +- versioneer.py | 2146 -------------------- 9 files changed, 142 insertions(+), 2321 deletions(-) delete mode 100644 versioneer.py diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index cd03426..2d4a6c6 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -48,12 +48,15 @@ debug = os.environ.get("DEBUG", False) broadcast_port = 21200 -default_publish_port = 16543 +DEFAULT_ADDRESS_PUBLISH_PORT = 16543 ten_minutes = dt.timedelta(minutes=10) zero_seconds = dt.timedelta(seconds=0) +def get_configured_address_port(): + return config.get("address_publish_port", DEFAULT_ADDRESS_PUBLISH_PORT) + def get_local_ips(): """Get local IP addresses.""" inet_addrs = [netifaces.ifaddresses(iface).get(netifaces.AF_INET) @@ -72,14 +75,14 @@ def get_local_ips(): # ----------------------------------------------------------------------------- -class AddressReceiver(object): +class AddressReceiver: """General thread to receive broadcast addresses.""" def __init__(self, max_age=ten_minutes, port=None, do_heartbeat=True, multicast_enabled=True, restrict_to_localhost=False): """Set up the address receiver.""" self._max_age = max_age - self._port = port or default_publish_port + self._port = port or get_configured_address_port() self._address_lock = threading.Lock() self._addresses = {} self._subject = "/address" diff --git a/posttroll/backends/zmq/address_receiver.py b/posttroll/backends/zmq/address_receiver.py index 0052b3e..8eb22f6 100644 --- a/posttroll/backends/zmq/address_receiver.py +++ b/posttroll/backends/zmq/address_receiver.py @@ -2,7 +2,7 @@ from zmq import LINGER, REP -from posttroll.address_receiver import default_publish_port +from posttroll.address_receiver import get_configured_address_port from posttroll.backends.zmq import get_context @@ -11,7 +11,7 @@ class SimpleReceiver(object): def __init__(self, port=None): """Set up the receiver.""" - self._port = port or default_publish_port + self._port = port or get_configured_address_port() self._socket = get_context().socket(REP) self._socket.bind("tcp://*:" + str(port)) diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py index bb4d0b0..f400ed9 100644 --- a/posttroll/backends/zmq/ns.py +++ b/posttroll/backends/zmq/ns.py @@ -7,7 +7,7 @@ from posttroll.backends.zmq import get_context from posttroll.message import Message -from posttroll.ns import PORT, get_active_address +from posttroll.ns import get_configured_nameserver_port, get_active_address logger = logging.getLogger("__name__") @@ -22,10 +22,11 @@ def unsecure_zmq_get_pub_address(name, timeout=10, nameserver="localhost"): # Socket to talk to server socket = get_context().socket(REQ) try: + port = get_configured_nameserver_port() socket.setsockopt(LINGER, int(timeout * 1000)) - socket.connect("tcp://" + nameserver + ":" + str(PORT)) + socket.connect("tcp://" + nameserver + ":" + str(port)) logger.debug("Connecting to %s", - "tcp://" + nameserver + ":" + str(PORT)) + "tcp://" + nameserver + ":" + str(port)) poller = Poller() poller.register(socket, POLLIN) @@ -55,7 +56,7 @@ def __init__(self): def run(self, arec): """Run the listener and answer to requests.""" - port = PORT + port = get_configured_nameserver_port() try: with nslock: diff --git a/posttroll/ns.py b/posttroll/ns.py index 585b089..bf7acf8 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -37,10 +37,21 @@ # pylint: enable=E0611 -PORT = int(os.environ.get("NAMESERVER_PORT", 5557)) +DEFAULT_NAMESERVER_PORT = 5557 logger = logging.getLogger(__name__) +def get_configured_nameserver_port(): + try: + port = int(os.environ["NAMESERVER_PORT"]) + warnings.warn("NAMESERVER_PORT is pending deprecation, please use POSTTROLL_NAMESERVER_PORT instead.", + PendingDeprecationWarning) + except KeyError: + port = DEFAULT_NAMESERVER_PORT + return config.get("nameserver_port", port) + + + # Client functions. diff --git a/posttroll/publisher.py b/posttroll/publisher.py index d1e4b85..e085753 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -196,7 +196,7 @@ def __init__(self, name, port=0, aliases=None, broadcast_interval=2, def start(self): """Start the publisher.""" pub_addr = _create_tcp_publish_address(self._port) - self._publisher = self._publisher_class(pub_addr, self._name, + self._publisher = self._publisher_class(pub_addr, name=self._name, min_port=self.min_port, max_port=self.max_port) self._publisher.start() @@ -317,6 +317,8 @@ def _get_publisher_instance(settings): if not publisher_address: publisher_address = _create_tcp_publish_address(port) settings.pop("nameservers", None) + settings.pop("aliases", None) + settings.pop("broadcast_interval", None) return Publisher(publisher_address, **settings) diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 6654dda..9d4a05e 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -31,6 +31,7 @@ from posttroll import config from posttroll.message import _MAGICK from posttroll.ns import get_pub_address +from posttroll.address_receiver import get_configured_address_port LOGGER = logging.getLogger(__name__) @@ -200,7 +201,7 @@ def _get_addr_loop(service, timeout): """Try to get the address of *service* until for *timeout* seconds.""" then = dt.datetime.now() + dt.timedelta(seconds=timeout) while dt.datetime.now() < then: - addrs = get_pub_address(service, nameserver=self._nameserver) + addrs = get_pub_address(service, self._timeout, nameserver=self._nameserver) if addrs: return [addr["URI"] for addr in addrs] time.sleep(1) @@ -304,7 +305,8 @@ def __init__(self, subscriber, services="", nameserver="localhost"): services = [services, ] self.services = services self.subscriber = subscriber - self.subscriber.add_hook_sub("tcp://" + nameserver + ":16543", + address_publish_port = get_configured_address_port() + self.subscriber.add_hook_sub("tcp://" + nameserver + ":" + str(address_publish_port), ["pytroll://address", ], self.handle_msg) diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 7f3e58a..983e9a9 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -42,122 +42,66 @@ test_lock = Lock() -class TestNS(unittest.TestCase): - """Test the nameserver.""" +def free_port(): + """Get a free port. - def setUp(self): - """Set up the testing class.""" - test_lock.acquire() - self.ns = NameServer(max_age=timedelta(seconds=3)) - self.thr = Thread(target=self.ns.run) - self.thr.start() - - def tearDown(self): - """Clean up after the tests have run.""" - self.ns.stop() - self.thr.join() - time.sleep(2) - test_lock.release() - - def test_pub_addresses(self): - """Test retrieving addresses.""" - from posttroll.ns import get_pub_addresses - from posttroll.publisher import Publish + From https://gist.github.com/bertjwregeer/0be94ced48383a42e70c3d9fff1f4ad0 - with Publish(str("data_provider"), 0, ["this_data"], broadcast_interval=0.1): - time.sleep(.3) - res = get_pub_addresses(["this_data"], timeout=.5) - assert len(res) == 1 - expected = {u"status": True, - u"service": [u"data_provider", u"this_data"], - u"name": u"address"} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] - res = get_pub_addresses([str("data_provider")]) - assert len(res) == 1 - expected = {u"status": True, - u"service": [u"data_provider", u"this_data"], - u"name": u"address"} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] + Returns a factory that finds the next free port that is available on the OS + This is a bit of a hack, it does this by creating a new socket, and calling + bind with the 0 port. The operating system will assign a brand new port, + which we can find out using getsockname(). Once we have the new port + information we close the socket thereby returning it to the free pool. + This means it is technically possible for this function to return the same + port twice (for example if run in very quick succession), however operating + systems return a random port number in the default range (1024 - 65535), + and it is highly unlikely for two processes to get the same port number. + In other words, it is possible to flake, but incredibly unlikely. + """ + import socket - def test_pub_sub_ctx(self): - """Test publish and subscribe.""" - from posttroll.message import Message - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("0.0.0.0", 0)) + portnum = s.getsockname()[1] + s.close() - with Publish("data_provider", 0, ["this_data"]) as pub: - with Subscribe("this_data", "counter") as sub: - for counter in range(5): - message = Message("/counter", "info", str(counter)) - pub.send(str(message)) - time.sleep(1) - msg = next(sub.recv(2)) - if msg is not None: - assert str(msg) == str(message) - tested = True - sub.close() - assert tested + return portnum - def test_pub_sub_add_rm(self): - """Test adding and removing publishers.""" - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe - time.sleep(4) - with Subscribe("this_data", "counter", True) as sub: - assert len(sub.addresses) == 0 - with Publish("data_provider", 0, ["this_data"]): - time.sleep(4) - next(sub.recv(2)) - assert len(sub.addresses) == 1 - time.sleep(3) - for msg in sub.recv(2): - if msg is None: - break - time.sleep(3) - assert len(sub.addresses) == 0 - with Publish("data_provider_2", 0, ["another_data"]): - time.sleep(4) - next(sub.recv(2)) - assert len(sub.addresses) == 0 - sub.close() - - -class TestNSWithoutMulticasting: - """Test the nameserver.""" +@contextmanager +def create_nameserver_instance(max_age=3, multicast_enabled=True): + config.set(nameserver_port=free_port()) + config.set(address_publish_port=free_port()) + ns = NameServer(max_age=timedelta(seconds=max_age), multicast_enabled=multicast_enabled) + thr = Thread(target=ns.run) + thr.start() - def setup_method(self): - """Set up the testing class.""" - test_lock.acquire() - self.nameservers = ["localhost"] - self.max_age = .3 - self.ns = NameServer(max_age=timedelta(seconds=self.max_age), - multicast_enabled=False) - self.thr = Thread(target=self.ns.run) - self.thr.start() - - def teardown_method(self): - """Clean up after the tests have run.""" - self.ns.stop() - self.thr.join() - time.sleep(2) - test_lock.release() + try: + yield + finally: + ns.stop() + thr.join() + + +@pytest.mark.parametrize( + "multicast_enabled", + [True, False] +) +def test_pub_addresses(multicast_enabled): + """Test retrieving addresses.""" + from posttroll.ns import get_pub_addresses + from posttroll.publisher import Publish - def test_pub_addresses(self): - """Test retrieving addresses.""" - from posttroll.ns import get_pub_addresses - from posttroll.publisher import Publish - with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers, broadcast_interval=.1): - time.sleep(.2) - res = get_pub_addresses(["this_data"]) + with create_nameserver_instance(multicast_enabled=multicast_enabled): + if multicast_enabled: + nameservers = None + else: + nameservers = ["localhost"] + with Publish(str("data_provider"), 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1): + time.sleep(.3) + res = get_pub_addresses(["this_data"], timeout=.5) assert len(res) == 1 expected = {u"status": True, u"service": [u"data_provider", u"this_data"], @@ -166,7 +110,7 @@ def test_pub_addresses(self): assert res[0][key] == val assert "receive_time" in res[0] assert "URI" in res[0] - res = get_pub_addresses(["data_provider"]) + res = get_pub_addresses([str("data_provider")]) assert len(res) == 1 expected = {u"status": True, u"service": [u"data_provider", u"this_data"], @@ -176,52 +120,67 @@ def test_pub_addresses(self): assert "receive_time" in res[0] assert "URI" in res[0] - def test_pub_sub_ctx(self): - """Test publish and subscribe.""" - from posttroll.message import Message - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe +@pytest.mark.parametrize( + "multicast_enabled", + [True, False] +) +def test_pub_sub_ctx(multicast_enabled): + """Test publish and subscribe.""" + from posttroll.message import Message + from posttroll.publisher import Publish + from posttroll.subscriber import Subscribe - with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers, broadcast_interval=.1) as pub: + with create_nameserver_instance(multicast_enabled=multicast_enabled): + if multicast_enabled: + nameservers = None + else: + nameservers = ["localhost"] + with Publish("data_provider", 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1) as pub: with Subscribe("this_data", "counter") as sub: for counter in range(5): message = Message("/counter", "info", str(counter)) pub.send(str(message)) time.sleep(.1) - msg = next(sub.recv(2)) + msg = next(sub.recv(.2)) if msg is not None: assert str(msg) == str(message) tested = True sub.close() assert tested - def test_pub_sub_add_rm(self): - """Test adding and removing publishers.""" - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe +@pytest.mark.parametrize( + "multicast_enabled", + [True, False] +) +def test_pub_sub_add_rm(multicast_enabled): + """Test adding and removing publishers.""" + from posttroll.publisher import Publish + from posttroll.subscriber import Subscribe + + max_age = 0.5 - with Subscribe("this_data", "counter", True, timeout=.1) as sub: + with create_nameserver_instance(max_age=max_age, multicast_enabled=multicast_enabled): + if multicast_enabled: + nameservers = None + else: + nameservers = ["localhost"] + with Subscribe("this_data", "counter", True, timeout=.2) as sub: assert len(sub.addresses) == 0 - with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers, broadcast_interval=.1): - time.sleep(4) - next(sub.recv(.2)) + with Publish("data_provider", 0, ["this_data"], nameservers=nameservers): + time.sleep(.1) + next(sub.recv(.1)) assert len(sub.addresses) == 1 - - time.sleep(3) - - for msg in sub.recv(.2): + time.sleep(max_age * 2) + for msg in sub.recv(.1): if msg is None: break - - time.sleep(3) + time.sleep(.1) assert len(sub.addresses) == 0 - with Publish("data_provider_2", 0, ["another_data"], - nameservers=self.nameservers, broadcast_interval=.1): - time.sleep(4) - next(sub.recv(.2)) + with Publish("data_provider_2", 0, ["another_data"], nameservers=nameservers): + time.sleep(.1) + next(sub.recv(.1)) assert len(sub.addresses) == 0 + sub.close() class TestPubSub(unittest.TestCase): @@ -305,7 +264,7 @@ def test_pub_supports_unicode(self): from posttroll.publisher import Publish message = Message("/pџтяöll", "info", "hej") - with Publish("a_service", 9000) as pub: + with Publish("a_service", 0) as pub: try: pub.send(message.encode()) except UnicodeDecodeError: @@ -356,28 +315,13 @@ def _get_port_from_publish_instance(min_port=None, max_port=None): return False -class TestListenerContainer(unittest.TestCase): - """Testing listener container.""" - - def setUp(self): - """Set up the testing class.""" - test_lock.acquire() - self.ns = NameServer(max_age=timedelta(seconds=3)) - self.thr = Thread(target=self.ns.run) - self.thr.start() - - def tearDown(self): - """Clean up after the tests have run.""" - self.ns.stop() - self.thr.join() - test_lock.release() - - def test_listener_container(self): - """Test listener container.""" - from posttroll.listener import ListenerContainer - from posttroll.message import Message - from posttroll.publisher import NoisyPublisher +def test_listener_container(): + """Test listener container.""" + from posttroll.listener import ListenerContainer + from posttroll.message import Message + from posttroll.publisher import NoisyPublisher + with create_nameserver_instance(): pub = NoisyPublisher("test", broadcast_interval=0.1) pub.start() sub = ListenerContainer(topics=["/counter"]) @@ -654,7 +598,7 @@ def _tcp_keepalive_no_settings(): def test_publisher_tcp_keepalive(): """Test that TCP Keepalive is set for Publisher if the environment variables are present.""" from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - pub = UnsecureZMQPublisher("tcp://127.0.0.1:9001").start() + pub = UnsecureZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start() _assert_tcp_keepalive(pub.publish_socket) pub.stop() @@ -663,7 +607,7 @@ def test_publisher_tcp_keepalive(): def test_publisher_tcp_keepalive_not_set(): """Test that TCP Keepalive is not set on by default.""" from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - pub = UnsecureZMQPublisher("tcp://127.0.0.1:9002").start() + pub = UnsecureZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start() _assert_no_tcp_keepalive(pub.publish_socket) pub.stop() @@ -672,7 +616,7 @@ def test_publisher_tcp_keepalive_not_set(): def test_subscriber_tcp_keepalive(): """Test that TCP Keepalive is set for Subscriber if the environment variables are present.""" from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber - sub = UnsecureZMQSubscriber("tcp://127.0.0.1:9000") + sub = UnsecureZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}") assert len(sub.addr_sub.values()) == 1 _assert_tcp_keepalive(list(sub.addr_sub.values())[0]) sub.stop() @@ -682,7 +626,7 @@ def test_subscriber_tcp_keepalive(): def test_subscriber_tcp_keepalive_not_set(): """Test that TCP Keepalive is not set on by default.""" from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber - sub = UnsecureZMQSubscriber("tcp://127.0.0.1:9000") + sub = UnsecureZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}") assert len(sub.addr_sub.values()) == 1 _assert_no_tcp_keepalive(list(sub.addr_sub.values())[0]) sub.close() diff --git a/pyproject.toml b/pyproject.toml index 784768f..5f1c8e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,11 @@ description = "Messaging system for pytroll" authors = [ { name = "The Pytroll Team", email = "pytroll@googlegroups.com" } ] -dependencies = ["pyzmq", "netifaces", "donfig"] +dependencies = [ + "pyzmq", + "netifaces-plus", + "donfig", +] readme = "README.md" requires-python = ">=3.10" license = { text = "GPLv3" } diff --git a/versioneer.py b/versioneer.py deleted file mode 100644 index 070e384..0000000 --- a/versioneer.py +++ /dev/null @@ -1,2146 +0,0 @@ - -# Version: 0.23 - -"""The Versioneer - like a rocketeer, but for versions. - -The Versioneer -============== - -* like a rocketeer, but for versions! -* https://github.com/python-versioneer/python-versioneer -* Brian Warner -* License: Public Domain (CC0-1.0) -* Compatible with: Python 3.7, 3.8, 3.9, 3.10 and pypy3 -* [![Latest Version][pypi-image]][pypi-url] -* [![Build Status][travis-image]][travis-url] - -This is a tool for managing a recorded version number in distutils/setuptools-based -python projects. The goal is to remove the tedious and error-prone "update -the embedded version string" step from your release process. Making a new -release should be as easy as recording a new tag in your version-control -system, and maybe making new tarballs. - - -## Quick Install - -* `pip install versioneer` to somewhere in your $PATH -* add a `[versioneer]` section to your setup.cfg (see [Install](INSTALL.md)) -* run `versioneer install` in your source tree, commit the results -* Verify version information with `python setup.py version` - -## Version Identifiers - -Source trees come from a variety of places: - -* a version-control system checkout (mostly used by developers) -* a nightly tarball, produced by build automation -* a snapshot tarball, produced by a web-based VCS browser, like github's - "tarball from tag" feature -* a release tarball, produced by "setup.py sdist", distributed through PyPI - -Within each source tree, the version identifier (either a string or a number, -this tool is format-agnostic) can come from a variety of places: - -* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows - about recent "tags" and an absolute revision-id -* the name of the directory into which the tarball was unpacked -* an expanded VCS keyword ($Id$, etc) -* a `_version.py` created by some earlier build step - -For released software, the version identifier is closely related to a VCS -tag. Some projects use tag names that include more than just the version -string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool -needs to strip the tag prefix to extract the version identifier. For -unreleased software (between tags), the version identifier should provide -enough information to help developers recreate the same tree, while also -giving them an idea of roughly how old the tree is (after version 1.2, before -version 1.3). Many VCS systems can report a description that captures this, -for example `git describe --tags --dirty --always` reports things like -"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the -0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes). - -The version identifier is used for multiple purposes: - -* to allow the module to self-identify its version: `myproject.__version__` -* to choose a name and prefix for a 'setup.py sdist' tarball - -## Theory of Operation - -Versioneer works by adding a special `_version.py` file into your source -tree, where your `__init__.py` can import it. This `_version.py` knows how to -dynamically ask the VCS tool for version information at import time. - -`_version.py` also contains `$Revision$` markers, and the installation -process marks `_version.py` to have this marker rewritten with a tag name -during the `git archive` command. As a result, generated tarballs will -contain enough information to get the proper version. - -To allow `setup.py` to compute a version too, a `versioneer.py` is added to -the top level of your source tree, next to `setup.py` and the `setup.cfg` -that configures it. This overrides several distutils/setuptools commands to -compute the version when invoked, and changes `setup.py build` and `setup.py -sdist` to replace `_version.py` with a small static file that contains just -the generated version data. - -## Installation - -See [INSTALL.md](./INSTALL.md) for detailed installation instructions. - -## Version-String Flavors - -Code which uses Versioneer can learn about its version string at runtime by -importing `_version` from your main `__init__.py` file and running the -`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can -import the top-level `versioneer.py` and run `get_versions()`. - -Both functions return a dictionary with different flavors of version -information: - -* `['version']`: A condensed version string, rendered using the selected - style. This is the most commonly used value for the project's version - string. The default "pep440" style yields strings like `0.11`, - `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section - below for alternative styles. - -* `['full-revisionid']`: detailed revision identifier. For Git, this is the - full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". - -* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the - commit date in ISO 8601 format. This will be None if the date is not - available. - -* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that - this is only accurate if run in a VCS checkout, otherwise it is likely to - be False or None - -* `['error']`: if the version string could not be computed, this will be set - to a string describing the problem, otherwise it will be None. It may be - useful to throw an exception in setup.py if this is set, to avoid e.g. - creating tarballs with a version string of "unknown". - -Some variants are more useful than others. Including `full-revisionid` in a -bug report should allow developers to reconstruct the exact code being tested -(or indicate the presence of local changes that should be shared with the -developers). `version` is suitable for display in an "about" box or a CLI -`--version` output: it can be easily compared against release notes and lists -of bugs fixed in various releases. - -The installer adds the following text to your `__init__.py` to place a basic -version in `YOURPROJECT.__version__`: - - from ._version import get_versions - __version__ = get_versions()['version'] - del get_versions - -## Styles - -The setup.cfg `style=` configuration controls how the VCS information is -rendered into a version string. - -The default style, "pep440", produces a PEP440-compliant string, equal to the -un-prefixed tag name for actual releases, and containing an additional "local -version" section with more detail for in-between builds. For Git, this is -TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags ---dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the -tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and -that this commit is two revisions ("+2") beyond the "0.11" tag. For released -software (exactly equal to a known tag), the identifier will only contain the -stripped tag, e.g. "0.11". - -Other styles are available. See [details.md](details.md) in the Versioneer -source tree for descriptions. - -## Debugging - -Versioneer tries to avoid fatal errors: if something goes wrong, it will tend -to return a version of "0+unknown". To investigate the problem, run `setup.py -version`, which will run the version-lookup code in a verbose mode, and will -display the full contents of `get_versions()` (including the `error` string, -which may help identify what went wrong). - -## Known Limitations - -Some situations are known to cause problems for Versioneer. This details the -most significant ones. More can be found on Github -[issues page](https://github.com/python-versioneer/python-versioneer/issues). - -### Subprojects - -Versioneer has limited support for source trees in which `setup.py` is not in -the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are -two common reasons why `setup.py` might not be in the root: - -* Source trees which contain multiple subprojects, such as - [Buildbot](https://github.com/buildbot/buildbot), which contains both - "master" and "slave" subprojects, each with their own `setup.py`, - `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI - distributions (and upload multiple independently-installable tarballs). -* Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other languages) in subdirectories. - -Versioneer will look for `.git` in parent directories, and most operations -should get the right version string. However `pip` and `setuptools` have bugs -and implementation details which frequently cause `pip install .` from a -subproject directory to fail to find a correct version string (so it usually -defaults to `0+unknown`). - -`pip install --editable .` should work correctly. `setup.py install` might -work too. - -Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in -some later version. - -[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking -this issue. The discussion in -[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the -issue from the Versioneer side in more detail. -[pip PR#3176](https://github.com/pypa/pip/pull/3176) and -[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve -pip to let Versioneer work correctly. - -Versioneer-0.16 and earlier only looked for a `.git` directory next to the -`setup.cfg`, so subprojects were completely unsupported with those releases. - -### Editable installs with setuptools <= 18.5 - -`setup.py develop` and `pip install --editable .` allow you to install a -project into a virtualenv once, then continue editing the source code (and -test) without re-installing after every change. - -"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a -convenient way to specify executable scripts that should be installed along -with the python package. - -These both work as expected when using modern setuptools. When using -setuptools-18.5 or earlier, however, certain operations will cause -`pkg_resources.DistributionNotFound` errors when running the entrypoint -script, which must be resolved by re-installing the package. This happens -when the install happens with one version, then the egg_info data is -regenerated while a different version is checked out. Many setup.py commands -cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into -a different virtualenv), so this can be surprising. - -[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes -this one, but upgrading to a newer version of setuptools should probably -resolve it. - - -## Updating Versioneer - -To upgrade your project to a new release of Versioneer, do the following: - -* install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace - `SRC/_version.py` -* commit any changed files - -## Future Directions - -This tool is designed to make it easily extended to other version-control -systems: all VCS-specific components are in separate directories like -src/git/ . The top-level `versioneer.py` script is assembled from these -components by running make-versioneer.py . In the future, make-versioneer.py -will take a VCS name as an argument, and will construct a version of -`versioneer.py` that is specific to the given VCS. It might also take the -configuration arguments that are currently provided manually during -installation by editing setup.py . Alternatively, it might go the other -direction and include code from all supported VCS systems, reducing the -number of intermediate scripts. - -## Similar projects - -* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time - dependency -* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of - versioneer -* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools - plugin - -## License - -To make Versioneer easier to embed, all its code is dedicated to the public -domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . - -[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg -[pypi-url]: https://pypi.python.org/pypi/versioneer/ -[travis-image]: -https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg -[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer - -""" -# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring -# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements -# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error -# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with -# pylint:disable=attribute-defined-outside-init,too-many-arguments - -import configparser -import errno -import json -import os -import re -import subprocess -import sys -from typing import Callable, Dict -import functools - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_root(): - """Get the project root directory. - - We require that all commands are run from the project root, i.e. the - directory that contains setup.py, setup.cfg, and versioneer.py . - """ - root = os.path.realpath(os.path.abspath(os.getcwd())) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - # allow 'python path/to/setup.py COMMAND' - root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") - raise VersioneerBadRootError(err) - try: - # Certain runtime workflows (setup.py install/develop in a setuptools - # tree) execute all dependencies in a single python process, so - # "versioneer" may be imported multiple times, and python's shared - # module-import table will cache the first one. So we can't use - # os.path.dirname(__file__), as that will find whichever - # versioneer.py was first imported, even in later projects. - my_path = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(my_path)[0]) - vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(my_path), versioneer_py)) - except NameError: - pass - return root - - -def get_config_from_root(root): - """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise OSError (if setup.cfg is missing), or - # configparser.NoSectionError (if it lacks a [versioneer] section), or - # configparser.NoOptionError (if it lacks "VCS="). See the docstring at - # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.ConfigParser() - with open(setup_cfg, "r") as cfg_file: - parser.read_file(cfg_file) - VCS = parser.get("versioneer", "VCS") # mandatory - - # Dict-like interface for non-mandatory entries - section = parser["versioneer"] - - cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = section.get("style", "") - cfg.versionfile_source = section.get("versionfile_source") - cfg.versionfile_build = section.get("versionfile_build") - cfg.tag_prefix = section.get("tag_prefix") - if cfg.tag_prefix in ("''", '""', None): - cfg.tag_prefix = "" - cfg.parentdir_prefix = section.get("parentdir_prefix") - cfg.verbose = section.get("verbose") - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -# these dictionaries contain VCS-specific tools -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} - - -def register_vcs_handler(vcs, method): # decorator - """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - HANDLERS.setdefault(vcs, {})[method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - - popen_kwargs = {} - if sys.platform == "win32": - # This hides the console window if pythonw.exe is used - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - popen_kwargs["startupinfo"] = startupinfo - - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, process.returncode - return stdout, process.returncode - - -LONG_VERSION_PY['git'] = r''' -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.23 (https://github.com/python-versioneer/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys -from typing import Callable, Dict -import functools - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" - git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" - git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "%(STYLE)s" - cfg.tag_prefix = "%(TAG_PREFIX)s" - cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" - cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} - - -def register_vcs_handler(vcs, method): # decorator - """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - - popen_kwargs = {} - if sys.platform == "win32": - # This hides the console window if pythonw.exe is used - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - popen_kwargs["startupinfo"] = startupinfo - - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %%s" %% dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %%s" %% (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %%s (error)" %% dispcmd) - print("stdout was %%s" %% stdout) - return None, process.returncode - return stdout, process.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %%s but none started with prefix %%s" %% - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %%d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} - if verbose: - print("discarding '%%s', no digits" %% ",".join(refs - tags)) - if verbose: - print("likely tags: %%s" %% ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r'\d', r): - continue - if verbose: - print("picking %%s" %% r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - # GIT_DIR can interfere with correct operation of Versioneer. - # It may be intended to be passed to the Versioneer-versioned project, - # but that should not change where we get our version from. - env = os.environ.copy() - env.pop("GIT_DIR", None) - runner = functools.partial(runner, env=env) - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %%s not under git control" %% root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%%s'" - %% describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%%s' doesn't start with prefix '%%s'" - print(fmt %% (full_tag, tag_prefix)) - pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" - %% (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) - pieces["distance"] = len(out.split()) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces): - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def pep440_split_post(ver): - """Split pep440 version string at the post-release segment. - - Returns the release segments before the post-release and the - post-release version number (or -1 if no post-release segment is present). - """ - vc = str.split(ver, ".post") - return vc[0], int(vc[1] or 0) if len(vc) == 2 else None - - -def render_pep440_pre(pieces): - """TAG[.postN.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - if pieces["distance"]: - # update the post release segment - tag_version, post_version = pep440_split_post(pieces["closest-tag"]) - rendered = tag_version - if post_version is not None: - rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"]) - else: - rendered += ".post0.dev%%d" %% (pieces["distance"]) - else: - # no commits, use the tag as the version - rendered = pieces["closest-tag"] - else: - # exception #1 - rendered = "0.post0.dev%%d" %% pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%%s'" %% style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} -''' - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r'\d', r): - continue - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - # GIT_DIR can interfere with correct operation of Versioneer. - # It may be intended to be passed to the Versioneer-versioned project, - # but that should not change where we get our version from. - env = os.environ.copy() - env.pop("GIT_DIR", None) - runner = functools.partial(runner, env=env) - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) - pieces["distance"] = len(out.split()) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def do_vcs_install(versionfile_source, ipy): - """Git-specific installation logic for Versioneer. - - For Git, this means creating/changing .gitattributes to mark _version.py - for export-subst keyword substitution. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - files = [versionfile_source] - if ipy: - files.append(ipy) - try: - my_path = __file__ - if my_path.endswith(".pyc") or my_path.endswith(".pyo"): - my_path = os.path.splitext(my_path)[0] + ".py" - versioneer_file = os.path.relpath(my_path) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) - present = False - try: - with open(".gitattributes", "r") as fobj: - for line in fobj: - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - break - except OSError: - pass - if not present: - with open(".gitattributes", "a+") as fobj: - fobj.write(f"{versionfile_source} export-subst\n") - files.append(".gitattributes") - run_command(GITS, ["add", "--"] + files) - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.23) from -# revision-control system data, or from the parent directory name of an -# unpacked source archive. Distribution tarballs contain a pre-generated copy -# of this file. - -import json - -version_json = ''' -%s -''' # END VERSION_JSON - - -def get_versions(): - return json.loads(version_json) -""" - - -def versions_from_file(filename): - """Try to determine the version from _version.py if present.""" - try: - with open(filename) as f: - contents = f.read() - except OSError: - raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - raise NotThisMethod("no version_json in _version.py") - return json.loads(mo.group(1)) - - -def write_to_version_file(filename, versions): - """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) - with open(filename, "w") as f: - f.write(SHORT_VERSION_PY % contents) - - print("set %s to '%s'" % (filename, versions["version"])) - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces): - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def pep440_split_post(ver): - """Split pep440 version string at the post-release segment. - - Returns the release segments before the post-release and the - post-release version number (or -1 if no post-release segment is present). - """ - vc = str.split(ver, ".post") - return vc[0], int(vc[1] or 0) if len(vc) == 2 else None - - -def render_pep440_pre(pieces): - """TAG[.postN.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - if pieces["distance"]: - # update the post release segment - tag_version, post_version = pep440_split_post(pieces["closest-tag"]) - rendered = tag_version - if post_version is not None: - rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) - else: - rendered += ".post0.dev%d" % (pieces["distance"]) - else: - # no commits, use the tag as the version - rendered = pieces["closest-tag"] - else: - # exception #1 - rendered = "0.post0.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -class VersioneerBadRootError(Exception): - """The project root directory is unknown or missing key files.""" - - -def get_versions(verbose=False): - """Get the project version from whatever source is available. - - Returns dict with two keys: 'version' and 'full'. - """ - if "versioneer" in sys.modules: - # see the discussion in cmdclass.py:get_cmdclass() - del sys.modules["versioneer"] - - root = get_root() - cfg = get_config_from_root(root) - - assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" - handlers = HANDLERS.get(cfg.VCS) - assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" - assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" - - versionfile_abs = os.path.join(root, cfg.versionfile_source) - - # extract version from first of: _version.py, VCS command (e.g. 'git - # describe'), parentdir. This is meant to work for developers using a - # source checkout, for users of a tarball created by 'setup.py sdist', - # and for users of a tarball/zipball created by 'git archive' or github's - # download-from-tag feature or the equivalent in other VCSes. - - get_keywords_f = handlers.get("get_keywords") - from_keywords_f = handlers.get("keywords") - if get_keywords_f and from_keywords_f: - try: - keywords = get_keywords_f(versionfile_abs) - ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) - if verbose: - print("got version from expanded keyword %s" % ver) - return ver - except NotThisMethod: - pass - - try: - ver = versions_from_file(versionfile_abs) - if verbose: - print("got version from file %s %s" % (versionfile_abs, ver)) - return ver - except NotThisMethod: - pass - - from_vcs_f = handlers.get("pieces_from_vcs") - if from_vcs_f: - try: - pieces = from_vcs_f(cfg.tag_prefix, root, verbose) - ver = render(pieces, cfg.style) - if verbose: - print("got version from VCS %s" % ver) - return ver - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - if verbose: - print("got version from parentdir %s" % ver) - return ver - except NotThisMethod: - pass - - if verbose: - print("unable to compute version") - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} - - -def get_version(): - """Get the short version string for this project.""" - return get_versions()["version"] - - -def get_cmdclass(cmdclass=None): - """Get the custom setuptools subclasses used by Versioneer. - - If the package uses a different cmdclass (e.g. one from numpy), it - should be provide as an argument. - """ - if "versioneer" in sys.modules: - del sys.modules["versioneer"] - # this fixes the "python setup.py develop" case (also 'install' and - # 'easy_install .'), in which subdependencies of the main project are - # built (using setup.py bdist_egg) in the same python process. Assume - # a main project A and a dependency B, which use different versions - # of Versioneer. A's setup.py imports A's Versioneer, leaving it in - # sys.modules by the time B's setup.py is executed, causing B to run - # with the wrong versioneer. Setuptools wraps the sub-dep builds in a - # sandbox that restores sys.modules to it's pre-build state, so the - # parent is protected against the child's "import versioneer". By - # removing ourselves from sys.modules here, before the child build - # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/python-versioneer/python-versioneer/issues/52 - - cmds = {} if cmdclass is None else cmdclass.copy() - - # we add "version" to setuptools - from setuptools import Command - - class cmd_version(Command): - description = "report generated version string" - user_options = [] - boolean_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - vers = get_versions(verbose=True) - print("Version: %s" % vers["version"]) - print(" full-revisionid: %s" % vers.get("full-revisionid")) - print(" dirty: %s" % vers.get("dirty")) - print(" date: %s" % vers.get("date")) - if vers["error"]: - print(" error: %s" % vers["error"]) - cmds["version"] = cmd_version - - # we override "build_py" in setuptools - # - # most invocation pathways end up running build_py: - # distutils/build -> build_py - # distutils/install -> distutils/build ->.. - # setuptools/bdist_wheel -> distutils/install ->.. - # setuptools/bdist_egg -> distutils/install_lib -> build_py - # setuptools/install -> bdist_egg ->.. - # setuptools/develop -> ? - # pip install: - # copies source tree to a tempdir before running egg_info/etc - # if .git isn't copied too, 'git describe' will fail - # then does setup.py bdist_wheel, or sometimes setup.py install - # setup.py egg_info -> ? - - # pip install -e . and setuptool/editable_wheel will invoke build_py - # but the build_py command is not expected to copy any files. - - # we override different "build_py" commands for both environments - if 'build_py' in cmds: - _build_py = cmds['build_py'] - else: - from setuptools.command.build_py import build_py as _build_py - - class cmd_build_py(_build_py): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_py.run(self) - if getattr(self, "editable_mode", False): - # During editable installs `.py` and data files are - # not copied to build_lib - return - # now locate _version.py in the new build/ directory and replace - # it with an updated value - if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - cmds["build_py"] = cmd_build_py - - if 'build_ext' in cmds: - _build_ext = cmds['build_ext'] - else: - from setuptools.command.build_ext import build_ext as _build_ext - - class cmd_build_ext(_build_ext): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_ext.run(self) - if self.inplace: - # build_ext --inplace will only build extensions in - # build/lib<..> dir with no _version.py to write to. - # As in place builds will already have a _version.py - # in the module dir, we do not need to write one. - return - # now locate _version.py in the new build/ directory and replace - # it with an updated value - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) - if not os.path.exists(target_versionfile): - print(f"Warning: {target_versionfile} does not exist, skipping " - "version update. This can happen if you are running build_ext " - "without first running build_py.") - return - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - cmds["build_ext"] = cmd_build_ext - - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe - # nczeczulin reports that py2exe won't like the pep440-style string - # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. - # setup(console=[{ - # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION - # "product_version": versioneer.get_version(), - # ... - - class cmd_build_exe(_build_exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _build_exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["build_exe"] = cmd_build_exe - del cmds["build_py"] - - if 'py2exe' in sys.modules: # py2exe enabled? - from py2exe.distutils_buildexe import py2exe as _py2exe - - class cmd_py2exe(_py2exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _py2exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["py2exe"] = cmd_py2exe - - # sdist farms its file list building out to egg_info - if 'egg_info' in cmds: - _sdist = cmds['egg_info'] - else: - from setuptools.command.egg_info import egg_info as _egg_info - - class cmd_egg_info(_egg_info): - def find_sources(self): - # egg_info.find_sources builds the manifest list and writes it - # in one shot - super().find_sources() - - # Modify the filelist and normalize it - root = get_root() - cfg = get_config_from_root(root) - self.filelist.append('versioneer.py') - if cfg.versionfile_source: - # There are rare cases where versionfile_source might not be - # included by default, so we must be explicit - self.filelist.append(cfg.versionfile_source) - self.filelist.sort() - self.filelist.remove_duplicates() - - # The write method is hidden in the manifest_maker instance that - # generated the filelist and was thrown away - # We will instead replicate their final normalization (to unicode, - # and POSIX-style paths) - from setuptools import unicode_utils - normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/') - for f in self.filelist.files] - - manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt') - with open(manifest_filename, 'w') as fobj: - fobj.write('\n'.join(normalized)) - - cmds['egg_info'] = cmd_egg_info - - # we override different "sdist" commands for both environments - if 'sdist' in cmds: - _sdist = cmds['sdist'] - else: - from setuptools.command.sdist import sdist as _sdist - - class cmd_sdist(_sdist): - def run(self): - versions = get_versions() - self._versioneer_generated_versions = versions - # unless we update this, the command will keep using the old - # version - self.distribution.metadata.version = versions["version"] - return _sdist.run(self) - - def make_release_tree(self, base_dir, files): - root = get_root() - cfg = get_config_from_root(root) - _sdist.make_release_tree(self, base_dir, files) - # now locate _version.py in the new base_dir directory - # (remembering that it may be a hardlink) and replace it with an - # updated value - target_versionfile = os.path.join(base_dir, cfg.versionfile_source) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) - cmds["sdist"] = cmd_sdist - - return cmds - - -CONFIG_ERROR = """ -setup.cfg is missing the necessary Versioneer configuration. You need -a section like: - - [versioneer] - VCS = git - style = pep440 - versionfile_source = src/myproject/_version.py - versionfile_build = myproject/_version.py - tag_prefix = - parentdir_prefix = myproject- - -You will also need to edit your setup.py to use the results: - - import versioneer - setup(version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), ...) - -Please read the docstring in ./versioneer.py for configuration instructions, -edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. -""" - -SAMPLE_CONFIG = """ -# See the docstring in versioneer.py for instructions. Note that you must -# re-run 'versioneer.py setup' after changing this section, and commit the -# resulting files. - -[versioneer] -#VCS = git -#style = pep440 -#versionfile_source = -#versionfile_build = -#tag_prefix = -#parentdir_prefix = - -""" - -OLD_SNIPPET = """ -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions -""" - -INIT_PY_SNIPPET = """ -from . import {0} -__version__ = {0}.get_versions()['version'] -""" - - -def do_setup(): - """Do main VCS-independent setup function for installing Versioneer.""" - root = get_root() - try: - cfg = get_config_from_root(root) - except (OSError, configparser.NoSectionError, - configparser.NoOptionError) as e: - if isinstance(e, (OSError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) - with open(os.path.join(root, "setup.cfg"), "a") as f: - f.write(SAMPLE_CONFIG) - print(CONFIG_ERROR, file=sys.stderr) - return 1 - - print(" creating %s" % cfg.versionfile_source) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") - if os.path.exists(ipy): - try: - with open(ipy, "r") as f: - old = f.read() - except OSError: - old = "" - module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] - snippet = INIT_PY_SNIPPET.format(module) - if OLD_SNIPPET in old: - print(" replacing boilerplate in %s" % ipy) - with open(ipy, "w") as f: - f.write(old.replace(OLD_SNIPPET, snippet)) - elif snippet not in old: - print(" appending to %s" % ipy) - with open(ipy, "a") as f: - f.write(snippet) - else: - print(" %s unmodified" % ipy) - else: - print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make VCS-specific changes. For git, this means creating/changing - # .gitattributes to mark _version.py for export-subst keyword - # substitution. - do_vcs_install(cfg.versionfile_source, ipy) - return 0 - - -def scan_setup_py(): - """Validate the contents of setup.py against Versioneer's expectations.""" - found = set() - setters = False - errors = 0 - with open("setup.py", "r") as f: - for line in f.readlines(): - if "import versioneer" in line: - found.add("import") - if "versioneer.get_cmdclass()" in line: - found.add("cmdclass") - if "versioneer.get_version()" in line: - found.add("get_version") - if "versioneer.VCS" in line: - setters = True - if "versioneer.versionfile_source" in line: - setters = True - if len(found) != 3: - print("") - print("Your setup.py appears to be missing some important items") - print("(but I might be wrong). Please make sure it has something") - print("roughly like the following:") - print("") - print(" import versioneer") - print(" setup( version=versioneer.get_version(),") - print(" cmdclass=versioneer.get_cmdclass(), ...)") - print("") - errors += 1 - if setters: - print("You should remove lines like 'versioneer.VCS = ' and") - print("'versioneer.versionfile_source = ' . This configuration") - print("now lives in setup.cfg, and should be removed from setup.py") - print("") - errors += 1 - return errors - - -if __name__ == "__main__": - cmd = sys.argv[1] - if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1) From 720511fc46984e0102747ffd620098936fcfbd0e Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 25 Apr 2024 12:52:53 +0200 Subject: [PATCH 28/45] Allow multicast-dependent tests to be skipped --- posttroll/backends/zmq/ns.py | 2 +- posttroll/subscriber.py | 2 +- posttroll/tests/test_pubsub.py | 41 ++++++++++++++++++++++++++-------- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py index f400ed9..4f7214c 100644 --- a/posttroll/backends/zmq/ns.py +++ b/posttroll/backends/zmq/ns.py @@ -7,7 +7,7 @@ from posttroll.backends.zmq import get_context from posttroll.message import Message -from posttroll.ns import get_configured_nameserver_port, get_active_address +from posttroll.ns import get_active_address, get_configured_nameserver_port logger = logging.getLogger("__name__") diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 9d4a05e..a59f590 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -29,9 +29,9 @@ import time from posttroll import config +from posttroll.address_receiver import get_configured_address_port from posttroll.message import _MAGICK from posttroll.ns import get_pub_address -from posttroll.address_receiver import get_configured_address_port LOGGER = logging.getLogger(__name__) diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 983e9a9..305bd27 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -23,6 +23,7 @@ """Test the publishing and subscribing facilities.""" +import os import time import unittest from contextlib import contextmanager @@ -83,7 +84,10 @@ def create_nameserver_instance(max_age=3, multicast_enabled=True): ns.stop() thr.join() - +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST"), + reason="Multicast tests disabled.", +) @pytest.mark.parametrize( "multicast_enabled", [True, False] @@ -93,12 +97,13 @@ def test_pub_addresses(multicast_enabled): from posttroll.ns import get_pub_addresses from posttroll.publisher import Publish + if multicast_enabled: + nameservers = None + else: + nameservers = ["localhost"] + with create_nameserver_instance(multicast_enabled=multicast_enabled): - if multicast_enabled: - nameservers = None - else: - nameservers = ["localhost"] with Publish(str("data_provider"), 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1): time.sleep(.3) res = get_pub_addresses(["this_data"], timeout=.5) @@ -120,6 +125,10 @@ def test_pub_addresses(multicast_enabled): assert "receive_time" in res[0] assert "URI" in res[0] +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST"), + reason="Multicast tests disabled.", +) @pytest.mark.parametrize( "multicast_enabled", [True, False] @@ -148,6 +157,11 @@ def test_pub_sub_ctx(multicast_enabled): sub.close() assert tested + +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST"), + reason="Multicast tests disabled.", +) @pytest.mark.parametrize( "multicast_enabled", [True, False] @@ -159,11 +173,12 @@ def test_pub_sub_add_rm(multicast_enabled): max_age = 0.5 + if multicast_enabled: + nameservers = None + else: + nameservers = ["localhost"] + with create_nameserver_instance(max_age=max_age, multicast_enabled=multicast_enabled): - if multicast_enabled: - nameservers = None - else: - nameservers = ["localhost"] with Subscribe("this_data", "counter", True, timeout=.2) as sub: assert len(sub.addresses) == 0 with Publish("data_provider", 0, ["this_data"], nameservers=nameservers): @@ -315,6 +330,10 @@ def _get_port_from_publish_instance(min_port=None, max_port=None): return False +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST"), + reason="Multicast tests disabled.", +) def test_listener_container(): """Test listener container.""" from posttroll.listener import ListenerContainer @@ -650,6 +669,10 @@ def _assert_no_tcp_keepalive(socket): assert socket.getsockopt(zmq.TCP_KEEPALIVE_INTVL) == -1 +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST"), + reason="Multicast tests disabled.", +) def test_noisypublisher_heartbeat(): """Test that the heartbeat in the NoisyPublisher works.""" from posttroll.ns import NameServer From 82954846e48e6b2145220cefb9d1b20a82e43f64 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 25 Apr 2024 15:00:19 +0200 Subject: [PATCH 29/45] Fix linting issues --- posttroll/address_receiver.py | 53 ++-- posttroll/backends/__init__.py | 1 + posttroll/backends/zmq/__init__.py | 2 + posttroll/backends/zmq/message_broadcaster.py | 3 + posttroll/backends/zmq/subscriber.py | 268 ++++-------------- posttroll/logger.py | 2 +- posttroll/message.py | 55 ++-- posttroll/ns.py | 6 +- posttroll/publisher.py | 3 +- posttroll/tests/test_pubsub.py | 1 + posttroll/tests/test_secure_zmq_backend.py | 2 + 11 files changed, 130 insertions(+), 266 deletions(-) diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index 2d4a6c6..2f0a24d 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -151,30 +151,7 @@ def _check_age(self, pub, min_interval=zero_seconds): def _run(self): """Run the receiver.""" port = broadcast_port - nameservers = False - if self._multicast_enabled: - while True: - try: - recv = MulticastReceiver(port) - except IOError as err: - if err.errno == errno.ENODEV: - LOGGER.error("Receiver initialization failed " - "(no such device). " - "Trying again in %d s", - 10) - time.sleep(10) - else: - raise - else: - recv.settimeout(tout=2.0) - LOGGER.info("Receiver initialized.") - break - - else: - if config.get("backend", "unsecure_zmq") == "unsecure_zmq": - from posttroll.backends.zmq.address_receiver import SimpleReceiver - recv = SimpleReceiver(port) - nameservers = ["localhost"] + nameservers, recv = self.set_up_address_receiver(port) self._is_running = True with Publish("address_receiver", self._port, ["addresses"], @@ -217,6 +194,34 @@ def _run(self): self._is_running = False recv.close() + def set_up_address_receiver(self, port): + """Set up the address receiver depending on if it is multicast or not.""" + nameservers = False + if self._multicast_enabled: + while True: + try: + recv = MulticastReceiver(port) + except IOError as err: + if err.errno == errno.ENODEV: + LOGGER.error("Receiver initialization failed " + "(no such device). " + "Trying again in %d s", + 10) + time.sleep(10) + else: + raise + else: + recv.settimeout(tout=2.0) + LOGGER.info("Receiver initialized.") + break + + else: + if config.get("backend", "unsecure_zmq") == "unsecure_zmq": + from posttroll.backends.zmq.address_receiver import SimpleReceiver + recv = SimpleReceiver(port) + nameservers = ["localhost"] + return nameservers,recv + def _add(self, adr, metadata): """Add an address.""" with self._address_lock: diff --git a/posttroll/backends/__init__.py b/posttroll/backends/__init__.py index e69de29..982ba70 100644 --- a/posttroll/backends/__init__.py +++ b/posttroll/backends/__init__.py @@ -0,0 +1 @@ +"""Init file for the backends.""" diff --git a/posttroll/backends/zmq/__init__.py b/posttroll/backends/zmq/__init__.py index f086f98..17a60f9 100644 --- a/posttroll/backends/zmq/__init__.py +++ b/posttroll/backends/zmq/__init__.py @@ -1,3 +1,4 @@ +"""Main module for the zmq backend.""" import logging import os @@ -21,6 +22,7 @@ def get_context(): return context[pid] def _set_tcp_keepalive(socket): + """Set the tcp keepalive parameters on *socket*.""" _set_int_sockopt(socket, zmq.TCP_KEEPALIVE, config.get("tcp_keepalive", None)) _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_CNT, config.get("tcp_keepalive_cnt", None)) _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_IDLE, config.get("tcp_keepalive_idle", None)) diff --git a/posttroll/backends/zmq/message_broadcaster.py b/posttroll/backends/zmq/message_broadcaster.py index 5fbff8d..060e9ae 100644 --- a/posttroll/backends/zmq/message_broadcaster.py +++ b/posttroll/backends/zmq/message_broadcaster.py @@ -1,3 +1,5 @@ +"""Message broadcaster implementation using zmq.""" + import logging import threading @@ -12,6 +14,7 @@ class UnsecureZMQDesignatedReceiversSender: """Sends message to multiple *receivers* on *port*.""" def __init__(self, default_port, receivers): + """Set up the sender.""" self.default_port = default_port self.receivers = receivers diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index 3351995..313f041 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -12,8 +12,8 @@ LOGGER = logging.getLogger(__name__) -class UnsecureZMQSubscriber: - """Unsecure ZMQ implementation of the subscriber.""" + +class _ZMQSubscriber: def __init__(self, addresses, topics="", message_filter=None, translate=False): """Initialize the subscriber.""" @@ -55,17 +55,6 @@ def add(self, address, topics=None): self.sub_addr[subscriber] = address self.addr_sub[address] = subscriber - def _add_sub_socket(self, address, topics): - subscriber = get_context().socket(SUB) - _set_tcp_keepalive(subscriber) - for t__ in topics: - subscriber.setsockopt_string(SUBSCRIBE, str(t__)) - subscriber.connect(address) - - if self.poller: - self.poller.register(subscriber, POLLIN) - return subscriber - def remove(self, address): """Remove *address* from the subscribing list for *topics*.""" with self._lock: @@ -151,35 +140,41 @@ def recv(self, timeout=None): try: while self._loop: sleep(0) - try: - socks = dict(self.poller.poll(timeout=timeout)) - if socks: - for sub in self.subscribers: - if sub in socks and socks[sub] == POLLIN: - received = sub.recv_string(NOBLOCK) - m__ = Message.decode(received) - if not self._filter or self._filter(m__): - if self._translate: - url = urlsplit(self.sub_addr[sub]) - host = url[1].split(":")[0] - m__.sender = (m__.sender.split("@")[0] - + "@" + host) - yield m__ - - for sub in self._hooks: - if sub in socks and socks[sub] == POLLIN: - m__ = Message.decode(sub.recv_string(NOBLOCK)) - self._hooks_cb[sub](m__) - else: - # timeout - yield None - except ZMQError as err: - if self._loop: - LOGGER.exception("Receive failed: %s", str(err)) + yield from self._new_messages(timeout) finally: for sub in list(self.subscribers) + self._hooks: self.poller.unregister(sub) + def _new_messages(self, timeout): + """Check for new messages to yield and pass to the callbacks.""" + try: + socks = dict(self.poller.poll(timeout=timeout)) + if socks: + for sub in self.subscribers: + if sub in socks and socks[sub] == POLLIN: + received = sub.recv_string(NOBLOCK) + m__ = Message.decode(received) + if not self._filter or self._filter(m__): + if self._translate: + url = urlsplit(self.sub_addr[sub]) + host = url[1].split(":")[0] + m__.sender = (m__.sender.split("@")[0] + + "@" + host) + yield m__ + + for sub in self._hooks: + if sub in socks and socks[sub] == POLLIN: + m__ = Message.decode(sub.recv_string(NOBLOCK)) + self._hooks_cb[sub](m__) + else: + # timeout + yield None + except ZMQError as err: + if self._loop: + LOGGER.exception("Receive failed: %s", str(err)) + + + def __call__(self, **kwargs): """Handle calls with class instance.""" return self.recv(**kwargs) @@ -207,51 +202,30 @@ def __del__(self): pass -class SecureZMQSubscriber: - """Secure ZMQ implementation of the subscriber, using Curve.""" - - def __init__(self, addresses, client_secret_key_file, server_public_key_file, topics="", message_filter=None, translate=False): - """Initialize the subscriber.""" - self._topics = topics - self._filter = message_filter - self._translate = translate - - self._client_secret_file = client_secret_key_file - self._server_public_key_file = server_public_key_file - - self.sub_addr = {} - self.addr_sub = {} - - self._hooks = [] - self._hooks_cb = {} - - self.poller = Poller() - self._lock = Lock() +class UnsecureZMQSubscriber(_ZMQSubscriber): + """Unsecure ZMQ implementation of the subscriber.""" - self.update(addresses) + def _add_sub_socket(self, address, topics): + subscriber = get_context().socket(SUB) + _set_tcp_keepalive(subscriber) + for t__ in topics: + subscriber.setsockopt_string(SUBSCRIBE, str(t__)) + subscriber.connect(address) - self._loop = None + if self.poller: + self.poller.register(subscriber, POLLIN) + return subscriber - @property - def running(self): - """Check if suscriber is running.""" - return self._loop - def add(self, address, topics=None): - """Add *address* to the subscribing list for *topics*. +class SecureZMQSubscriber(_ZMQSubscriber): + """Secure ZMQ implementation of the subscriber, using Curve.""" - It topics is None we will subscribe to already specified topics. - """ - with self._lock: - if address in self.addresses: - return + def __init__(self, addresses, client_secret_key_file, server_public_key_file, **kwargs): + """Initialize the subscriber.""" + self._client_secret_file = client_secret_key_file + self._server_public_key_file = server_public_key_file - topics = topics or self._topics - LOGGER.info("Subscriber adding address %s with topics %s", - str(address), str(topics)) - subscriber = self._add_sub_socket(address, topics) - self.sub_addr[subscriber] = address - self.addr_sub[address] = subscriber + super().__init__(addresses, **kwargs) def _add_sub_socket(self, address, topics): import zmq.auth @@ -274,143 +248,3 @@ def _add_sub_socket(self, address, topics): if self.poller: self.poller.register(subscriber, POLLIN) return subscriber - - def remove(self, address): - """Remove *address* from the subscribing list for *topics*.""" - with self._lock: - try: - subscriber = self.addr_sub[address] - except KeyError: - return - LOGGER.info("Subscriber removing address %s", str(address)) - del self.addr_sub[address] - del self.sub_addr[subscriber] - self._remove_sub_socket(subscriber) - - def _remove_sub_socket(self, subscriber): - if self.poller: - self.poller.unregister(subscriber) - subscriber.close() - - def update(self, addresses): - """Update with a set of addresses.""" - if isinstance(addresses, str): - addresses = [addresses, ] - current_addresses, new_addresses = set(self.addresses), set(addresses) - addresses_to_remove = current_addresses.difference(new_addresses) - addresses_to_add = new_addresses.difference(current_addresses) - for addr in addresses_to_remove: - self.remove(addr) - for addr in addresses_to_add: - self.add(addr) - return bool(addresses_to_remove or addresses_to_add) - - def add_hook_sub(self, address, topics, callback): - """Specify a SUB *callback* in the same stream (thread) as the main receive loop. - - The callback will be called with the received messages from the - specified subscription. - - Good for operations, which is required to be done in the same thread as - the main recieve loop (e.q operations on the underlying sockets). - """ - topics = topics - LOGGER.info("Subscriber adding SUB hook %s for topics %s", - str(address), str(topics)) - socket = self._add_sub_socket(address, topics) - self._add_hook(socket, callback) - - def add_hook_pull(self, address, callback): - """Specify a PULL *callback* in the same stream (thread) as the main receive loop. - - The callback will be called with the received messages from the - specified subscription. Good for pushed 'inproc' messages from another thread. - """ - LOGGER.info("Subscriber adding PULL hook %s", str(address)) - socket = get_context().socket(PULL) - socket.connect(address) - if self.poller: - self.poller.register(socket, POLLIN) - self._add_hook(socket, callback) - - def _add_hook(self, socket, callback): - """Add a generic hook. The passed socket has to be "receive only".""" - self._hooks.append(socket) - self._hooks_cb[socket] = callback - - - @property - def addresses(self): - """Get the addresses.""" - return self.sub_addr.values() - - @property - def subscribers(self): - """Get the subscribers.""" - return self.sub_addr.keys() - - def recv(self, timeout=None): - """Receive, optionally with *timeout* in seconds.""" - if timeout: - timeout *= 1000. - - for sub in list(self.subscribers) + self._hooks: - self.poller.register(sub, POLLIN) - self._loop = True - try: - while self._loop: - sleep(0) - try: - socks = dict(self.poller.poll(timeout=timeout)) - if socks: - for sub in self.subscribers: - if sub in socks and socks[sub] == POLLIN: - received = sub.recv_string(NOBLOCK) - m__ = Message.decode(received) - if not self._filter or self._filter(m__): - if self._translate: - url = urlsplit(self.sub_addr[sub]) - host = url[1].split(":")[0] - m__.sender = (m__.sender.split("@")[0] - + "@" + host) - yield m__ - - for sub in self._hooks: - if sub in socks and socks[sub] == POLLIN: - m__ = Message.decode(sub.recv_string(NOBLOCK)) - self._hooks_cb[sub](m__) - else: - # timeout - yield None - except ZMQError as err: - if self._loop: - LOGGER.exception("Receive failed: %s", str(err)) - finally: - for sub in list(self.subscribers) + self._hooks: - self.poller.unregister(sub) - - def __call__(self, **kwargs): - """Handle calls with class instance.""" - return self.recv(**kwargs) - - def stop(self): - """Stop the subscriber.""" - self._loop = False - - def close(self): - """Close the subscriber: stop it and close the local subscribers.""" - self.stop() - for sub in list(self.subscribers) + self._hooks: - try: - sub.setsockopt(LINGER, 1) - sub.close() - except ZMQError: - pass - - def __del__(self): - """Clean up after the instance is deleted.""" - for sub in list(self.subscribers) + self._hooks: - try: - sub.close() - except Exception: # noqa: E722 - pass diff --git a/posttroll/logger.py b/posttroll/logger.py index 5d3b321..2155a76 100644 --- a/posttroll/logger.py +++ b/posttroll/logger.py @@ -201,7 +201,7 @@ def run(): time.sleep(1) except KeyboardInterrupt: tlogger.stop() - print("Thanks for using pytroll/logger. See you soon on www.pytroll.org!") + print("Thanks for using pytroll/logger. See you soon on www.pytroll.org!") # noqa if __name__ == "__main__": diff --git a/posttroll/message.py b/posttroll/message.py index 315e030..541c0af 100644 --- a/posttroll/message.py +++ b/posttroll/message.py @@ -243,29 +243,10 @@ def datetime_decoder(dct): def _decode(rawstr): """Convert a raw string to a Message.""" - # Check for the magick word. - try: - rawstr = rawstr.decode("utf-8") - except (AttributeError, UnicodeEncodeError): - pass - except (UnicodeDecodeError): - try: - rawstr = rawstr.decode("iso-8859-1") - except (UnicodeDecodeError): - rawstr = rawstr.decode("utf-8", "ignore") - if not rawstr.startswith(_MAGICK): - raise MessageError("This is not a '%s' message (wrong magick word)" - % _MAGICK) - rawstr = rawstr[len(_MAGICK):] + rawstr = _check_for_magic_word(rawstr) - # Check for element count and version - raw = re.split(r"\s+", rawstr, maxsplit=6) - if len(raw) < 5: - raise MessageError("Could node decode raw string: '%s ...'" - % str(rawstr[:36])) - version = raw[4][:len(_VERSION)] - if not _is_valid_version(version): - raise MessageError("Invalid Message version: '%s'" % str(version)) + raw = _check_for_element_count(rawstr) + version = _check_for_version(raw) # Start to build message msg = dict((("subject", raw[0].strip()), @@ -301,6 +282,36 @@ def _decode(rawstr): return msg +def _check_for_version(raw): + version = raw[4][:len(_VERSION)] + if not _is_valid_version(version): + raise MessageError("Invalid Message version: '%s'" % str(version)) + return version + +def _check_for_element_count(rawstr): + raw = re.split(r"\s+", rawstr, maxsplit=6) + if len(raw) < 5: + raise MessageError("Could node decode raw string: '%s ...'" + % str(rawstr[:36])) + + return raw + +def _check_for_magic_word(rawstr): + """Check for the magick word.""" + try: + rawstr = rawstr.decode("utf-8") + except (AttributeError, UnicodeEncodeError): + pass + except (UnicodeDecodeError): + try: + rawstr = rawstr.decode("iso-8859-1") + except (UnicodeDecodeError): + rawstr = rawstr.decode("utf-8", "ignore") + if not rawstr.startswith(_MAGICK): + raise MessageError("This is not a '%s' message (wrong magick word)" + % _MAGICK) + return rawstr[len(_MAGICK):] + def datetime_encoder(obj): """Encode datetimes into iso format.""" diff --git a/posttroll/ns.py b/posttroll/ns.py index bf7acf8..9296dd2 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -29,6 +29,7 @@ import logging import os import time +import warnings from posttroll import config from posttroll.address_receiver import AddressReceiver @@ -41,7 +42,9 @@ logger = logging.getLogger(__name__) + def get_configured_nameserver_port(): + """Get the configured nameserver port.""" try: port = int(os.environ["NAMESERVER_PORT"]) warnings.warn("NAMESERVER_PORT is pending deprecation, please use POSTTROLL_NAMESERVER_PORT instead.", @@ -78,10 +81,11 @@ def get_pub_addresses(names=None, timeout=10, nameserver="localhost"): def get_pub_address(name, timeout=10, nameserver="localhost"): - """Get the address of the named publisher + """Get the address of the named publisher. Args: name: name of the publishers + timeout: how long to wait for an address, in seconds. nameserver: nameserver address to query the publishers from (default: localhost). """ backend = config.get("backend", "unsecure_zmq") diff --git a/posttroll/publisher.py b/posttroll/publisher.py index e085753..5d5abdc 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -100,7 +100,8 @@ def __init__(self, address, *args, name="", min_port=None, max_port=None, **kwar self._publisher = UnsecureZMQPublisher(address, name=name, min_port=min_port, max_port=max_port, **kwargs) elif backend == "secure_zmq": from posttroll.backends.zmq.publisher import SecureZMQPublisher - self._publisher = SecureZMQPublisher(address, *args, name=name, min_port=min_port, max_port=max_port, **kwargs) + self._publisher = SecureZMQPublisher(address, *args, name=name, min_port=min_port, max_port=max_port, + **kwargs) else: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 305bd27..ba900c9 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -72,6 +72,7 @@ def free_port(): @contextmanager def create_nameserver_instance(max_age=3, multicast_enabled=True): + """Create a nameserver instance.""" config.set(nameserver_port=free_port()) config.set(address_publish_port=free_port()) ns = NameServer(max_age=timedelta(seconds=max_age), multicast_enabled=multicast_enabled) diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index be11691..7cdc25b 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -1,4 +1,6 @@ +"""Test the curve-based zmq backend.""" + import os import shutil import time From a942d5440bb4d033d71e9d73b737f2542592e8a4 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 25 Apr 2024 15:32:03 +0200 Subject: [PATCH 30/45] Refactor shared parts between secure and unsecure zmq --- posttroll/backends/zmq/publisher.py | 58 +++++------------ posttroll/publisher.py | 6 +- posttroll/subscriber.py | 10 +-- posttroll/tests/test_pubsub.py | 32 +--------- posttroll/tests/test_unsecure_zmq_backend.py | 66 ++++++++++++++++++++ 5 files changed, 94 insertions(+), 78 deletions(-) create mode 100644 posttroll/tests/test_unsecure_zmq_backend.py diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py index 64e6e15..d7e48a4 100644 --- a/posttroll/backends/zmq/publisher.py +++ b/posttroll/backends/zmq/publisher.py @@ -5,6 +5,7 @@ from urllib.parse import urlsplit, urlunsplit import zmq +from zmq.auth.thread import ThreadAuthenticator from posttroll.backends.zmq import _set_tcp_keepalive, get_context @@ -15,7 +16,15 @@ class UnsecureZMQPublisher: """Unsecure ZMQ implementation of the publisher class.""" def __init__(self, address, name="", min_port=None, max_port=None): - """Bind the publisher class to a port.""" + """Set up the publisher. + + Args: + address: the address to connect to. + name: the name of this publishing service. + min_port: the minimal port number to use. + max_port: the maximal port number to use. + + """ self.name = name self.destination = address self.publish_socket = None @@ -61,11 +70,11 @@ def stop(self): self.publish_socket.setsockopt(zmq.LINGER, 1) self.publish_socket.close() -class SecureZMQPublisher: + +class SecureZMQPublisher(UnsecureZMQPublisher): """Secure ZMQ implementation of the publisher class.""" - def __init__(self, address, server_secret_key, public_keys_directory, name="", min_port=None, max_port=None, - authorized_sub_addresses=None): + def __init__(self, address, server_secret_key, public_keys_directory, authorized_sub_addresses=None, **kwargs): # noqa """Set up the secure ZMQ publisher. Args: @@ -73,32 +82,23 @@ def __init__(self, address, server_secret_key, public_keys_directory, name="", m server_secret_key: the secret key for this publisher. public_keys_directory: the directory containing the public keys of the subscribers that are allowed to connect. - name: the name of this publishing service. - min_port: the minimal port number to use. - max_port: the maximal port number to use. authorized_sub_addresses: the list of addresse allowed to subscibe to this publisher. By default, all are allowed. + kwargs: passed to the underlying UnsecureZMQPublisher instance. """ - self.name = name - self.destination = address - self.publish_socket = None - self.min_port = min_port - self.max_port = max_port - self.port_number = None - self._pub_lock = Lock() - self._server_secret_key = server_secret_key self._authorized_sub_addresses = authorized_sub_addresses or [] self._pub_keys_dir = public_keys_directory self._authenticator = None + super().__init__(address=address, **kwargs) + def start(self): """Start the publisher.""" ctx = get_context() # Start an authenticator for this context. - from zmq.auth.thread import ThreadAuthenticator auth = ThreadAuthenticator(ctx) auth.start() auth.allow(*self._authorized_sub_addresses) @@ -119,31 +119,7 @@ def start(self): LOGGER.info(f"Publisher for {self.destination} started on port {self.port_number}.") return self - def _bind(self): - # Check for port 0 (random port) - u__ = urlsplit(self.destination) - port = u__.port - if port == 0: - dest = urlunsplit((u__.scheme, u__.hostname, - u__.path, u__.query, u__.fragment)) - self.port_number = self.publish_socket.bind_to_random_port( - dest, - min_port=self.min_port, - max_port=self.max_port) - netloc = u__.hostname + ":" + str(self.port_number) - self.destination = urlunsplit((u__.scheme, netloc, u__.path, - u__.query, u__.fragment)) - else: - self.publish_socket.bind(self.destination) - self.port_number = port - - def send(self, msg): - """Send the given message.""" - with self._pub_lock: - self.publish_socket.send_string(msg) - def stop(self): """Stop the publisher.""" - self.publish_socket.setsockopt(zmq.LINGER, 1) - self.publish_socket.close() + super().stop() self._authenticator.stop() diff --git a/posttroll/publisher.py b/posttroll/publisher.py index 5d5abdc..e6e0875 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -96,8 +96,12 @@ def __init__(self, address, *args, name="", min_port=None, max_port=None, **kwar backend = config.get("backend", "unsecure_zmq") if backend == "unsecure_zmq": + if args: + raise TypeError(f"Unexpected arguments: {args}") + if kwargs: + raise TypeError(f"Unexpected keyword arguments: {kwargs}") from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - self._publisher = UnsecureZMQPublisher(address, name=name, min_port=min_port, max_port=max_port, **kwargs) + self._publisher = UnsecureZMQPublisher(address, name=name, min_port=min_port, max_port=max_port) elif backend == "secure_zmq": from posttroll.backends.zmq.publisher import SecureZMQPublisher self._publisher = SecureZMQPublisher(address, *args, name=name, min_port=min_port, max_port=max_port, diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index a59f590..4ced19b 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -65,6 +65,10 @@ def __init__(self, addresses, *args, topics="", message_filter=None, translate=F topics = self._magickfy_topics(topics) backend = config.get("backend", "unsecure_zmq") if backend == "unsecure_zmq": + if args: + raise TypeError(f"Unexpected arguments: {args}") + if kwargs: + raise TypeError(f"Unexpected keyword arguments: {kwargs}") from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber self._subscriber = UnsecureZMQSubscriber(addresses, topics=topics, message_filter=message_filter, translate=translate) @@ -350,14 +354,10 @@ def create_subscriber_from_dict_config(settings): def _get_subscriber_instance(settings): - addresses = settings.pop("addresses") - topics = settings.pop("topics", "") - message_filter = settings.pop("message_filter", None) - translate = settings.pop("translate", False) _ = settings.pop("nameserver", None) _ = settings.pop("port", None) - return Subscriber(addresses, topics=topics, message_filter=message_filter, translate=translate, **settings) + return Subscriber(**settings) def _get_nssubscriber_instance(settings): diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index ba900c9..81782ed 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -38,7 +38,7 @@ from posttroll import config from posttroll.ns import NameServer from posttroll.publisher import Publisher, create_publisher_from_dict_config -from posttroll.subscriber import Subscribe, Subscriber, create_subscriber_from_dict_config +from posttroll.subscriber import Subscribe, Subscriber test_lock = Lock() @@ -699,27 +699,6 @@ def test_noisypublisher_heartbeat(): thr.join() -def test_ipc_pubsub(): - """Test pub-sub on an ipc socket.""" - with config.set(backend="unsecure_zmq"): - subscriber_settings = dict(addresses="ipc://bla.ipc", topics="", nameserver=False, port=10202) - sub = create_subscriber_from_dict_config(subscriber_settings) - pub = Publisher("ipc://bla.ipc") - pub.start() - def delayed_send(msg): - time.sleep(.2) - from posttroll.message import Message - msg = Message(subject="/hi", atype="string", data=msg) - pub.send(str(msg)) - pub.stop() - from threading import Thread - Thread(target=delayed_send, args=["hi"]).start() - for msg in sub.recv(): - assert msg.data == "hi" - break - sub.stop() - - def test_switch_to_unknown_backend(): """Test switching to unknown backend.""" from posttroll.publisher import Publisher @@ -729,12 +708,3 @@ def test_switch_to_unknown_backend(): Publisher("ipc://bla.ipc") with pytest.raises(NotImplementedError): Subscriber("ipc://bla.ipc") - -def test_switch_to_unsecure_zmq_backend(): - """Test switching to the secure_zmq backend.""" - from posttroll.publisher import Publisher - from posttroll.subscriber import Subscriber - - with config.set(backend="unsecure_zmq"): - Publisher("ipc://bla.ipc") - Subscriber("ipc://bla.ipc") diff --git a/posttroll/tests/test_unsecure_zmq_backend.py b/posttroll/tests/test_unsecure_zmq_backend.py new file mode 100644 index 0000000..66dbd6e --- /dev/null +++ b/posttroll/tests/test_unsecure_zmq_backend.py @@ -0,0 +1,66 @@ +"""Tests for the unsecure zmq backend.""" + +import time + +import pytest + +from posttroll import config +from posttroll.publisher import Publisher, create_publisher_from_dict_config +from posttroll.subscriber import Subscriber, create_subscriber_from_dict_config + + +def test_ipc_pubsub(tmp_path): + """Test pub-sub on an ipc socket.""" + ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" + + with config.set(backend="unsecure_zmq"): + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202) + sub = create_subscriber_from_dict_config(subscriber_settings) + pub = Publisher(ipc_address) + pub.start() + def delayed_send(msg): + time.sleep(.2) + from posttroll.message import Message + msg = Message(subject="/hi", atype="string", data=msg) + pub.send(str(msg)) + pub.stop() + from threading import Thread + Thread(target=delayed_send, args=["hi"]).start() + for msg in sub.recv(): + assert msg.data == "hi" + break + sub.stop() + + +def test_switch_to_unsecure_zmq_backend(tmp_path): + """Test switching to the secure_zmq backend.""" + ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" + + with config.set(backend="unsecure_zmq"): + Publisher(ipc_address) + Subscriber(ipc_address) + + +def test_ipc_pub_crashes_when_passed_key_files(tmp_path): + """Test pub-sub on an ipc socket.""" + ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" + + with config.set(backend="unsecure_zmq"): + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202, + client_secret_key_file="my_secret_key", + server_public_key_file="server_public_key") + with pytest.raises(TypeError): + create_subscriber_from_dict_config(subscriber_settings) + + +def test_ipc_sub_crashes_when_passed_key_files(tmp_path): + """Test pub-sub on a secure ipc socket.""" + ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" + + with config.set(backend="unsecure_zmq"): + pub_settings = dict(address=ipc_address, + server_secret_key="server.key_secret", + public_keys_directory="public_keys_dir", + nameservers=False, port=1789) + with pytest.raises(TypeError): + create_publisher_from_dict_config(pub_settings) From 10ee5cc0fcbee8d0df5a0cdd97665bd9295115ed Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 25 Apr 2024 15:55:04 +0200 Subject: [PATCH 31/45] Fix subscriber settings dropping --- posttroll/subscriber.py | 3 +++ posttroll/tests/test_pubsub.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 4ced19b..6ec51a8 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -356,6 +356,9 @@ def create_subscriber_from_dict_config(settings): def _get_subscriber_instance(settings): _ = settings.pop("nameserver", None) _ = settings.pop("port", None) + _ = settings.pop("services", None) + _ = settings.pop("addr_listener", None), + _ = settings.pop("timeout", None) return Subscriber(**settings) diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 81782ed..5e7c153 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -224,7 +224,7 @@ def test_pub_suber(self): pub_address = "tcp://" + str(get_own_ip()) + ":0" pub = Publisher(pub_address).start() addr = pub_address[:-1] + str(pub.port_number) - sub = Subscriber([addr], "/counter") + sub = Subscriber([addr], topics="/counter") # wait a bit before sending the first message so that the subscriber is ready time.sleep(.002) From 7f6eb9d33d7d450107e56656ae00f7be9e5cc9c5 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 25 Apr 2024 16:00:54 +0200 Subject: [PATCH 32/45] Fix tests --- posttroll/tests/test_pubsub.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 5e7c153..1bb6574 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -146,7 +146,7 @@ def test_pub_sub_ctx(multicast_enabled): else: nameservers = ["localhost"] with Publish("data_provider", 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1) as pub: - with Subscribe("this_data", "counter") as sub: + with Subscribe("this_data", topics="counter") as sub: for counter in range(5): message = Message("/counter", "info", str(counter)) pub.send(str(message)) @@ -180,7 +180,7 @@ def test_pub_sub_add_rm(multicast_enabled): nameservers = ["localhost"] with create_nameserver_instance(max_age=max_age, multicast_enabled=multicast_enabled): - with Subscribe("this_data", "counter", True, timeout=.2) as sub: + with Subscribe("this_data", topics="counter", addr_listener=True, timeout=.2) as sub: assert len(sub.addresses) == 0 with Publish("data_provider", 0, ["this_data"], nameservers=nameservers): time.sleep(.1) From c43ff9e5b09881a6eb8c696932dbe6a921bab598 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 25 Apr 2024 19:01:37 +0200 Subject: [PATCH 33/45] Fix backwards compatibility --- posttroll/backends/zmq/publisher.py | 8 ++++++-- posttroll/backends/zmq/subscriber.py | 8 ++++++-- posttroll/publisher.py | 6 ++---- posttroll/subscriber.py | 8 +++----- posttroll/tests/test_pubsub.py | 12 ++++++------ posttroll/tests/test_secure_zmq_backend.py | 10 ++++++---- 6 files changed, 29 insertions(+), 23 deletions(-) diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py index d7e48a4..ffbf2d6 100644 --- a/posttroll/backends/zmq/publisher.py +++ b/posttroll/backends/zmq/publisher.py @@ -74,7 +74,7 @@ def stop(self): class SecureZMQPublisher(UnsecureZMQPublisher): """Secure ZMQ implementation of the publisher class.""" - def __init__(self, address, server_secret_key, public_keys_directory, authorized_sub_addresses=None, **kwargs): # noqa + def __init__(self, *args, server_secret_key=None, public_keys_directory=None, authorized_sub_addresses=None, **kwargs): # noqa """Set up the secure ZMQ publisher. Args: @@ -87,12 +87,16 @@ def __init__(self, address, server_secret_key, public_keys_directory, authorized kwargs: passed to the underlying UnsecureZMQPublisher instance. """ + if server_secret_key is None: + raise TypeError("Missing server_secret_key argument.") + if public_keys_directory is None: + raise TypeError("Missing public_keys_directory argument.") self._server_secret_key = server_secret_key self._authorized_sub_addresses = authorized_sub_addresses or [] self._pub_keys_dir = public_keys_directory self._authenticator = None - super().__init__(address=address, **kwargs) + super().__init__(*args, **kwargs) def start(self): """Start the publisher.""" diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index 313f041..5b04e4d 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -220,12 +220,16 @@ def _add_sub_socket(self, address, topics): class SecureZMQSubscriber(_ZMQSubscriber): """Secure ZMQ implementation of the subscriber, using Curve.""" - def __init__(self, addresses, client_secret_key_file, server_public_key_file, **kwargs): + def __init__(self, *args, client_secret_key_file=None, server_public_key_file=None, **kwargs): """Initialize the subscriber.""" + if client_secret_key_file is None: + raise TypeError("Missing client_secret_key_file argument.") + if server_public_key_file is None: + raise TypeError("Missing server_public_key_file argument.") self._client_secret_file = client_secret_key_file self._server_public_key_file = server_public_key_file - super().__init__(addresses, **kwargs) + super().__init__(*args, **kwargs) def _add_sub_socket(self, address, topics): import zmq.auth diff --git a/posttroll/publisher.py b/posttroll/publisher.py index e6e0875..51a89c9 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -85,7 +85,7 @@ class Publisher: """ - def __init__(self, address, *args, name="", min_port=None, max_port=None, **kwargs): + def __init__(self, address, name="", min_port=None, max_port=None, **kwargs): """Bind the publisher class to a port.""" # Limit port range or use the defaults when no port is defined # by the user @@ -96,15 +96,13 @@ def __init__(self, address, *args, name="", min_port=None, max_port=None, **kwar backend = config.get("backend", "unsecure_zmq") if backend == "unsecure_zmq": - if args: - raise TypeError(f"Unexpected arguments: {args}") if kwargs: raise TypeError(f"Unexpected keyword arguments: {kwargs}") from posttroll.backends.zmq.publisher import UnsecureZMQPublisher self._publisher = UnsecureZMQPublisher(address, name=name, min_port=min_port, max_port=max_port) elif backend == "secure_zmq": from posttroll.backends.zmq.publisher import SecureZMQPublisher - self._publisher = SecureZMQPublisher(address, *args, name=name, min_port=min_port, max_port=max_port, + self._publisher = SecureZMQPublisher(address, name=name, min_port=min_port, max_port=max_port, **kwargs) else: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 6ec51a8..5528fa4 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -60,13 +60,11 @@ class Subscriber: """ - def __init__(self, addresses, *args, topics="", message_filter=None, translate=False, **kwargs): + def __init__(self, addresses, topics="", message_filter=None, translate=False, **kwargs): """Initialize the subscriber.""" topics = self._magickfy_topics(topics) backend = config.get("backend", "unsecure_zmq") if backend == "unsecure_zmq": - if args: - raise TypeError(f"Unexpected arguments: {args}") if kwargs: raise TypeError(f"Unexpected keyword arguments: {kwargs}") from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber @@ -74,8 +72,8 @@ def __init__(self, addresses, *args, topics="", message_filter=None, translate=F message_filter=message_filter, translate=translate) elif backend == "secure_zmq": from posttroll.backends.zmq.subscriber import SecureZMQSubscriber - self._subscriber = SecureZMQSubscriber(addresses, *args, topics=topics, - message_filter=message_filter, translate=translate, **kwargs) + self._subscriber = SecureZMQSubscriber(addresses, topics=topics, + message_filter=message_filter, translate=translate, **kwargs) else: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 1bb6574..eeecba7 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -126,10 +126,10 @@ def test_pub_addresses(multicast_enabled): assert "receive_time" in res[0] assert "URI" in res[0] -@pytest.mark.skipif( - os.getenv("DISABLED_MULTICAST"), - reason="Multicast tests disabled.", -) +# @pytest.mark.skipif( +# os.getenv("DISABLED_MULTICAST"), +# reason="Multicast tests disabled.", +# ) @pytest.mark.parametrize( "multicast_enabled", [True, False] @@ -146,7 +146,7 @@ def test_pub_sub_ctx(multicast_enabled): else: nameservers = ["localhost"] with Publish("data_provider", 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1) as pub: - with Subscribe("this_data", topics="counter") as sub: + with Subscribe("this_data", "counter") as sub: for counter in range(5): message = Message("/counter", "info", str(counter)) pub.send(str(message)) @@ -180,7 +180,7 @@ def test_pub_sub_add_rm(multicast_enabled): nameservers = ["localhost"] with create_nameserver_instance(max_age=max_age, multicast_enabled=multicast_enabled): - with Subscribe("this_data", topics="counter", addr_listener=True, timeout=.2) as sub: + with Subscribe("this_data", "counter", addr_listener=True, timeout=.2) as sub: assert len(sub.addresses) == 0 with Publish("data_provider", 0, ["this_data"], nameservers=nameservers): time.sleep(.1) diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index 7cdc25b..4c6a903 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -94,15 +94,17 @@ def test_switch_to_secure_zmq_backend(tmp_path): server_secret_key = secret_keys_dir / "server.key_secret" public_keys_directory = public_keys_dir - publisher_key_args = (server_secret_key, public_keys_directory) + publisher_key_args = dict(server_secret_key=server_secret_key, + public_keys_directory=public_keys_directory) client_secret_key = secret_keys_dir / "client.key_secret" server_public_key = public_keys_dir / "server.key" - subscriber_key_args = (client_secret_key, server_public_key) + subscriber_key_args = dict(client_secret_key_file=client_secret_key, + server_public_key_file=server_public_key) with config.set(backend="secure_zmq"): - Publisher("ipc://bla.ipc", *publisher_key_args) - Subscriber("ipc://bla.ipc", *subscriber_key_args) + Publisher("ipc://bla.ipc", **publisher_key_args) + Subscriber("ipc://bla.ipc", **subscriber_key_args) def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): From 24b2c9af68826eeffc81af62cd2d2bb2cc0bd811 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Fri, 26 Apr 2024 09:50:48 +0200 Subject: [PATCH 34/45] Do not skip too much --- posttroll/tests/test_pubsub.py | 43 ++++++++++++++-------------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index eeecba7..00b45eb 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -36,8 +36,9 @@ import posttroll from posttroll import config +from posttroll.message import Message from posttroll.ns import NameServer -from posttroll.publisher import Publisher, create_publisher_from_dict_config +from posttroll.publisher import Publish, Publisher, create_publisher_from_dict_config from posttroll.subscriber import Subscribe, Subscriber test_lock = Lock() @@ -85,10 +86,7 @@ def create_nameserver_instance(max_age=3, multicast_enabled=True): ns.stop() thr.join() -@pytest.mark.skipif( - os.getenv("DISABLED_MULTICAST"), - reason="Multicast tests disabled.", -) + @pytest.mark.parametrize( "multicast_enabled", [True, False] @@ -99,6 +97,8 @@ def test_pub_addresses(multicast_enabled): from posttroll.publisher import Publish if multicast_enabled: + if os.getenv("DISABLED_MULTICAST"): + pytest.skip("Multicast tests disabled.") nameservers = None else: nameservers = ["localhost"] @@ -126,25 +126,21 @@ def test_pub_addresses(multicast_enabled): assert "receive_time" in res[0] assert "URI" in res[0] -# @pytest.mark.skipif( -# os.getenv("DISABLED_MULTICAST"), -# reason="Multicast tests disabled.", -# ) + @pytest.mark.parametrize( "multicast_enabled", [True, False] ) def test_pub_sub_ctx(multicast_enabled): """Test publish and subscribe.""" - from posttroll.message import Message - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe + if multicast_enabled: + if os.getenv("DISABLED_MULTICAST"): + pytest.skip("Multicast tests disabled.") + nameservers = None + else: + nameservers = ["localhost"] with create_nameserver_instance(multicast_enabled=multicast_enabled): - if multicast_enabled: - nameservers = None - else: - nameservers = ["localhost"] with Publish("data_provider", 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1) as pub: with Subscribe("this_data", "counter") as sub: for counter in range(5): @@ -159,26 +155,21 @@ def test_pub_sub_ctx(multicast_enabled): assert tested -@pytest.mark.skipif( - os.getenv("DISABLED_MULTICAST"), - reason="Multicast tests disabled.", -) @pytest.mark.parametrize( "multicast_enabled", [True, False] ) def test_pub_sub_add_rm(multicast_enabled): """Test adding and removing publishers.""" - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe - - max_age = 0.5 - if multicast_enabled: + if os.getenv("DISABLED_MULTICAST"): + pytest.skip("Multicast tests disabled.") nameservers = None else: nameservers = ["localhost"] + max_age = 0.5 + with create_nameserver_instance(max_age=max_age, multicast_enabled=multicast_enabled): with Subscribe("this_data", "counter", addr_listener=True, timeout=.2) as sub: assert len(sub.addresses) == 0 @@ -190,7 +181,7 @@ def test_pub_sub_add_rm(multicast_enabled): for msg in sub.recv(.1): if msg is None: break - time.sleep(.1) + time.sleep(.3) assert len(sub.addresses) == 0 with Publish("data_provider_2", 0, ["another_data"], nameservers=nameservers): time.sleep(.1) From fb303540f90296498e7e2b3a81152670bc80033d Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Fri, 26 Apr 2024 09:57:58 +0200 Subject: [PATCH 35/45] Fix wait time --- .github/workflows/ci.yaml | 2 +- posttroll/tests/test_pubsub.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0dc2507..924a954 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,7 +21,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install -U pytest pytest-cov pyzmq netifaces donfig pytest-reraise + pip install -U pytest pytest-cov pyzmq netifaces-plus donfig pytest-reraise - name: Install posttroll run: | pip install --no-deps -e . diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 00b45eb..2f88943 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -177,7 +177,7 @@ def test_pub_sub_add_rm(multicast_enabled): time.sleep(.1) next(sub.recv(.1)) assert len(sub.addresses) == 1 - time.sleep(max_age * 2) + time.sleep(max_age * 4) for msg in sub.recv(.1): if msg is None: break @@ -209,9 +209,7 @@ def test_pub_address_timeout(self): def test_pub_suber(self): """Test publisher and subscriber.""" - from posttroll.message import Message from posttroll.publisher import get_own_ip - from posttroll.subscriber import Subscriber pub_address = "tcp://" + str(get_own_ip()) + ":0" pub = Publisher(pub_address).start() addr = pub_address[:-1] + str(pub.port_number) @@ -234,9 +232,6 @@ def test_pub_suber(self): def test_pub_sub_ctx_no_nameserver(self): """Test publish and subscribe.""" - from posttroll.message import Message - from posttroll.publisher import Publish - with Publish("data_provider", 40000, nameservers=False) as pub: with Subscribe(topics="counter", nameserver=False, addresses=["tcp://127.0.0.1:40000"]) as sub: assert isinstance(sub, Subscriber) From 7bc00e66d9884cd1befbe109afa671c4b430090e Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Fri, 3 May 2024 15:52:08 +0200 Subject: [PATCH 36/45] Refactor --- posttroll/__init__.py | 4 +- posttroll/address_receiver.py | 28 +- posttroll/backends/zmq/__init__.py | 55 +++- posttroll/backends/zmq/address_receiver.py | 24 +- posttroll/backends/zmq/message_broadcaster.py | 21 +- posttroll/backends/zmq/ns.py | 108 ++++---- posttroll/backends/zmq/publisher.py | 97 +------ posttroll/backends/zmq/socket.py | 106 ++++++++ posttroll/backends/zmq/subscriber.py | 120 +++------ posttroll/bbmcast.py | 10 +- posttroll/message_broadcaster.py | 15 +- posttroll/ns.py | 17 +- posttroll/publisher.py | 15 +- posttroll/subscriber.py | 18 +- posttroll/tests/test_nameserver.py | 253 ++++++++++++++++++ posttroll/tests/test_pubsub.py | 219 +-------------- posttroll/tests/test_secure_zmq_backend.py | 67 +++-- 17 files changed, 658 insertions(+), 519 deletions(-) create mode 100644 posttroll/backends/zmq/socket.py create mode 100644 posttroll/tests/test_nameserver.py diff --git a/posttroll/__init__.py b/posttroll/__init__.py index aece644..46b0eaf 100644 --- a/posttroll/__init__.py +++ b/posttroll/__init__.py @@ -30,7 +30,7 @@ from donfig import Config -config = Config("posttroll") +config = Config("posttroll", defaults=[dict(backend="unsecure_zmq")]) # context = {} logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def get_context(): This function takes care of creating new contexts in case of forks. """ - backend = config.get("backend", "unsecure_zmq") + backend = config["backend"] if "zmq" in backend: from posttroll.backends.zmq import get_context return get_context() diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index 2f0a24d..d2cef04 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -37,16 +37,16 @@ import netifaces from posttroll import config -from posttroll.bbmcast import MulticastReceiver, SocketTimeout +from posttroll.bbmcast import MulticastReceiver, SocketTimeout, get_configured_broadcast_port from posttroll.message import Message from posttroll.publisher import Publish +from zmq import ZMQError __all__ = ("AddressReceiver", "getaddress") LOGGER = logging.getLogger(__name__) debug = os.environ.get("DEBUG", False) -broadcast_port = 21200 DEFAULT_ADDRESS_PUBLISH_PORT = 16543 @@ -144,13 +144,13 @@ def _check_age(self, pub, min_interval=zero_seconds): msg = Message("/address/" + metadata["name"], "info", mda) to_del.append(addr) LOGGER.info(f"publish remove '{msg}'") - pub.send(msg.encode()) + pub.send(str(msg.encode())) for addr in to_del: del self._addresses[addr] def _run(self): """Run the receiver.""" - port = broadcast_port + port = get_configured_broadcast_port() nameservers, recv = self.set_up_address_receiver(port) self._is_running = True @@ -159,7 +159,16 @@ def _run(self): try: while self._do_run: try: - data, fromaddr = recv() + rerun = True + while rerun: + try: + data, fromaddr = recv() + rerun = False + except TimeoutError: + if self._do_run: + continue + else: + raise if self._multicast_enabled: ip_, port = fromaddr if self._restrict_to_localhost and ip_ not in self._local_ips: @@ -171,6 +180,8 @@ def _run(self): if self._multicast_enabled: LOGGER.debug("Multicast socket timed out on recv!") continue + except ZMQError: + return finally: self._check_age(pub, min_interval=self._max_age / 20) if self._do_heartbeat: @@ -216,9 +227,10 @@ def set_up_address_receiver(self, port): break else: - if config.get("backend", "unsecure_zmq") == "unsecure_zmq": - from posttroll.backends.zmq.address_receiver import SimpleReceiver - recv = SimpleReceiver(port) + if config["backend"] not in ["unsecure_zmq", "secure_zmq"]: + raise NotImplementedError + from posttroll.backends.zmq.address_receiver import SimpleReceiver + recv = SimpleReceiver(port, timeout=2) nameservers = ["localhost"] return nameservers,recv diff --git a/posttroll/backends/zmq/__init__.py b/posttroll/backends/zmq/__init__.py index 17a60f9..2cd6597 100644 --- a/posttroll/backends/zmq/__init__.py +++ b/posttroll/backends/zmq/__init__.py @@ -5,6 +5,7 @@ import zmq from posttroll import config +from posttroll.message import Message logger = logging.getLogger(__name__) context = {} @@ -21,14 +22,54 @@ def get_context(): logger.debug("renewed context for PID %d", pid) return context[pid] +def destroy_context(linger=None): + pid = os.getpid() + context.pop(pid).destroy(linger) + def _set_tcp_keepalive(socket): """Set the tcp keepalive parameters on *socket*.""" - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE, config.get("tcp_keepalive", None)) - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_CNT, config.get("tcp_keepalive_cnt", None)) - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_IDLE, config.get("tcp_keepalive_idle", None)) - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_INTVL, config.get("tcp_keepalive_intvl", None)) + keepalive_options = get_tcp_keepalive_options() + for param, value in keepalive_options.items(): + socket.setsockopt(param, value) + +def get_tcp_keepalive_options(): + """Get the tcp_keepalive options from config.""" + keepalive_options = dict() + for opt in ("tcp_keepalive", + "tcp_keepalive_cnt", + "tcp_keepalive_idle", + "tcp_keepalive_intvl"): + try: + value = int(config[opt]) + except (KeyError, TypeError): + continue + param = getattr(zmq, opt.upper()) + keepalive_options[param] = value + return keepalive_options + + +class SocketReceiver: + + def __init__(self): + self._poller = zmq.Poller() + + def register(self, socket): + """Register the socket.""" + self._poller.register(socket, zmq.POLLIN) + def unregister(self, socket): + """Unregister the socket.""" + self._poller.unregister(socket) -def _set_int_sockopt(socket, param, value): - if value is not None: - socket.setsockopt(param, int(value)) + def receive(self, *sockets, timeout=None): + """Timeout is in seconds.""" + if timeout: + timeout *= 1000 + socks = dict(self._poller.poll(timeout=timeout)) + if socks: + for sock in sockets: + if socks.get(sock) == zmq.POLLIN: + received = sock.recv_string(zmq.NOBLOCK) + yield Message.decode(received), sock + else: + raise TimeoutError("Did not receive anything on sockets.") diff --git a/posttroll/backends/zmq/address_receiver.py b/posttroll/backends/zmq/address_receiver.py index 8eb22f6..f926747 100644 --- a/posttroll/backends/zmq/address_receiver.py +++ b/posttroll/backends/zmq/address_receiver.py @@ -3,25 +3,35 @@ from zmq import LINGER, REP from posttroll.address_receiver import get_configured_address_port -from posttroll.backends.zmq import get_context +from posttroll.backends.zmq.socket import set_up_server_socket class SimpleReceiver(object): """Simple listing on port for address messages.""" - def __init__(self, port=None): + def __init__(self, port=None, timeout=2): """Set up the receiver.""" self._port = port or get_configured_address_port() - self._socket = get_context().socket(REP) - self._socket.bind("tcp://*:" + str(port)) + address = "tcp://*:" + str(port) + self._socket, _, self._authenticator = set_up_server_socket(REP, address) + self._running = True + self.timeout = timeout def __call__(self): """Receive a message.""" - message = self._socket.recv_string() - self._socket.send_string("ok") - return message, None + while self._running: + try: + message = self._socket.recv_string(self.timeout) + except TimeoutError: + continue + else: + self._socket.send_string("ok") + return message, None def close(self): """Close the receiver.""" + self._running = False self._socket.setsockopt(LINGER, 1) self._socket.close() + if self._authenticator: + self._authenticator.stop() diff --git a/posttroll/backends/zmq/message_broadcaster.py b/posttroll/backends/zmq/message_broadcaster.py index 060e9ae..fe2ddfe 100644 --- a/posttroll/backends/zmq/message_broadcaster.py +++ b/posttroll/backends/zmq/message_broadcaster.py @@ -3,20 +3,19 @@ import logging import threading +from posttroll.backends.zmq.socket import set_up_client_socket from zmq import LINGER, NOBLOCK, REQ, ZMQError -from posttroll.backends.zmq import get_context logger = logging.getLogger(__name__) -class UnsecureZMQDesignatedReceiversSender: +class ZMQDesignatedReceiversSender: """Sends message to multiple *receivers* on *port*.""" def __init__(self, default_port, receivers): """Set up the sender.""" self.default_port = default_port - self.receivers = receivers self._shutdown_event = threading.Event() @@ -28,13 +27,14 @@ def __call__(self, data): def _send_to_address(self, address, data, timeout=10): """Send data to *address* and *port* without verification of response.""" # Socket to talk to server - socket = get_context().socket(REQ) + if address.find(":") == -1: + full_address = "tcp://%s:%d" % (address, self.default_port) + else: + full_address = "tcp://%s" % address + options = {LINGER: int(timeout * 1000)} + socket = set_up_client_socket(REQ, full_address, options) try: - socket.setsockopt(LINGER, timeout * 1000) - if address.find(":") == -1: - socket.connect("tcp://%s:%d" % (address, self.default_port)) - else: - socket.connect("tcp://%s" % address) + socket.send_string(data) while not self._shutdown_event.is_set(): try: @@ -43,10 +43,11 @@ def _send_to_address(self, address, data, timeout=10): self._shutdown_event.wait(.1) continue if message != "ok": - logger.warn("invalid acknowledge received: %s" % message) + logger.warning("invalid acknowledge received: %s" % message) break finally: + socket.setsockopt(LINGER, 1) socket.close() def close(self): diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py index 4f7214c..3325272 100644 --- a/posttroll/backends/zmq/ns.py +++ b/posttroll/backends/zmq/ns.py @@ -1,11 +1,13 @@ """ZMQ implexentation of ns.""" import logging +from contextlib import suppress from threading import Lock -from zmq import LINGER, NOBLOCK, POLLIN, REP, REQ, Poller +from posttroll.backends.zmq.socket import set_up_client_socket, set_up_server_socket +from zmq import LINGER, REP, REQ +from posttroll.backends.zmq import SocketReceiver -from posttroll.backends.zmq import get_context from posttroll.message import Message from posttroll.ns import get_active_address, get_configured_nameserver_port @@ -14,78 +16,94 @@ nslock = Lock() -def unsecure_zmq_get_pub_address(name, timeout=10, nameserver="localhost"): +def zmq_get_pub_address(name, timeout=10, nameserver="localhost"): """Get the address of the publisher. For a given publisher *name* from the nameserver on *nameserver* (localhost by default). """ + nameserver_address = create_nameserver_address(nameserver) # Socket to talk to server - socket = get_context().socket(REQ) + logger.debug(f"Connecting to {nameserver_address}") + socket = create_req_socket(timeout, nameserver_address) + return _fetch_address_using_socket(socket, name, timeout) + + +def create_nameserver_address(nameserver): + port = get_configured_nameserver_port() + nameserver_address = "tcp://" + nameserver + ":" + str(port) + return nameserver_address + + +def _fetch_address_using_socket(socket, name, timeout): try: - port = get_configured_nameserver_port() - socket.setsockopt(LINGER, int(timeout * 1000)) - socket.connect("tcp://" + nameserver + ":" + str(port)) - logger.debug("Connecting to %s", - "tcp://" + nameserver + ":" + str(port)) - poller = Poller() - poller.register(socket, POLLIN) + socket_receiver = SocketReceiver() + socket_receiver.register(socket) message = Message("/oper/ns", "request", {"service": name}) socket.send_string(str(message)) # Get the reply. - sock = poller.poll(timeout=timeout * 1000) - if sock: - if sock[0][0] == socket: - message = Message.decode(socket.recv_string(NOBLOCK)) - return message.data - else: - raise TimeoutError("Didn't get an address after %d seconds." - % timeout) + #socket.poll(timeout) + #message = socket.recv(timeout) + for message, _ in socket_receiver.receive(socket, timeout=timeout): + return message.data + except TimeoutError: + raise TimeoutError("Didn't get an address after %d seconds." + % timeout) finally: + socket_receiver.unregister(socket) + socket.setsockopt(LINGER, 1) socket.close() +def create_req_socket(timeout, nameserver_address): + options = {LINGER: int(timeout * 1000)} + socket = set_up_client_socket(REQ, nameserver_address, options) + return socket -class UnsecureZMQNameServer: +class ZMQNameServer: """The name server.""" def __init__(self): """Set up the nameserver.""" - self.loop = True + self.running = True self.listener = None - def run(self, arec): + def run(self, address_receiver): """Run the listener and answer to requests.""" port = get_configured_nameserver_port() try: - with nslock: - self.listener = get_context().socket(REP) - self.listener.bind("tcp://*:" + str(port)) - logger.debug(f"Nameserver listening on port {port}") - poller = Poller() - poller.register(self.listener, POLLIN) - while self.loop: - with nslock: - socks = dict(poller.poll(1000)) - if socks: - if socks.get(self.listener) == POLLIN: - msg = self.listener.recv_string() - else: - continue - logger.debug("Replying to request: " + str(msg)) - msg = Message.decode(msg) - active_address = get_active_address(msg.data["service"], arec) - self.listener.send_unicode(str(active_address)) + # stop was called before we could start running, exit + if not self.running: + return + address = "tcp://*:" + str(port) + self.listener, _, self._authenticator = set_up_server_socket(REP, address) + logger.debug(f"Nameserver listening on port {port}") + socket_receiver = SocketReceiver() + socket_receiver.register(self.listener) + while self.running: + try: + for msg, _ in socket_receiver.receive(self.listener, timeout=1): + logger.debug("Replying to request: " + str(msg)) + active_address = get_active_address(msg.data["service"], address_receiver) + self.listener.send_unicode(str(active_address)) + except TimeoutError: + continue except KeyboardInterrupt: # Needed to stop the nameserver. pass finally: - self.stop() + socket_receiver.unregister(self.listener) + self.close_sockets_and_threads() + + def close_sockets_and_threads(self): + with suppress(AttributeError): + self.listener.setsockopt(LINGER, 1) + self.listener.close() + with suppress(AttributeError): + self._authenticator.stop() + def stop(self): """Stop the name server.""" - self.listener.setsockopt(LINGER, 1) - self.loop = False - with nslock: - self.listener.close() + self.running = False diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py index ffbf2d6..37a4898 100644 --- a/posttroll/backends/zmq/publisher.py +++ b/posttroll/backends/zmq/publisher.py @@ -1,18 +1,18 @@ """ZMQ implementation of the publisher.""" +from contextlib import suppress import logging from threading import Lock -from urllib.parse import urlsplit, urlunsplit +from posttroll.backends.zmq.socket import set_up_server_socket import zmq -from zmq.auth.thread import ThreadAuthenticator -from posttroll.backends.zmq import _set_tcp_keepalive, get_context +from posttroll.backends.zmq import get_tcp_keepalive_options LOGGER = logging.getLogger(__name__) -class UnsecureZMQPublisher: +class ZMQPublisher: """Unsecure ZMQ implementation of the publisher class.""" def __init__(self, address, name="", min_port=None, max_port=None): @@ -32,33 +32,20 @@ def __init__(self, address, name="", min_port=None, max_port=None): self.max_port = max_port self.port_number = None self._pub_lock = Lock() + self._authenticator = None def start(self): """Start the publisher.""" - self.publish_socket = get_context().socket(zmq.PUB) - _set_tcp_keepalive(self.publish_socket) - - self._bind() + self._create_socket() LOGGER.info(f"Publisher for {self.destination} started on port {self.port_number}.") + return self - def _bind(self): - # Check for port 0 (random port) - u__ = urlsplit(self.destination) - port = u__.port - if port == 0: - dest = urlunsplit((u__.scheme, u__.hostname, - u__.path, u__.query, u__.fragment)) - self.port_number = self.publish_socket.bind_to_random_port( - dest, - min_port=self.min_port, - max_port=self.max_port) - netloc = u__.hostname + ":" + str(self.port_number) - self.destination = urlunsplit((u__.scheme, netloc, u__.path, - u__.query, u__.fragment)) - else: - self.publish_socket.bind(self.destination) - self.port_number = port + def _create_socket(self): + options = get_tcp_keepalive_options() + self.publish_socket, port, self._authenticator = set_up_server_socket(zmq.PUB, self.destination, options, + (self.min_port, self.max_port)) + self.port_number = port def send(self, msg): """Send the given message.""" @@ -69,61 +56,5 @@ def stop(self): """Stop the publisher.""" self.publish_socket.setsockopt(zmq.LINGER, 1) self.publish_socket.close() - - -class SecureZMQPublisher(UnsecureZMQPublisher): - """Secure ZMQ implementation of the publisher class.""" - - def __init__(self, *args, server_secret_key=None, public_keys_directory=None, authorized_sub_addresses=None, **kwargs): # noqa - """Set up the secure ZMQ publisher. - - Args: - address: the address to connect to. - server_secret_key: the secret key for this publisher. - public_keys_directory: the directory containing the public keys of the subscribers that are allowed to - connect. - authorized_sub_addresses: the list of addresse allowed to subscibe to this publisher. By default, all are - allowed. - kwargs: passed to the underlying UnsecureZMQPublisher instance. - - """ - if server_secret_key is None: - raise TypeError("Missing server_secret_key argument.") - if public_keys_directory is None: - raise TypeError("Missing public_keys_directory argument.") - self._server_secret_key = server_secret_key - self._authorized_sub_addresses = authorized_sub_addresses or [] - self._pub_keys_dir = public_keys_directory - self._authenticator = None - - super().__init__(*args, **kwargs) - - def start(self): - """Start the publisher.""" - ctx = get_context() - - # Start an authenticator for this context. - auth = ThreadAuthenticator(ctx) - auth.start() - auth.allow(*self._authorized_sub_addresses) - # Tell authenticator to use the certificate in a directory - auth.configure_curve(domain="*", location=self._pub_keys_dir) - self._authenticator = auth - - self.publish_socket = ctx.socket(zmq.PUB) - - server_public, server_secret =zmq.auth.load_certificate(self._server_secret_key) - self.publish_socket.curve_secretkey = server_secret - self.publish_socket.curve_publickey = server_public - self.publish_socket.curve_server = True - - _set_tcp_keepalive(self.publish_socket) - - self._bind() - LOGGER.info(f"Publisher for {self.destination} started on port {self.port_number}.") - return self - - def stop(self): - """Stop the publisher.""" - super().stop() - self._authenticator.stop() + with suppress(AttributeError): + self._authenticator.stop() diff --git a/posttroll/backends/zmq/socket.py b/posttroll/backends/zmq/socket.py new file mode 100644 index 0000000..16ae3da --- /dev/null +++ b/posttroll/backends/zmq/socket.py @@ -0,0 +1,106 @@ +from posttroll import get_context, config +import zmq +from zmq.auth.thread import ThreadAuthenticator +from urllib.parse import urlsplit, urlunsplit + + + +def set_up_client_socket(socket_type, address, options=None): + backend = config["backend"] + if backend == "unsecure_zmq": + sock = create_unsecure_client_socket(socket_type) + elif backend == "secure_zmq": + sock = create_secure_client_socket(socket_type) + add_options(sock, options) + sock.connect(address) + return sock + + +def create_unsecure_client_socket(socket_type): + return get_context().socket(socket_type) + + +def add_options(sock, options=None): + if not options: + return + for param, val in options.items(): + sock.setsockopt(param, val) + + +def create_secure_client_socket(socket_type): + subscriber = get_context().socket(socket_type) + + client_secret_key_file = config["client_secret_key_file"] + server_public_key_file = config["server_public_key_file"] + client_public, client_secret = zmq.auth.load_certificate(client_secret_key_file) + subscriber.curve_secretkey = client_secret + subscriber.curve_publickey = client_public + + server_public, _ = zmq.auth.load_certificate(server_public_key_file) + # The client must know the server's public key to make a CURVE connection. + subscriber.curve_serverkey = server_public + return subscriber + + +def set_up_server_socket(socket_type, destination, options=None, port_interval=(None, None)): + if options is None: + options = {} + backend = config["backend"] + if backend == "unsecure_zmq": + sock = create_unsecure_server_socket(socket_type) + authenticator = None + elif backend == "secure_zmq": + sock, authenticator = create_secure_server_socket(socket_type) + + add_options(sock, options) + + port = bind(sock, destination, port_interval) + return sock, port, authenticator + + +def create_unsecure_server_socket(socket_type): + return get_context().socket(socket_type) + + +def bind(sock, destination, port_interval): + # Check for port 0 (random port) + min_port, max_port = port_interval + u__ = urlsplit(destination) + port = u__.port + if port == 0: + dest = urlunsplit((u__.scheme, u__.hostname, + u__.path, u__.query, u__.fragment)) + port_number = sock.bind_to_random_port(dest, + min_port=min_port, + max_port=max_port) + netloc = u__.hostname + ":" + str(port_number) + destination = urlunsplit((u__.scheme, netloc, u__.path, + u__.query, u__.fragment)) + else: + sock.bind(destination) + port_number = port + return port_number + + +def create_secure_server_socket(socket_type): + server_secret_key = config["server_secret_key_file"] + clients_public_keys_directory = config["clients_public_keys_directory"] + authorized_sub_addresses = config.get("authorized_client_addresses", []) + + ctx = get_context() + + # Start an authenticator for this context. + authenticator_thread = ThreadAuthenticator(ctx) + authenticator_thread.start() + authenticator_thread.allow(*authorized_sub_addresses) + # Tell authenticator to use the certificate in a directory + authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory) + + + server_socket = ctx.socket(socket_type) + + server_public, server_secret =zmq.auth.load_certificate(server_secret_key) + server_socket.curve_secretkey = server_secret + server_socket.curve_publickey = server_public + server_socket.curve_server = True + return server_socket, authenticator_thread diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index 5b04e4d..8186f69 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -5,15 +5,15 @@ from time import sleep from urllib.parse import urlsplit -from zmq import LINGER, NOBLOCK, POLLIN, PULL, SUB, SUBSCRIBE, Poller, ZMQError +from zmq import LINGER, PULL, SUB, SUBSCRIBE, ZMQError +from posttroll.backends.zmq.socket import set_up_client_socket -from posttroll.backends.zmq import _set_tcp_keepalive, get_context -from posttroll.message import Message +from posttroll.backends.zmq import SocketReceiver, get_tcp_keepalive_options LOGGER = logging.getLogger(__name__) -class _ZMQSubscriber: +class ZMQSubscriber: def __init__(self, addresses, topics="", message_filter=None, translate=False): """Initialize the subscriber.""" @@ -27,7 +27,8 @@ def __init__(self, addresses, topics="", message_filter=None, translate=False): self._hooks = [] self._hooks_cb = {} - self.poller = Poller() + #self.poller = Poller() + self._sock_receiver = SocketReceiver() self._lock = Lock() self.update(addresses) @@ -68,8 +69,8 @@ def remove(self, address): self._remove_sub_socket(subscriber) def _remove_sub_socket(self, subscriber): - if self.poller: - self.poller.unregister(subscriber) + if self._sock_receiver: + self._sock_receiver.unregister(subscriber) subscriber.close() def update(self, addresses): @@ -107,10 +108,9 @@ def add_hook_pull(self, address, callback): specified subscription. Good for pushed 'inproc' messages from another thread. """ LOGGER.info("Subscriber adding PULL hook %s", str(address)) - socket = get_context().socket(PULL) - socket.connect(address) - if self.poller: - self.poller.register(socket, POLLIN) + socket = self._create_socket(PULL, address) + if self._sock_receiver: + self._sock_receiver.register(socket) self._add_hook(socket, callback) def _add_hook(self, socket, callback): @@ -131,11 +131,9 @@ def subscribers(self): def recv(self, timeout=None): """Receive, optionally with *timeout* in seconds.""" - if timeout: - timeout *= 1000. for sub in list(self.subscribers) + self._hooks: - self.poller.register(sub, POLLIN) + self._sock_receiver.register(sub) self._loop = True try: while self._loop: @@ -143,38 +141,33 @@ def recv(self, timeout=None): yield from self._new_messages(timeout) finally: for sub in list(self.subscribers) + self._hooks: - self.poller.unregister(sub) + self._sock_receiver.unregister(sub) + # self.poller.unregister(sub) def _new_messages(self, timeout): """Check for new messages to yield and pass to the callbacks.""" + all_subs = list(self.subscribers) + self._hooks try: - socks = dict(self.poller.poll(timeout=timeout)) - if socks: - for sub in self.subscribers: - if sub in socks and socks[sub] == POLLIN: - received = sub.recv_string(NOBLOCK) - m__ = Message.decode(received) - if not self._filter or self._filter(m__): - if self._translate: - url = urlsplit(self.sub_addr[sub]) - host = url[1].split(":")[0] - m__.sender = (m__.sender.split("@")[0] - + "@" + host) - yield m__ - - for sub in self._hooks: - if sub in socks and socks[sub] == POLLIN: - m__ = Message.decode(sub.recv_string(NOBLOCK)) - self._hooks_cb[sub](m__) - else: - # timeout - yield None + for m__, sock in self._sock_receiver.receive(*all_subs, timeout=timeout): + if sock in self.subscribers: + if not self._filter or self._filter(m__): + if self._translate: + url = urlsplit(self.sub_addr[sock]) + host = url[1].split(":")[0] + m__.sender = (m__.sender.split("@")[0] + + "@" + host) + yield m__ + elif sock in self._hooks: + self._hooks_cb[sock](m__) + except TimeoutError: + yield None except ZMQError as err: if self._loop: LOGGER.exception("Receive failed: %s", str(err)) + def __call__(self, **kwargs): """Handle calls with class instance.""" return self.recv(**kwargs) @@ -201,54 +194,21 @@ def __del__(self): except Exception: # noqa: E722 pass - -class UnsecureZMQSubscriber(_ZMQSubscriber): - """Unsecure ZMQ implementation of the subscriber.""" - def _add_sub_socket(self, address, topics): - subscriber = get_context().socket(SUB) - _set_tcp_keepalive(subscriber) - for t__ in topics: - subscriber.setsockopt_string(SUBSCRIBE, str(t__)) - subscriber.connect(address) - - if self.poller: - self.poller.register(subscriber, POLLIN) - return subscriber - - -class SecureZMQSubscriber(_ZMQSubscriber): - """Secure ZMQ implementation of the subscriber, using Curve.""" - def __init__(self, *args, client_secret_key_file=None, server_public_key_file=None, **kwargs): - """Initialize the subscriber.""" - if client_secret_key_file is None: - raise TypeError("Missing client_secret_key_file argument.") - if server_public_key_file is None: - raise TypeError("Missing server_public_key_file argument.") - self._client_secret_file = client_secret_key_file - self._server_public_key_file = server_public_key_file - - super().__init__(*args, **kwargs) - - def _add_sub_socket(self, address, topics): - import zmq.auth - subscriber = get_context().socket(SUB) + options = get_tcp_keepalive_options() - client_public, client_secret = zmq.auth.load_certificate(self._client_secret_file) - subscriber.curve_secretkey = client_secret - subscriber.curve_publickey = client_public + subscriber = self._create_socket(SUB, address, options) + add_subscriptions(subscriber, topics) - server_public, _ = zmq.auth.load_certificate(self._server_public_key_file) - # The client must know the server's public key to make a CURVE connection. - subscriber.curve_serverkey = server_public + if self._sock_receiver: + self._sock_receiver.register(subscriber) + return subscriber + def _create_socket(self, socket_type, address, options): + return set_up_client_socket(socket_type, address, options) - _set_tcp_keepalive(subscriber) - for t__ in topics: - subscriber.setsockopt_string(SUBSCRIBE, str(t__)) - subscriber.connect(address) - if self.poller: - self.poller.register(subscriber, POLLIN) - return subscriber +def add_subscriptions(socket, topics): + for t__ in topics: + socket.setsockopt_string(SUBSCRIBE, str(t__)) diff --git a/posttroll/bbmcast.py b/posttroll/bbmcast.py index da759f5..c2cf7b3 100644 --- a/posttroll/bbmcast.py +++ b/posttroll/bbmcast.py @@ -68,6 +68,14 @@ SocketTimeout = timeout # for easy access to socket.timeout +DEFAULT_BROADCAST_PORT = 21200 + +def get_configured_broadcast_port(): + """Get the configured nameserver port.""" + return config.get("broadcast_port", DEFAULT_BROADCAST_PORT) + + + # ----------------------------------------------------------------------------- # # Sender. @@ -139,7 +147,7 @@ def get_mc_group(): # ----------------------------------------------------------------------------- -class MulticastReceiver(object): +class MulticastReceiver: """Multicast receiver on *port* for an *mcgroup*.""" BUFSIZE = 1024 diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py index b3e1501..d72dd4c 100644 --- a/posttroll/message_broadcaster.py +++ b/posttroll/message_broadcaster.py @@ -28,23 +28,20 @@ import threading from posttroll import config, message -from posttroll.bbmcast import MulticastSender +from posttroll.bbmcast import MulticastSender, get_configured_broadcast_port __all__ = ("MessageBroadcaster", "AddressBroadcaster", "sendaddress") LOGGER = logging.getLogger(__name__) -broadcast_port = 21200 - - class DesignatedReceiversSender: """Sends message to multiple *receivers* on *port*.""" def __init__(self, default_port, receivers): """Set settings.""" backend = config.get("backend", "unsecure_zmq") if backend == "unsecure_zmq": - from posttroll.backends.zmq.message_broadcaster import UnsecureZMQDesignatedReceiversSender - self._sender = UnsecureZMQDesignatedReceiversSender(default_port, receivers) + from posttroll.backends.zmq.message_broadcaster import ZMQDesignatedReceiversSender + self._sender = ZMQDesignatedReceiversSender(default_port, receivers) def __call__(self, data): """Send messages from all receivers.""" @@ -61,7 +58,7 @@ def close(self): # ---------------------------------------------------------------------------- -class MessageBroadcaster(object): +class MessageBroadcaster: """Class to broadcast stuff. If *interval* is 0 or negative, no broadcasting is done. @@ -135,7 +132,7 @@ def __init__(self, name, address, interval, nameservers): """Set up the Address broadcaster.""" msg = message.Message("/address/%s" % name, "info", {"URI": "%s:%d" % address}).encode() - MessageBroadcaster.__init__(self, msg, broadcast_port, interval, + MessageBroadcaster.__init__(self, msg, get_configured_broadcast_port(), interval, nameservers) @@ -158,7 +155,7 @@ def __init__(self, name, address, data_type, interval=2, nameservers=None): msg = message.Message("/address/%s" % name, "info", {"URI": address, "service": data_type}).encode() - MessageBroadcaster.__init__(self, msg, broadcast_port, interval, + MessageBroadcaster.__init__(self, msg, get_configured_broadcast_port(), interval, nameservers) diff --git a/posttroll/ns.py b/posttroll/ns.py index 9296dd2..0221bf1 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -88,10 +88,10 @@ def get_pub_address(name, timeout=10, nameserver="localhost"): timeout: how long to wait for an address, in seconds. nameserver: nameserver address to query the publishers from (default: localhost). """ - backend = config.get("backend", "unsecure_zmq") - if backend == "unsecure_zmq": - from posttroll.backends.zmq.ns import unsecure_zmq_get_pub_address - return unsecure_zmq_get_pub_address(name, timeout, nameserver) + if config["backend"] not in ["unsecure_zmq", "secure_zmq"]: + raise NotImplementedError(f"Did not recognize backend: {config['backend']}") + from posttroll.backends.zmq.ns import zmq_get_pub_address + return zmq_get_pub_address(name, timeout, nameserver) # Server part. @@ -116,10 +116,11 @@ def __init__(self, max_age=None, multicast_enabled=True, restrict_to_localhost=F self._max_age = max_age or dt.timedelta(minutes=10) self._multicast_enabled = multicast_enabled self._restrict_to_localhost = restrict_to_localhost - backend = config.get("backend", "unsecure_zmq") - if backend == "unsecure_zmq": - from posttroll.backends.zmq.ns import UnsecureZMQNameServer - self._ns = UnsecureZMQNameServer() + backend = config["backend"] + if backend not in ["unsecure_zmq", "secure_zmq"]: + raise NotImplementedError(f"Did not recognize backend: {backend}") + from posttroll.backends.zmq.ns import ZMQNameServer + self._ns = ZMQNameServer() def run(self, *args): """Run the listener and answer to requests.""" diff --git a/posttroll/publisher.py b/posttroll/publisher.py index 51a89c9..dee85cc 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -85,7 +85,7 @@ class Publisher: """ - def __init__(self, address, name="", min_port=None, max_port=None, **kwargs): + def __init__(self, address, name="", min_port=None, max_port=None): """Bind the publisher class to a port.""" # Limit port range or use the defaults when no port is defined # by the user @@ -95,17 +95,10 @@ def __init__(self, address, name="", min_port=None, max_port=None, **kwargs): self._heartbeat = None backend = config.get("backend", "unsecure_zmq") - if backend == "unsecure_zmq": - if kwargs: - raise TypeError(f"Unexpected keyword arguments: {kwargs}") - from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - self._publisher = UnsecureZMQPublisher(address, name=name, min_port=min_port, max_port=max_port) - elif backend == "secure_zmq": - from posttroll.backends.zmq.publisher import SecureZMQPublisher - self._publisher = SecureZMQPublisher(address, name=name, min_port=min_port, max_port=max_port, - **kwargs) - else: + if backend not in ["unsecure_zmq", "secure_zmq"]: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") + from posttroll.backends.zmq.publisher import ZMQPublisher + self._publisher = ZMQPublisher(address, name=name, min_port=min_port, max_port=max_port) def start(self): """Start the publisher.""" diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 5528fa4..9a04008 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -60,23 +60,17 @@ class Subscriber: """ - def __init__(self, addresses, topics="", message_filter=None, translate=False, **kwargs): + def __init__(self, addresses, topics="", message_filter=None, translate=False): """Initialize the subscriber.""" topics = self._magickfy_topics(topics) backend = config.get("backend", "unsecure_zmq") - if backend == "unsecure_zmq": - if kwargs: - raise TypeError(f"Unexpected keyword arguments: {kwargs}") - from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber - self._subscriber = UnsecureZMQSubscriber(addresses, topics=topics, - message_filter=message_filter, translate=translate) - elif backend == "secure_zmq": - from posttroll.backends.zmq.subscriber import SecureZMQSubscriber - self._subscriber = SecureZMQSubscriber(addresses, topics=topics, - message_filter=message_filter, translate=translate, **kwargs) - else: + if backend not in ["unsecure_zmq", "secure_zmq"]: raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") + from posttroll.backends.zmq.subscriber import ZMQSubscriber + self._subscriber = ZMQSubscriber(addresses, topics=topics, + message_filter=message_filter, translate=translate) + def add(self, address, topics=None): """Add *address* to the subscribing list for *topics*. diff --git a/posttroll/tests/test_nameserver.py b/posttroll/tests/test_nameserver.py new file mode 100644 index 0000000..123ea53 --- /dev/null +++ b/posttroll/tests/test_nameserver.py @@ -0,0 +1,253 @@ +"""Tests for communication involving the nameserver for service discovery.""" + +import os +import time +import unittest +from contextlib import contextmanager +from datetime import timedelta +from threading import Thread +from unittest import mock + +import pytest + +from posttroll import config +from posttroll.message import Message +from posttroll.ns import NameServer, get_pub_address +from posttroll.publisher import Publish +from posttroll.subscriber import Subscribe + + +def free_port(): + """Get a free port. + + From https://gist.github.com/bertjwregeer/0be94ced48383a42e70c3d9fff1f4ad0 + + Returns a factory that finds the next free port that is available on the OS + This is a bit of a hack, it does this by creating a new socket, and calling + bind with the 0 port. The operating system will assign a brand new port, + which we can find out using getsockname(). Once we have the new port + information we close the socket thereby returning it to the free pool. + This means it is technically possible for this function to return the same + port twice (for example if run in very quick succession), however operating + systems return a random port number in the default range (1024 - 65535), + and it is highly unlikely for two processes to get the same port number. + In other words, it is possible to flake, but incredibly unlikely. + """ + import socket + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("0.0.0.0", 0)) + portnum = s.getsockname()[1] + s.close() + + return portnum + + +@contextmanager +def create_nameserver_instance(max_age=3, multicast_enabled=True): + """Create a nameserver instance.""" + config.set(nameserver_port=free_port()) + config.set(address_publish_port=free_port()) + ns = NameServer(max_age=timedelta(seconds=max_age), multicast_enabled=multicast_enabled) + thr = Thread(target=ns.run) + thr.start() + + try: + yield + finally: + ns.stop() + thr.join() + + + +class TestAddressReceiver(unittest.TestCase): + """Test the AddressReceiver.""" + + @mock.patch("posttroll.address_receiver.Message") + @mock.patch("posttroll.address_receiver.Publish") + @mock.patch("posttroll.address_receiver.MulticastReceiver") + def test_localhost_restriction(self, mcrec, pub, msg): + """Test address receiver restricted only to localhost.""" + mocked_publish_instance = mock.Mock() + pub.return_value.__enter__.return_value = mocked_publish_instance + mcr_instance = mock.Mock() + mcrec.return_value = mcr_instance + mcr_instance.return_value = "blabla", ("255.255.255.255", 12) + + from posttroll.address_receiver import AddressReceiver + adr = AddressReceiver(restrict_to_localhost=True) + adr.start() + time.sleep(3) + try: + msg.decode.assert_not_called() + mocked_publish_instance.send.assert_not_called() + finally: + adr.stop() + + +@pytest.mark.parametrize( + "multicast_enabled", + [True, False] +) +def test_pub_addresses(multicast_enabled): + """Test retrieving addresses.""" + from posttroll.ns import get_pub_addresses + from posttroll.publisher import Publish + + if multicast_enabled: + if os.getenv("DISABLED_MULTICAST"): + pytest.skip("Multicast tests disabled.") + nameservers = None + else: + nameservers = ["localhost"] + with config.set(broadcast_port=free_port()): + with create_nameserver_instance(multicast_enabled=multicast_enabled): + with Publish(str("data_provider"), 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1): + time.sleep(.3) + res = get_pub_addresses(["this_data"], timeout=.5) + assert len(res) == 1 + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} + for key, val in expected.items(): + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] + res = get_pub_addresses([str("data_provider")]) + assert len(res) == 1 + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} + for key, val in expected.items(): + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] + + +@pytest.mark.parametrize( + "multicast_enabled", + [True, False] +) +def test_pub_sub_ctx(multicast_enabled): + """Test publish and subscribe.""" + if multicast_enabled: + if os.getenv("DISABLED_MULTICAST"): + pytest.skip("Multicast tests disabled.") + nameservers = None + else: + nameservers = ["localhost"] + + with config.set(broadcast_port=free_port()): + with create_nameserver_instance(multicast_enabled=multicast_enabled): + with Publish("data_provider", 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1) as pub: + with Subscribe("this_data", "counter") as sub: + for counter in range(5): + message = Message("/counter", "info", str(counter)) + pub.send(str(message)) + time.sleep(.1) + msg = next(sub.recv(.2)) + if msg is not None: + assert str(msg) == str(message) + tested = True + assert tested + + +@pytest.mark.parametrize( + "multicast_enabled", + [True, False] +) +def test_pub_sub_add_rm(multicast_enabled): + """Test adding and removing publishers.""" + if multicast_enabled: + if os.getenv("DISABLED_MULTICAST"): + pytest.skip("Multicast tests disabled.") + nameservers = None + else: + nameservers = ["localhost"] + + max_age = 0.5 + with config.set(broadcast_port=free_port()): + with create_nameserver_instance(max_age=max_age, multicast_enabled=multicast_enabled): + with Subscribe("this_data", "counter", addr_listener=True, timeout=.2) as sub: + assert len(sub.addresses) == 0 + with Publish("data_provider", 0, ["this_data"], nameservers=nameservers): + time.sleep(.1) + next(sub.recv(.1)) + assert len(sub.addresses) == 1 + time.sleep(max_age * 4) + for msg in sub.recv(.1): + if msg is None: + break + time.sleep(.3) + assert len(sub.addresses) == 0 + with Publish("data_provider_2", 0, ["another_data"], nameservers=nameservers): + time.sleep(.1) + next(sub.recv(.1)) + assert len(sub.addresses) == 0 + + +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST"), + reason="Multicast tests disabled.", +) +def test_listener_container(): + """Test listener container.""" + from posttroll.listener import ListenerContainer + from posttroll.message import Message + from posttroll.publisher import NoisyPublisher + + with create_nameserver_instance(): + pub = NoisyPublisher("test", broadcast_interval=0.1) + pub.start() + sub = ListenerContainer(topics=["/counter"]) + time.sleep(.1) + for counter in range(5): + tested = False + msg_out = Message("/counter", "info", str(counter)) + pub.send(str(msg_out)) + + msg_in = sub.output_queue.get(True, 1) + if msg_in is not None: + assert str(msg_in) == str(msg_out) + tested = True + assert tested + pub.stop() + sub.stop() + + +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST"), + reason="Multicast tests disabled.", +) +def test_noisypublisher_heartbeat(): + """Test that the heartbeat in the NoisyPublisher works.""" + from posttroll.publisher import NoisyPublisher + from posttroll.subscriber import Subscribe + + ns_ = NameServer() + thr = Thread(target=ns_.run) + thr.start() + + pub = NoisyPublisher("test") + pub.start() + time.sleep(0.2) + + with Subscribe("test", topics="/heartbeat/test", nameserver="localhost") as sub: + time.sleep(0.2) + pub.heartbeat(min_interval=10) + msg = next(sub.recv(1)) + assert msg.type == "beat" + assert msg.data == {"min_interval": 10} + pub.stop() + ns_.stop() + thr.join() + + +def test_switch_backend_for_nameserver(): + """Test switching backend for nameserver.""" + with config.set(backend="spurious_backend"): + with pytest.raises(NotImplementedError): + NameServer() + with pytest.raises(NotImplementedError): + get_pub_address("some_name") diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 2f88943..c72011d 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -23,12 +23,10 @@ """Test the publishing and subscribing facilities.""" -import os import time import unittest from contextlib import contextmanager -from datetime import timedelta -from threading import Lock, Thread +from threading import Lock from unittest import mock import pytest @@ -37,7 +35,6 @@ import posttroll from posttroll import config from posttroll.message import Message -from posttroll.ns import NameServer from posttroll.publisher import Publish, Publisher, create_publisher_from_dict_config from posttroll.subscriber import Subscribe, Subscriber @@ -71,125 +68,6 @@ def free_port(): return portnum -@contextmanager -def create_nameserver_instance(max_age=3, multicast_enabled=True): - """Create a nameserver instance.""" - config.set(nameserver_port=free_port()) - config.set(address_publish_port=free_port()) - ns = NameServer(max_age=timedelta(seconds=max_age), multicast_enabled=multicast_enabled) - thr = Thread(target=ns.run) - thr.start() - - try: - yield - finally: - ns.stop() - thr.join() - - -@pytest.mark.parametrize( - "multicast_enabled", - [True, False] -) -def test_pub_addresses(multicast_enabled): - """Test retrieving addresses.""" - from posttroll.ns import get_pub_addresses - from posttroll.publisher import Publish - - if multicast_enabled: - if os.getenv("DISABLED_MULTICAST"): - pytest.skip("Multicast tests disabled.") - nameservers = None - else: - nameservers = ["localhost"] - - - with create_nameserver_instance(multicast_enabled=multicast_enabled): - with Publish(str("data_provider"), 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1): - time.sleep(.3) - res = get_pub_addresses(["this_data"], timeout=.5) - assert len(res) == 1 - expected = {u"status": True, - u"service": [u"data_provider", u"this_data"], - u"name": u"address"} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] - res = get_pub_addresses([str("data_provider")]) - assert len(res) == 1 - expected = {u"status": True, - u"service": [u"data_provider", u"this_data"], - u"name": u"address"} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] - - -@pytest.mark.parametrize( - "multicast_enabled", - [True, False] -) -def test_pub_sub_ctx(multicast_enabled): - """Test publish and subscribe.""" - if multicast_enabled: - if os.getenv("DISABLED_MULTICAST"): - pytest.skip("Multicast tests disabled.") - nameservers = None - else: - nameservers = ["localhost"] - - with create_nameserver_instance(multicast_enabled=multicast_enabled): - with Publish("data_provider", 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1) as pub: - with Subscribe("this_data", "counter") as sub: - for counter in range(5): - message = Message("/counter", "info", str(counter)) - pub.send(str(message)) - time.sleep(.1) - msg = next(sub.recv(.2)) - if msg is not None: - assert str(msg) == str(message) - tested = True - sub.close() - assert tested - - -@pytest.mark.parametrize( - "multicast_enabled", - [True, False] -) -def test_pub_sub_add_rm(multicast_enabled): - """Test adding and removing publishers.""" - if multicast_enabled: - if os.getenv("DISABLED_MULTICAST"): - pytest.skip("Multicast tests disabled.") - nameservers = None - else: - nameservers = ["localhost"] - - max_age = 0.5 - - with create_nameserver_instance(max_age=max_age, multicast_enabled=multicast_enabled): - with Subscribe("this_data", "counter", addr_listener=True, timeout=.2) as sub: - assert len(sub.addresses) == 0 - with Publish("data_provider", 0, ["this_data"], nameservers=nameservers): - time.sleep(.1) - next(sub.recv(.1)) - assert len(sub.addresses) == 1 - time.sleep(max_age * 4) - for msg in sub.recv(.1): - if msg is None: - break - time.sleep(.3) - assert len(sub.addresses) == 0 - with Publish("data_provider_2", 0, ["another_data"], nameservers=nameservers): - time.sleep(.1) - next(sub.recv(.1)) - assert len(sub.addresses) == 0 - sub.close() - - class TestPubSub(unittest.TestCase): """Testing the publishing and subscribing capabilities.""" @@ -317,35 +195,6 @@ def _get_port_from_publish_instance(min_port=None, max_port=None): return False -@pytest.mark.skipif( - os.getenv("DISABLED_MULTICAST"), - reason="Multicast tests disabled.", -) -def test_listener_container(): - """Test listener container.""" - from posttroll.listener import ListenerContainer - from posttroll.message import Message - from posttroll.publisher import NoisyPublisher - - with create_nameserver_instance(): - pub = NoisyPublisher("test", broadcast_interval=0.1) - pub.start() - sub = ListenerContainer(topics=["/counter"]) - time.sleep(.1) - for counter in range(5): - tested = False - msg_out = Message("/counter", "info", str(counter)) - pub.send(str(msg_out)) - - msg_in = sub.output_queue.get(True, 1) - if msg_in is not None: - assert str(msg_in) == str(msg_out) - tested = True - assert tested - pub.stop() - sub.stop() - - class TestListenerContainerNoNameserver(unittest.TestCase): """Testing listener container without nameserver.""" @@ -382,27 +231,6 @@ def test_listener_container(self): sub.stop() -class TestAddressReceiver(unittest.TestCase): - """Test the AddressReceiver.""" - - @mock.patch("posttroll.address_receiver.Message") - @mock.patch("posttroll.address_receiver.Publish") - @mock.patch("posttroll.address_receiver.MulticastReceiver") - def test_localhost_restriction(self, mcrec, pub, msg): - """Test address receiver restricted only to localhost.""" - mcr_instance = mock.Mock() - mcrec.return_value = mcr_instance - mcr_instance.return_value = "blabla", ("255.255.255.255", 12) - from posttroll.address_receiver import AddressReceiver - adr = AddressReceiver(restrict_to_localhost=True) - adr.start() - time.sleep(3) - msg.decode.assert_not_called() - adr.stop() - - - - ## Test create_publisher_from_config def test_publisher_with_invalid_arguments_crashes(): @@ -603,8 +431,8 @@ def _tcp_keepalive_no_settings(): @pytest.mark.usefixtures("_tcp_keepalive_settings") def test_publisher_tcp_keepalive(): """Test that TCP Keepalive is set for Publisher if the environment variables are present.""" - from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - pub = UnsecureZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start() + from posttroll.backends.zmq.publisher import ZMQPublisher + pub = ZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start() _assert_tcp_keepalive(pub.publish_socket) pub.stop() @@ -612,8 +440,8 @@ def test_publisher_tcp_keepalive(): @pytest.mark.usefixtures("_tcp_keepalive_no_settings") def test_publisher_tcp_keepalive_not_set(): """Test that TCP Keepalive is not set on by default.""" - from posttroll.backends.zmq.publisher import UnsecureZMQPublisher - pub = UnsecureZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start() + from posttroll.backends.zmq.publisher import ZMQPublisher + pub = ZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start() _assert_no_tcp_keepalive(pub.publish_socket) pub.stop() @@ -621,8 +449,8 @@ def test_publisher_tcp_keepalive_not_set(): @pytest.mark.usefixtures("_tcp_keepalive_settings") def test_subscriber_tcp_keepalive(): """Test that TCP Keepalive is set for Subscriber if the environment variables are present.""" - from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber - sub = UnsecureZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}") + from posttroll.backends.zmq.subscriber import ZMQSubscriber + sub = ZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}") assert len(sub.addr_sub.values()) == 1 _assert_tcp_keepalive(list(sub.addr_sub.values())[0]) sub.stop() @@ -631,8 +459,8 @@ def test_subscriber_tcp_keepalive(): @pytest.mark.usefixtures("_tcp_keepalive_no_settings") def test_subscriber_tcp_keepalive_not_set(): """Test that TCP Keepalive is not set on by default.""" - from posttroll.backends.zmq.subscriber import UnsecureZMQSubscriber - sub = UnsecureZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}") + from posttroll.backends.zmq.subscriber import ZMQSubscriber + sub = ZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}") assert len(sub.addr_sub.values()) == 1 _assert_no_tcp_keepalive(list(sub.addr_sub.values())[0]) sub.close() @@ -656,35 +484,6 @@ def _assert_no_tcp_keepalive(socket): assert socket.getsockopt(zmq.TCP_KEEPALIVE_INTVL) == -1 -@pytest.mark.skipif( - os.getenv("DISABLED_MULTICAST"), - reason="Multicast tests disabled.", -) -def test_noisypublisher_heartbeat(): - """Test that the heartbeat in the NoisyPublisher works.""" - from posttroll.ns import NameServer - from posttroll.publisher import NoisyPublisher - from posttroll.subscriber import Subscribe - - ns_ = NameServer() - thr = Thread(target=ns_.run) - thr.start() - - pub = NoisyPublisher("test") - pub.start() - time.sleep(0.2) - - with Subscribe("test", topics="/heartbeat/test", nameserver="localhost") as sub: - time.sleep(0.2) - pub.heartbeat(min_interval=10) - msg = next(sub.recv(1)) - assert msg.type == "beat" - assert msg.data == {"min_interval": 10} - pub.stop() - ns_.stop() - thr.join() - - def test_switch_to_unknown_backend(): """Test switching to unknown backend.""" from posttroll.publisher import Publisher diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index 4c6a903..38d442e 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -10,6 +10,8 @@ from posttroll import config from posttroll.publisher import Publisher from posttroll.subscriber import Subscriber, create_subscriber_from_dict_config +from posttroll.tests.test_nameserver import create_nameserver_instance +from posttroll.ns import get_pub_address def create_keys(tmp_path): @@ -48,21 +50,21 @@ def create_keys(tmp_path): def test_ipc_pubsub_with_sec(tmp_path): """Test pub-sub on a secure ipc socket.""" - server_public_key, server_secret_key = zmq.auth.create_certificates(tmp_path, "server") - client_public_key, client_secret_key = zmq.auth.create_certificates(tmp_path, "client") + server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") + client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" - with config.set(backend="secure_zmq"): - subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202, - client_secret_key_file=client_secret_key, - server_public_key_file=server_public_key) + with config.set(backend="secure_zmq", + client_secret_key_file=client_secret_key_file, + clients_public_keys_directory=os.path.dirname(client_public_key_file), + server_public_key_file=server_public_key_file, + server_secret_key_file=server_secret_key_file): + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202) sub = create_subscriber_from_dict_config(subscriber_settings) from posttroll.publisher import Publisher - pub = Publisher(ipc_address, - server_secret_key=server_secret_key, - public_keys_directory=os.path.dirname(client_public_key)) + pub = Publisher(ipc_address) pub.start() @@ -94,38 +96,37 @@ def test_switch_to_secure_zmq_backend(tmp_path): server_secret_key = secret_keys_dir / "server.key_secret" public_keys_directory = public_keys_dir - publisher_key_args = dict(server_secret_key=server_secret_key, - public_keys_directory=public_keys_directory) client_secret_key = secret_keys_dir / "client.key_secret" server_public_key = public_keys_dir / "server.key" - subscriber_key_args = dict(client_secret_key_file=client_secret_key, - server_public_key_file=server_public_key) - with config.set(backend="secure_zmq"): - Publisher("ipc://bla.ipc", **publisher_key_args) - Subscriber("ipc://bla.ipc", **subscriber_key_args) + with config.set(backend="secure_zmq", + client_secret_key_file=client_secret_key, + clients_public_keys_directory=public_keys_directory, + server_public_key_file=server_public_key, + server_secret_key_file=server_secret_key): + Publisher("ipc://bla.ipc") + Subscriber("ipc://bla.ipc") def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): """Test pub-sub on a secure ipc socket.""" - base_dir = tmp_path - public_keys_dir = base_dir / "public_keys" - secret_keys_dir = base_dir / "private_keys" + #create_keys(tmp_path) - create_keys(tmp_path) + server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") + client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" - with config.set(backend="secure_zmq"): - subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202, - client_secret_key_file=secret_keys_dir / "client.key_secret", - server_public_key_file=public_keys_dir / "server.key") + with config.set(backend="secure_zmq", + client_secret_key_file=client_secret_key_file, + clients_public_keys_directory=os.path.dirname(client_public_key_file), + server_public_key_file=server_public_key_file, + server_secret_key_file=server_secret_key_file): + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202) sub = create_subscriber_from_dict_config(subscriber_settings) from posttroll.publisher import create_publisher_from_dict_config pub_settings = dict(address=ipc_address, - server_secret_key=secret_keys_dir / "server.key_secret", - public_keys_directory=public_keys_dir, nameservers=False, port=1789) pub = create_publisher_from_dict_config(pub_settings) @@ -146,3 +147,17 @@ def delayed_send(msg): sub.stop() thr.join() pub.stop() + +def test_switch_to_secure_backend_for_nameserver(tmp_path): + """Test switching backend for nameserver.""" + server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") + client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") + with config.set(backend="secure_zmq", + client_secret_key_file=client_secret_key_file, + clients_public_keys_directory=os.path.dirname(client_public_key_file), + server_public_key_file=server_public_key_file, + server_secret_key_file=server_secret_key_file): + + with create_nameserver_instance(): + res = get_pub_address("some_name") + assert res == "" From 19fb35d0e43d657aaa5c845b0b590fa968677954 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Fri, 3 May 2024 16:50:37 +0200 Subject: [PATCH 37/45] Clean up --- posttroll/address_receiver.py | 2 +- posttroll/backends/zmq/__init__.py | 29 +--------- posttroll/backends/zmq/address_receiver.py | 7 ++- posttroll/backends/zmq/message_broadcaster.py | 5 +- posttroll/backends/zmq/ns.py | 13 ++--- posttroll/backends/zmq/publisher.py | 7 ++- posttroll/backends/zmq/socket.py | 54 ++++++++++++++++++- posttroll/backends/zmq/subscriber.py | 14 ++--- posttroll/tests/test_secure_zmq_backend.py | 2 +- 9 files changed, 77 insertions(+), 56 deletions(-) diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index d2cef04..5d58858 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -35,12 +35,12 @@ import time import netifaces +from zmq import ZMQError from posttroll import config from posttroll.bbmcast import MulticastReceiver, SocketTimeout, get_configured_broadcast_port from posttroll.message import Message from posttroll.publisher import Publish -from zmq import ZMQError __all__ = ("AddressReceiver", "getaddress") diff --git a/posttroll/backends/zmq/__init__.py b/posttroll/backends/zmq/__init__.py index 2cd6597..d59a127 100644 --- a/posttroll/backends/zmq/__init__.py +++ b/posttroll/backends/zmq/__init__.py @@ -5,7 +5,6 @@ import zmq from posttroll import config -from posttroll.message import Message logger = logging.getLogger(__name__) context = {} @@ -23,6 +22,7 @@ def get_context(): return context[pid] def destroy_context(linger=None): + """Destroy the context.""" pid = os.getpid() context.pop(pid).destroy(linger) @@ -46,30 +46,3 @@ def get_tcp_keepalive_options(): param = getattr(zmq, opt.upper()) keepalive_options[param] = value return keepalive_options - - -class SocketReceiver: - - def __init__(self): - self._poller = zmq.Poller() - - def register(self, socket): - """Register the socket.""" - self._poller.register(socket, zmq.POLLIN) - - def unregister(self, socket): - """Unregister the socket.""" - self._poller.unregister(socket) - - def receive(self, *sockets, timeout=None): - """Timeout is in seconds.""" - if timeout: - timeout *= 1000 - socks = dict(self._poller.poll(timeout=timeout)) - if socks: - for sock in sockets: - if socks.get(sock) == zmq.POLLIN: - received = sock.recv_string(zmq.NOBLOCK) - yield Message.decode(received), sock - else: - raise TimeoutError("Did not receive anything on sockets.") diff --git a/posttroll/backends/zmq/address_receiver.py b/posttroll/backends/zmq/address_receiver.py index f926747..ef58dfa 100644 --- a/posttroll/backends/zmq/address_receiver.py +++ b/posttroll/backends/zmq/address_receiver.py @@ -1,9 +1,9 @@ """ZMQ implementation of the the simple receiver.""" -from zmq import LINGER, REP +from zmq import REP from posttroll.address_receiver import get_configured_address_port -from posttroll.backends.zmq.socket import set_up_server_socket +from posttroll.backends.zmq.socket import close_socket, set_up_server_socket class SimpleReceiver(object): @@ -31,7 +31,6 @@ def __call__(self): def close(self): """Close the receiver.""" self._running = False - self._socket.setsockopt(LINGER, 1) - self._socket.close() + close_socket(self._socket) if self._authenticator: self._authenticator.stop() diff --git a/posttroll/backends/zmq/message_broadcaster.py b/posttroll/backends/zmq/message_broadcaster.py index fe2ddfe..238d9eb 100644 --- a/posttroll/backends/zmq/message_broadcaster.py +++ b/posttroll/backends/zmq/message_broadcaster.py @@ -3,9 +3,9 @@ import logging import threading -from posttroll.backends.zmq.socket import set_up_client_socket from zmq import LINGER, NOBLOCK, REQ, ZMQError +from posttroll.backends.zmq.socket import close_socket, set_up_client_socket logger = logging.getLogger(__name__) @@ -47,8 +47,7 @@ def _send_to_address(self, address, data, timeout=10): break finally: - socket.setsockopt(LINGER, 1) - socket.close() + close_socket(socket) def close(self): """Close the sender.""" diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py index 3325272..5827920 100644 --- a/posttroll/backends/zmq/ns.py +++ b/posttroll/backends/zmq/ns.py @@ -4,10 +4,9 @@ from contextlib import suppress from threading import Lock -from posttroll.backends.zmq.socket import set_up_client_socket, set_up_server_socket from zmq import LINGER, REP, REQ -from posttroll.backends.zmq import SocketReceiver +from posttroll.backends.zmq.socket import SocketReceiver, close_socket, set_up_client_socket, set_up_server_socket from posttroll.message import Message from posttroll.ns import get_active_address, get_configured_nameserver_port @@ -29,6 +28,7 @@ def zmq_get_pub_address(name, timeout=10, nameserver="localhost"): def create_nameserver_address(nameserver): + """Create the nameserver address.""" port = get_configured_nameserver_port() nameserver_address = "tcp://" + nameserver + ":" + str(port) return nameserver_address @@ -52,10 +52,11 @@ def _fetch_address_using_socket(socket, name, timeout): % timeout) finally: socket_receiver.unregister(socket) - socket.setsockopt(LINGER, 1) - socket.close() + close_socket(socket) + def create_req_socket(timeout, nameserver_address): + """Create a REQ socket.""" options = {LINGER: int(timeout * 1000)} socket = set_up_client_socket(REQ, nameserver_address, options) return socket @@ -97,9 +98,9 @@ def run(self, address_receiver): self.close_sockets_and_threads() def close_sockets_and_threads(self): + """Close all sockets and threads.""" with suppress(AttributeError): - self.listener.setsockopt(LINGER, 1) - self.listener.close() + close_socket(self.listener) with suppress(AttributeError): self._authenticator.stop() diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py index 37a4898..8d2bec5 100644 --- a/posttroll/backends/zmq/publisher.py +++ b/posttroll/backends/zmq/publisher.py @@ -1,13 +1,13 @@ """ZMQ implementation of the publisher.""" -from contextlib import suppress import logging +from contextlib import suppress from threading import Lock -from posttroll.backends.zmq.socket import set_up_server_socket import zmq from posttroll.backends.zmq import get_tcp_keepalive_options +from posttroll.backends.zmq.socket import close_socket, set_up_server_socket LOGGER = logging.getLogger(__name__) @@ -54,7 +54,6 @@ def send(self, msg): def stop(self): """Stop the publisher.""" - self.publish_socket.setsockopt(zmq.LINGER, 1) - self.publish_socket.close() + close_socket(self.publish_socket) with suppress(AttributeError): self._authenticator.stop() diff --git a/posttroll/backends/zmq/socket.py b/posttroll/backends/zmq/socket.py index 16ae3da..132bb08 100644 --- a/posttroll/backends/zmq/socket.py +++ b/posttroll/backends/zmq/socket.py @@ -1,11 +1,22 @@ -from posttroll import get_context, config +"""ZMQ socket handling functions.""" + +from urllib.parse import urlsplit, urlunsplit + import zmq from zmq.auth.thread import ThreadAuthenticator -from urllib.parse import urlsplit, urlunsplit + +from posttroll import config, get_context +from posttroll.message import Message +def close_socket(sock): + """Close a zmq socket.""" + sock.setsockopt(zmq.LINGER, 1) + sock.close() + def set_up_client_socket(socket_type, address, options=None): + """Set up a client (connecting) zmq socket.""" backend = config["backend"] if backend == "unsecure_zmq": sock = create_unsecure_client_socket(socket_type) @@ -17,10 +28,12 @@ def set_up_client_socket(socket_type, address, options=None): def create_unsecure_client_socket(socket_type): + """Create an unsecure client socket.""" return get_context().socket(socket_type) def add_options(sock, options=None): + """Add options to a socket.""" if not options: return for param, val in options.items(): @@ -28,6 +41,7 @@ def add_options(sock, options=None): def create_secure_client_socket(socket_type): + """Create a secure client socket.""" subscriber = get_context().socket(socket_type) client_secret_key_file = config["client_secret_key_file"] @@ -43,6 +57,7 @@ def create_secure_client_socket(socket_type): def set_up_server_socket(socket_type, destination, options=None, port_interval=(None, None)): + """Set up a server (binding) socket.""" if options is None: options = {} backend = config["backend"] @@ -59,10 +74,15 @@ def set_up_server_socket(socket_type, destination, options=None, port_interval=( def create_unsecure_server_socket(socket_type): + """Create an unsecure server socket.""" return get_context().socket(socket_type) def bind(sock, destination, port_interval): + """Bind the socket to a destination. + + If a random port is to be chosen, the port_interval is used. + """ # Check for port 0 (random port) min_port, max_port = port_interval u__ = urlsplit(destination) @@ -83,6 +103,7 @@ def bind(sock, destination, port_interval): def create_secure_server_socket(socket_type): + """Create a secure server socket.""" server_secret_key = config["server_secret_key_file"] clients_public_keys_directory = config["clients_public_keys_directory"] authorized_sub_addresses = config.get("authorized_client_addresses", []) @@ -104,3 +125,32 @@ def create_secure_server_socket(socket_type): server_socket.curve_publickey = server_public server_socket.curve_server = True return server_socket, authenticator_thread + + +class SocketReceiver: + """A receiver for mulitple sockets.""" + + def __init__(self): + """Set up the receiver.""" + self._poller = zmq.Poller() + + def register(self, socket): + """Register the socket.""" + self._poller.register(socket, zmq.POLLIN) + + def unregister(self, socket): + """Unregister the socket.""" + self._poller.unregister(socket) + + def receive(self, *sockets, timeout=None): + """Timeout is in seconds.""" + if timeout: + timeout *= 1000 + socks = dict(self._poller.poll(timeout=timeout)) + if socks: + for sock in sockets: + if socks.get(sock) == zmq.POLLIN: + received = sock.recv_string(zmq.NOBLOCK) + yield Message.decode(received), sock + else: + raise TimeoutError("Did not receive anything on sockets.") diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index 8186f69..afdaebd 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -5,15 +5,16 @@ from time import sleep from urllib.parse import urlsplit -from zmq import LINGER, PULL, SUB, SUBSCRIBE, ZMQError -from posttroll.backends.zmq.socket import set_up_client_socket +from zmq import PULL, SUB, SUBSCRIBE, ZMQError -from posttroll.backends.zmq import SocketReceiver, get_tcp_keepalive_options +from posttroll.backends.zmq import get_tcp_keepalive_options +from posttroll.backends.zmq.socket import SocketReceiver, close_socket, set_up_client_socket LOGGER = logging.getLogger(__name__) class ZMQSubscriber: + """A ZMQ subscriber class.""" def __init__(self, addresses, topics="", message_filter=None, translate=False): """Initialize the subscriber.""" @@ -131,7 +132,6 @@ def subscribers(self): def recv(self, timeout=None): """Receive, optionally with *timeout* in seconds.""" - for sub in list(self.subscribers) + self._hooks: self._sock_receiver.register(sub) self._loop = True @@ -181,8 +181,7 @@ def close(self): self.stop() for sub in list(self.subscribers) + self._hooks: try: - sub.setsockopt(LINGER, 1) - sub.close() + close_socket(sub) except ZMQError: pass @@ -190,7 +189,7 @@ def __del__(self): """Clean up after the instance is deleted.""" for sub in list(self.subscribers) + self._hooks: try: - sub.close() + close_socket(sub) except Exception: # noqa: E722 pass @@ -210,5 +209,6 @@ def _create_socket(self, socket_type, address, options): def add_subscriptions(socket, topics): + """Add subscriptions to a socket.""" for t__ in topics: socket.setsockopt_string(SUBSCRIBE, str(t__)) diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index 38d442e..460f48c 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -8,10 +8,10 @@ import zmq.auth from posttroll import config +from posttroll.ns import get_pub_address from posttroll.publisher import Publisher from posttroll.subscriber import Subscriber, create_subscriber_from_dict_config from posttroll.tests.test_nameserver import create_nameserver_instance -from posttroll.ns import get_pub_address def create_keys(tmp_path): From 93c78e5adcf40ed38d95acf8fdcd696ac7773429 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 14 May 2024 15:44:59 +0200 Subject: [PATCH 38/45] Rename --- posttroll/ns.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/posttroll/ns.py b/posttroll/ns.py index 0221bf1..1bf05b4 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -23,7 +23,7 @@ """Manage other's subscriptions. -Default port is 5557, if $NAMESERVER_PORT is not defined. +Default port is 5557, if $POSTTROLL_NAMESERVER_PORT is not defined. """ import datetime as dt import logging @@ -54,7 +54,6 @@ def get_configured_nameserver_port(): return config.get("nameserver_port", port) - # Client functions. @@ -78,8 +77,6 @@ def get_pub_addresses(names=None, timeout=10, nameserver="localhost"): return addrs - - def get_pub_address(name, timeout=10, nameserver="localhost"): """Get the address of the named publisher. @@ -93,6 +90,7 @@ def get_pub_address(name, timeout=10, nameserver="localhost"): from posttroll.backends.zmq.ns import zmq_get_pub_address return zmq_get_pub_address(name, timeout, nameserver) + # Server part. @@ -105,7 +103,6 @@ def get_active_address(name, arec): return Message("/oper/ns", "info", "") - class NameServer: """The name server.""" From c7b23ede91c27c9995473f10495fadb49c40e3ec Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 14 May 2024 15:45:29 +0200 Subject: [PATCH 39/45] Rename --- posttroll/subscriber.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index 9a04008..c32ccbf 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -180,9 +180,9 @@ def __init__(self, services="", topics=_MAGICK, addr_listener=False, Default is to listen to all available services. """ - self._services = _to_array(services) - self._topics = _to_array(topics) - self._addresses = _to_array(addresses) + self._services = _to_list(services) + self._topics = _to_list(topics) + self._addresses = _to_list(addresses) self._timeout = timeout self._translate = translate @@ -283,7 +283,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): return self.subscriber.stop() -def _to_array(obj): +def _to_list(obj): """Convert *obj* to list if not already one.""" if isinstance(obj, str): return [obj, ] From 6364282a5f75c3cc9bfd3cd1353231a4faa60e06 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 21 May 2024 09:33:09 +0000 Subject: [PATCH 40/45] Improve tests --- posttroll/address_receiver.py | 31 ++++++++--------- posttroll/tests/test_nameserver.py | 39 ++++++++++++---------- posttroll/tests/test_secure_zmq_backend.py | 8 ++--- 3 files changed, 40 insertions(+), 38 deletions(-) diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index 5d58858..6ca006e 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -159,23 +159,13 @@ def _run(self): try: while self._do_run: try: - rerun = True - while rerun: - try: - data, fromaddr = recv() - rerun = False - except TimeoutError: - if self._do_run: - continue - else: - raise - if self._multicast_enabled: - ip_, port = fromaddr - if self._restrict_to_localhost and ip_ not in self._local_ips: - # discard external message - LOGGER.debug("Discard external message") - continue - LOGGER.debug("data %s", data) + data, fromaddr = recv() + except TimeoutError: + if self._do_run: + continue + else: + raise + except SocketTimeout: if self._multicast_enabled: LOGGER.debug("Multicast socket timed out on recv!") @@ -186,6 +176,13 @@ def _run(self): self._check_age(pub, min_interval=self._max_age / 20) if self._do_heartbeat: pub.heartbeat(min_interval=29) + if self._multicast_enabled: + ip_, port = fromaddr + if self._restrict_to_localhost and ip_ not in self._local_ips: + # discard external message + LOGGER.debug("Discard external message") + continue + LOGGER.debug("data %s", data) msg = Message.decode(data) name = msg.subject.split("/")[1] if msg.type == "info" and msg.subject.lower().startswith(self._subject): diff --git a/posttroll/tests/test_nameserver.py b/posttroll/tests/test_nameserver.py index 123ea53..8f8689b 100644 --- a/posttroll/tests/test_nameserver.py +++ b/posttroll/tests/test_nameserver.py @@ -179,7 +179,7 @@ def test_pub_sub_add_rm(multicast_enabled): for msg in sub.recv(.1): if msg is None: break - time.sleep(.3) + time.sleep(0.3) assert len(sub.addresses) == 0 with Publish("data_provider_2", 0, ["another_data"], nameservers=nameservers): time.sleep(.1) @@ -225,23 +225,28 @@ def test_noisypublisher_heartbeat(): from posttroll.publisher import NoisyPublisher from posttroll.subscriber import Subscribe - ns_ = NameServer() - thr = Thread(target=ns_.run) - thr.start() + min_interval = 10 - pub = NoisyPublisher("test") - pub.start() - time.sleep(0.2) - - with Subscribe("test", topics="/heartbeat/test", nameserver="localhost") as sub: - time.sleep(0.2) - pub.heartbeat(min_interval=10) - msg = next(sub.recv(1)) - assert msg.type == "beat" - assert msg.data == {"min_interval": 10} - pub.stop() - ns_.stop() - thr.join() + try: + with config.set(address_publish_port=free_port(), nameserver_port=free_port()): + ns_ = NameServer() + thr = Thread(target=ns_.run) + thr.start() + + pub = NoisyPublisher("test") + pub.start() + time.sleep(0.2) + + with Subscribe("test", topics="/heartbeat/test", nameserver="localhost") as sub: + time.sleep(0.2) + pub.heartbeat(min_interval=min_interval) + msg = next(sub.recv(1)) + assert msg.type == "beat" + assert msg.data == {"min_interval": min_interval} + finally: + pub.stop() + ns_.stop() + thr.join() def test_switch_backend_for_nameserver(): diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index 460f48c..01d7f97 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -153,10 +153,10 @@ def test_switch_to_secure_backend_for_nameserver(tmp_path): server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") with config.set(backend="secure_zmq", - client_secret_key_file=client_secret_key_file, - clients_public_keys_directory=os.path.dirname(client_public_key_file), - server_public_key_file=server_public_key_file, - server_secret_key_file=server_secret_key_file): + client_secret_key_file=client_secret_key_file, + clients_public_keys_directory=os.path.dirname(client_public_key_file), + server_public_key_file=server_public_key_file, + server_secret_key_file=server_secret_key_file): with create_nameserver_instance(): res = get_pub_address("some_name") From 96dea2fd7905f796a5a939e30f10662e35e9fd58 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 22 May 2024 09:16:52 +0200 Subject: [PATCH 41/45] Add script for generating keys and documentation --- doc/source/index.rst | 49 ++++++++++++++++++++++ posttroll/backends/zmq/__init__.py | 22 ++++++++++ posttroll/tests/test_secure_zmq_backend.py | 11 +++++ pyproject.toml | 1 + 4 files changed, 83 insertions(+) diff --git a/doc/source/index.rst b/doc/source/index.rst index b556936..7ff0e12 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -150,6 +150,55 @@ relevant socket options. .. _zmq_setsockopts: http://api.zeromq.org/master:zmq-setsockopt +Using secure ZeroMQ backend +--------------------------- + +To use securely authenticated sockets with posttroll (uses ZMQ's curve authentication), the backend needs to be defined +through posttroll config system, for example using an environment variable:: + + POSTTROLL_BACKEND=secure_zmq + +On the server side (for example a publisher), we need to define the server's secret key and the directory where the +accepted client keys are provided:: + + POSTTROLL_SERVER_SECRET_KEY_FILE=/path/to/server.key_secret + POSTTROLL_PUBLIC_SECRET_KEYS_DIRECTORY=/path/to/client_public_keys/ + +On the client side (for example a subscriber), we need to define the server's public key file and the client's secret +key file:: + + POSTTROLL_CLIENT_SECRET_KEY_FILE=/path/to/client.key_secret + POSTTROLL_SERVER_PUBLIC_KEY_FILE=/path/to/server.key + +These settings can also be set using the posttroll config object, for example:: + + >>> from posttroll import config + >>> with config.set(backend="secure_zmq", server_pubic_key_file="..."): + ... + +The posttroll configuration uses donfig, for more information, check https://donfig.readthedocs.io/en/latest/. + + +Generating the public and secret key pairs +****************************************** + +In order for the secure ZMQ backend to work, public/secret key pairs need to be generated, one for the client side and +one for the server side. A command-line script is provided for this purpose:: + + > posttroll-generate-keys -h + usage: posttroll-generate-keys [-h] [-d DIRECTORY] name + + Create a public/secret key pair for the secure zmq backend. This will create two files (in the current directory if not otherwise specified) with the suffixes '.key' and '.key_secret'. The name of the files will be the one provided. + + positional arguments: + name Name of the file. + + options: + -h, --help show this help message and exit + -d DIRECTORY, --directory DIRECTORY + Directory to place the keys in. + + Converting from older posttroll versions ---------------------------------------- diff --git a/posttroll/backends/zmq/__init__.py b/posttroll/backends/zmq/__init__.py index d59a127..c943737 100644 --- a/posttroll/backends/zmq/__init__.py +++ b/posttroll/backends/zmq/__init__.py @@ -1,8 +1,11 @@ """Main module for the zmq backend.""" +import argparse import logging import os +from pathlib import Path import zmq +from zmq.auth.certs import create_certificates from posttroll import config @@ -21,17 +24,20 @@ def get_context(): logger.debug("renewed context for PID %d", pid) return context[pid] + def destroy_context(linger=None): """Destroy the context.""" pid = os.getpid() context.pop(pid).destroy(linger) + def _set_tcp_keepalive(socket): """Set the tcp keepalive parameters on *socket*.""" keepalive_options = get_tcp_keepalive_options() for param, value in keepalive_options.items(): socket.setsockopt(param, value) + def get_tcp_keepalive_options(): """Get the tcp_keepalive options from config.""" keepalive_options = dict() @@ -46,3 +52,19 @@ def get_tcp_keepalive_options(): param = getattr(zmq, opt.upper()) keepalive_options[param] = value return keepalive_options + + +def generate_keys(args=None): + """Generate a public/secret key pair.""" + parser = argparse.ArgumentParser( + prog="posttroll-generate-keys", + description=("Create a public/secret key pair for the secure zmq backend. This will create two " + "files (in the current directory if not otherwise specified) with the suffixes '.key'" + " and '.key_secret'. The name of the files will be the one provided.")) + + parser.add_argument("name", type=str, help="Name of the file.") + parser.add_argument("-d", "--directory", help="Directory to place the keys in.", default=".", type=Path) + + parsed = parser.parse_args(args) + + create_certificates(parsed.directory, parsed.name) diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index 01d7f97..74619c9 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -161,3 +161,14 @@ def test_switch_to_secure_backend_for_nameserver(tmp_path): with create_nameserver_instance(): res = get_pub_address("some_name") assert res == "" + + + +def test_create_certificates_cli(tmp_path): + """Test the certificate creation cli.""" + from posttroll.backends.zmq import generate_keys + name = "server" + args = [name, "-d", str(tmp_path)] + generate_keys(args) + assert (tmp_path / (name + ".key")).exists() + assert (tmp_path / (name + ".key_secret")).exists() diff --git a/pyproject.toml b/pyproject.toml index 5f1c8e3..09c9663 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ [project.scripts] pytroll-logger = "posttroll.logger:run" +posttroll-generate-keys = "posttroll.backends.zmq:generate_keys" [project.urls] Homepage = "https://github.com/pytroll/posttroll" From 97b2f94e2069b3ff878ae99bd65b3489f2b768c7 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 22 May 2024 10:14:47 +0200 Subject: [PATCH 42/45] Fix style --- posttroll/__init__.py | 2 +- posttroll/address_receiver.py | 12 +++++------- posttroll/backends/zmq/ns.py | 6 ++---- posttroll/backends/zmq/socket.py | 3 +-- posttroll/backends/zmq/subscriber.py | 5 ----- posttroll/bbmcast.py | 7 ++++--- posttroll/message.py | 3 +++ posttroll/message_broadcaster.py | 3 ++- posttroll/subscriber.py | 1 + posttroll/tests/test_bbmcast.py | 13 ++++++++----- posttroll/tests/test_message.py | 12 ------------ posttroll/tests/test_nameserver.py | 1 - posttroll/tests/test_pubsub.py | 14 ++++++++++++-- posttroll/tests/test_secure_zmq_backend.py | 11 ++++++----- posttroll/tests/test_unsecure_zmq_backend.py | 1 + 15 files changed, 46 insertions(+), 48 deletions(-) diff --git a/posttroll/__init__.py b/posttroll/__init__.py index 46b0eaf..df053e3 100644 --- a/posttroll/__init__.py +++ b/posttroll/__init__.py @@ -68,5 +68,5 @@ def strp_isoformat(strg): else: dat, mis = strg.split(".") dat = dt.datetime.strptime(dat, "%Y-%m-%dT%H:%M:%S") - mis = int(float("." + mis)*1000000) + mis = int(float("." + mis) * 1000000) return dat.replace(microsecond=mis) diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index 6ca006e..42c0d4c 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -38,7 +38,7 @@ from zmq import ZMQError from posttroll import config -from posttroll.bbmcast import MulticastReceiver, SocketTimeout, get_configured_broadcast_port +from posttroll.bbmcast import MulticastReceiver, get_configured_broadcast_port from posttroll.message import Message from posttroll.publisher import Publish @@ -57,6 +57,7 @@ def get_configured_address_port(): return config.get("address_publish_port", DEFAULT_ADDRESS_PUBLISH_PORT) + def get_local_ips(): """Get local IP addresses.""" inet_addrs = [netifaces.ifaddresses(iface).get(netifaces.AF_INET) @@ -162,14 +163,11 @@ def _run(self): data, fromaddr = recv() except TimeoutError: if self._do_run: + if self._multicast_enabled: + LOGGER.debug("Multicast socket timed out on recv!") continue else: raise - - except SocketTimeout: - if self._multicast_enabled: - LOGGER.debug("Multicast socket timed out on recv!") - continue except ZMQError: return finally: @@ -229,7 +227,7 @@ def set_up_address_receiver(self, port): from posttroll.backends.zmq.address_receiver import SimpleReceiver recv = SimpleReceiver(port, timeout=2) nameservers = ["localhost"] - return nameservers,recv + return nameservers, recv def _add(self, adr, metadata): """Add an address.""" diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py index 5827920..dc0fcfb 100644 --- a/posttroll/backends/zmq/ns.py +++ b/posttroll/backends/zmq/ns.py @@ -43,10 +43,8 @@ def _fetch_address_using_socket(socket, name, timeout): socket.send_string(str(message)) # Get the reply. - #socket.poll(timeout) - #message = socket.recv(timeout) for message, _ in socket_receiver.receive(socket, timeout=timeout): - return message.data + return message.data except TimeoutError: raise TimeoutError("Didn't get an address after %d seconds." % timeout) @@ -61,6 +59,7 @@ def create_req_socket(timeout, nameserver_address): socket = set_up_client_socket(REQ, nameserver_address, options) return socket + class ZMQNameServer: """The name server.""" @@ -104,7 +103,6 @@ def close_sockets_and_threads(self): with suppress(AttributeError): self._authenticator.stop() - def stop(self): """Stop the name server.""" self.running = False diff --git a/posttroll/backends/zmq/socket.py b/posttroll/backends/zmq/socket.py index 132bb08..7adb295 100644 --- a/posttroll/backends/zmq/socket.py +++ b/posttroll/backends/zmq/socket.py @@ -117,10 +117,9 @@ def create_secure_server_socket(socket_type): # Tell authenticator to use the certificate in a directory authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory) - server_socket = ctx.socket(socket_type) - server_public, server_secret =zmq.auth.load_certificate(server_secret_key) + server_public, server_secret = zmq.auth.load_certificate(server_secret_key) server_socket.curve_secretkey = server_secret server_socket.curve_publickey = server_public server_socket.curve_server = True diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py index afdaebd..836f590 100644 --- a/posttroll/backends/zmq/subscriber.py +++ b/posttroll/backends/zmq/subscriber.py @@ -28,7 +28,6 @@ def __init__(self, addresses, topics="", message_filter=None, translate=False): self._hooks = [] self._hooks_cb = {} - #self.poller = Poller() self._sock_receiver = SocketReceiver() self._lock = Lock() @@ -119,7 +118,6 @@ def _add_hook(self, socket, callback): self._hooks.append(socket) self._hooks_cb[socket] = callback - @property def addresses(self): """Get the addresses.""" @@ -165,9 +163,6 @@ def _new_messages(self, timeout): if self._loop: LOGGER.exception("Receive failed: %s", str(err)) - - - def __call__(self, **kwargs): """Handle calls with class instance.""" return self.recv(**kwargs) diff --git a/posttroll/bbmcast.py b/posttroll/bbmcast.py index c2cf7b3..d9f3ae0 100644 --- a/posttroll/bbmcast.py +++ b/posttroll/bbmcast.py @@ -70,12 +70,12 @@ DEFAULT_BROADCAST_PORT = 21200 + def get_configured_broadcast_port(): """Get the configured nameserver port.""" return config.get("broadcast_port", DEFAULT_BROADCAST_PORT) - # ----------------------------------------------------------------------------- # # Sender. @@ -114,8 +114,8 @@ def mcast_sender(mcgroup=None): if _is_broadcast_group(mcgroup): group = "" sock.setsockopt(SOL_SOCKET, SO_BROADCAST, 1) - elif((int(mcgroup.split(".")[0]) > 239) or - (int(mcgroup.split(".")[0]) < 224)): + elif ((int(mcgroup.split(".")[0]) > 239) or + (int(mcgroup.split(".")[0]) < 224)): raise IOError(f"Invalid multicast address {mcgroup}") else: group = mcgroup @@ -130,6 +130,7 @@ def mcast_sender(mcgroup=None): raise return sock, group + def get_mc_group(): try: mcgroup = os.environ["PYTROLL_MC_GROUP"] diff --git a/posttroll/message.py b/posttroll/message.py index 541c0af..ab68484 100644 --- a/posttroll/message.py +++ b/posttroll/message.py @@ -282,12 +282,14 @@ def _decode(rawstr): return msg + def _check_for_version(raw): version = raw[4][:len(_VERSION)] if not _is_valid_version(version): raise MessageError("Invalid Message version: '%s'" % str(version)) return version + def _check_for_element_count(rawstr): raw = re.split(r"\s+", rawstr, maxsplit=6) if len(raw) < 5: @@ -296,6 +298,7 @@ def _check_for_element_count(rawstr): return raw + def _check_for_magic_word(rawstr): """Check for the magick word.""" try: diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py index d72dd4c..4990c36 100644 --- a/posttroll/message_broadcaster.py +++ b/posttroll/message_broadcaster.py @@ -34,6 +34,7 @@ LOGGER = logging.getLogger(__name__) + class DesignatedReceiversSender: """Sends message to multiple *receivers* on *port*.""" def __init__(self, default_port, receivers): @@ -51,7 +52,7 @@ def close(self): """Close the sender.""" return self._sender.close() -#----------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- # # General thread to broadcast messages. # diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index c32ccbf..fc3a8c1 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -155,6 +155,7 @@ def _magickfy_topics(topics): ts_.append(t__) return ts_ + class NSSubscriber: """Automatically subscribe to *services*. diff --git a/posttroll/tests/test_bbmcast.py b/posttroll/tests/test_bbmcast.py index 1a47b40..b61b26c 100644 --- a/posttroll/tests/test_bbmcast.py +++ b/posttroll/tests/test_bbmcast.py @@ -48,6 +48,7 @@ def test_mcast_sender_works_with_valid_addresses(): socket.close() + def test_mcast_sender_uses_broadcast_for_0s(): """Test mcast_sender uses broadcast for 0.0.0.0.""" mcgroup = "0.0.0.0" @@ -56,6 +57,7 @@ def test_mcast_sender_uses_broadcast_for_0s(): assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 socket.close() + def test_mcast_sender_uses_broadcast_for_255s(): """Test mcast_sender uses broadcast for 255.255.255.255.""" mcgroup = "255.255.255.255" @@ -64,6 +66,7 @@ def test_mcast_sender_uses_broadcast_for_255s(): assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 socket.close() + def test_mcast_sender_raises_for_invalit_adresses(): """Test mcast_sender uses broadcast for 0.0.0.0.""" mcgroup = (str(random.randint(0, 223)) + "." + @@ -78,7 +81,7 @@ def test_mcast_sender_raises_for_invalit_adresses(): str(random.randint(0, 255)) + "." + str(random.randint(0, 255))) with pytest.raises(OSError, match="Invalid multicast address .*"): - bbmcast.mcast_sender(mcgroup) + bbmcast.mcast_sender(mcgroup) def test_mcast_receiver_works_with_valid_addresses(): @@ -126,7 +129,7 @@ def test_multicast_roundtrip(reraise): """Test sending and receiving a multicast message.""" mcgroup = bbmcast.DEFAULT_MC_GROUP mcport = 5555 - rec_socket, rec_group = bbmcast.mcast_receiver(mcport, mcgroup) + rec_socket, _rec_group = bbmcast.mcast_receiver(mcport, mcgroup) rec_socket.settimeout(.1) message = "Ho Ho Ho!" @@ -136,7 +139,7 @@ def check_message(sock, message): data, _ = sock.recvfrom(1024) assert data.decode() == message - snd_socket, snd_group = bbmcast.mcast_sender(mcgroup) + snd_socket, _snd_group = bbmcast.mcast_sender(mcgroup) thr = Thread(target=check_message, args=(rec_socket, message)) thr.start() @@ -152,7 +155,7 @@ def test_broadcast_roundtrip(reraise): """Test sending and receiving a broadcast message.""" mcgroup = "0.0.0.0" mcport = 5555 - rec_socket, rec_group = bbmcast.mcast_receiver(mcport, mcgroup) + rec_socket, _rec_group = bbmcast.mcast_receiver(mcport, mcgroup) message = "Ho Ho Ho!" @@ -161,7 +164,7 @@ def check_message(sock, message): data, _ = sock.recvfrom(1024) assert data.decode() == message - snd_socket, snd_group = bbmcast.mcast_sender(mcgroup) + snd_socket, _snd_group = bbmcast.mcast_sender(mcgroup) thr = Thread(target=check_message, args=(rec_socket, message)) thr.start() diff --git a/posttroll/tests/test_message.py b/posttroll/tests/test_message.py index 5aa88bf..af97236 100644 --- a/posttroll/tests/test_message.py +++ b/posttroll/tests/test_message.py @@ -154,15 +154,3 @@ def test_serialization(self): msg = json.loads(local_dump) for key, val in msg.items(): assert val == metadata.get(key) - - -def suite(): - """Create the suite for test_message.""" - loader = unittest.TestLoader() - mysuite = unittest.TestSuite() - mysuite.addTest(loader.loadTestsFromTestCase(Test)) - - return mysuite - -if __name__ == "__main__": - unittest.main() diff --git a/posttroll/tests/test_nameserver.py b/posttroll/tests/test_nameserver.py index 8f8689b..f4fe81d 100644 --- a/posttroll/tests/test_nameserver.py +++ b/posttroll/tests/test_nameserver.py @@ -60,7 +60,6 @@ def create_nameserver_instance(max_age=3, multicast_enabled=True): thr.join() - class TestAddressReceiver(unittest.TestCase): """Test the AddressReceiver.""" diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index c72011d..c152bf8 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -170,7 +170,7 @@ def test_pub_minmax_port_from_instanciation(self): # Using range of ports defined at instantation time, this # should override environment variables for port in range(50000, 60000): - res = _get_port_from_publish_instance(min_port=port, max_port=port+1) + res = _get_port_from_publish_instance(min_port=port, max_port=port + 1) if res is False: # The port wasn't free, try again continue @@ -231,7 +231,7 @@ def test_listener_container(self): sub.stop() -## Test create_publisher_from_config +# Test create_publisher_from_config def test_publisher_with_invalid_arguments_crashes(): """Test that only valid arguments are passed to Publisher.""" @@ -248,6 +248,7 @@ def test_publisher_is_selected(): assert isinstance(pub, Publisher) assert pub is not None + @mock.patch("posttroll.publisher.Publisher") def test_publisher_all_arguments(Publisher): """Test that only valid arguments are passed to Publisher.""" @@ -258,11 +259,13 @@ def test_publisher_all_arguments(Publisher): assert Publisher.call_args[0][0].startswith("tcp://*:") assert Publisher.call_args[0][0].endswith(str(settings["port"])) + def test_no_name_raises_keyerror(): """Trying to create a NoisyPublisher without a given name will raise KeyError.""" with pytest.raises(KeyError): _ = create_publisher_from_dict_config(dict()) + def test_noisypublisher_is_selected_only_name(): """Test that NoisyPublisher is selected as publisher class.""" from posttroll.publisher import NoisyPublisher @@ -272,6 +275,7 @@ def test_noisypublisher_is_selected_only_name(): pub = create_publisher_from_dict_config(settings) assert isinstance(pub, NoisyPublisher) + def test_noisypublisher_is_selected_name_and_port(): """Test that NoisyPublisher is selected as publisher class.""" from posttroll.publisher import NoisyPublisher @@ -281,6 +285,7 @@ def test_noisypublisher_is_selected_name_and_port(): pub = create_publisher_from_dict_config(settings) assert isinstance(pub, NoisyPublisher) + @mock.patch("posttroll.publisher.NoisyPublisher") def test_noisypublisher_all_arguments(NoisyPublisher): """Test that only valid arguments are passed to NoisyPublisher.""" @@ -293,6 +298,7 @@ def test_noisypublisher_all_arguments(NoisyPublisher): _check_valid_settings_in_call(settings, NoisyPublisher, ignore=["name"]) assert NoisyPublisher.call_args[0][0] == settings["name"] + def test_publish_is_not_noisy(): """Test that Publisher is selected with the context manager when it should be.""" from posttroll.publisher import Publish @@ -300,6 +306,7 @@ def test_publish_is_not_noisy(): with Publish("service_name", port=40000, nameservers=False) as pub: assert isinstance(pub, Publisher) + def test_publish_is_noisy_only_name(): """Test that NoisyPublisher is selected with the context manager when only name is given.""" from posttroll.publisher import NoisyPublisher, Publish @@ -307,6 +314,7 @@ def test_publish_is_noisy_only_name(): with Publish("service_name") as pub: assert isinstance(pub, NoisyPublisher) + def test_publish_is_noisy_with_port(): """Test that NoisyPublisher is selected with the context manager when port is given.""" from posttroll.publisher import NoisyPublisher, Publish @@ -314,6 +322,7 @@ def test_publish_is_noisy_with_port(): with Publish("service_name", port=40001) as pub: assert isinstance(pub, NoisyPublisher) + def test_publish_is_noisy_with_nameservers(): """Test that NoisyPublisher is selected with the context manager when nameservers are given.""" from posttroll.publisher import NoisyPublisher, Publish @@ -412,6 +421,7 @@ def _tcp_keepalive_settings(monkeypatch): with config.set(tcp_keepalive=1, tcp_keepalive_cnt=10, tcp_keepalive_idle=1, tcp_keepalive_intvl=1): yield + @contextmanager def reset_config_for_tests(): """Reset the config for testing.""" diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index 74619c9..f467fe2 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -26,10 +26,10 @@ def create_keys(tmp_path): secret_keys_dir.mkdir() # create new keys in certificates dir - server_public_file, server_secret_file = zmq.auth.create_certificates( + _server_public_file, _server_secret_file = zmq.auth.create_certificates( keys_dir, "server" ) - client_public_file, client_secret_file = zmq.auth.create_certificates( + _client_public_file, _client_secret_file = zmq.auth.create_certificates( keys_dir, "client" ) @@ -66,8 +66,8 @@ def test_ipc_pubsub_with_sec(tmp_path): pub = Publisher(ipc_address) - pub.start() + def delayed_send(msg): time.sleep(.2) from posttroll.message import Message @@ -111,7 +111,7 @@ def test_switch_to_secure_zmq_backend(tmp_path): def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): """Test pub-sub on a secure ipc socket.""" - #create_keys(tmp_path) + # create_keys(tmp_path) server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") @@ -131,6 +131,7 @@ def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): pub = create_publisher_from_dict_config(pub_settings) pub.start() + def delayed_send(msg): time.sleep(.2) from posttroll.message import Message @@ -148,6 +149,7 @@ def delayed_send(msg): thr.join() pub.stop() + def test_switch_to_secure_backend_for_nameserver(tmp_path): """Test switching backend for nameserver.""" server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") @@ -163,7 +165,6 @@ def test_switch_to_secure_backend_for_nameserver(tmp_path): assert res == "" - def test_create_certificates_cli(tmp_path): """Test the certificate creation cli.""" from posttroll.backends.zmq import generate_keys diff --git a/posttroll/tests/test_unsecure_zmq_backend.py b/posttroll/tests/test_unsecure_zmq_backend.py index 66dbd6e..1b2b469 100644 --- a/posttroll/tests/test_unsecure_zmq_backend.py +++ b/posttroll/tests/test_unsecure_zmq_backend.py @@ -18,6 +18,7 @@ def test_ipc_pubsub(tmp_path): sub = create_subscriber_from_dict_config(subscriber_settings) pub = Publisher(ipc_address) pub.start() + def delayed_send(msg): time.sleep(.2) from posttroll.message import Message From 7a2633ec56192633f3114113e1f835e794677264 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 22 May 2024 11:21:02 +0200 Subject: [PATCH 43/45] Change sphinx theme --- doc/source/conf.py | 2 +- posttroll/tests/test_secure_zmq_backend.py | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/doc/source/conf.py b/doc/source/conf.py index d45d3d6..08ac172 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -26,5 +26,5 @@ # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = "alabaster" +html_theme = "sphinx_rtd_theme" html_static_path = ["_static"] diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index f467fe2..36542d4 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -4,12 +4,15 @@ import os import shutil import time +from threading import Thread import zmq.auth from posttroll import config +from posttroll.backends.zmq import generate_keys +from posttroll.message import Message from posttroll.ns import get_pub_address -from posttroll.publisher import Publisher +from posttroll.publisher import Publisher, create_publisher_from_dict_config from posttroll.subscriber import Subscriber, create_subscriber_from_dict_config from posttroll.tests.test_nameserver import create_nameserver_instance @@ -62,7 +65,6 @@ def test_ipc_pubsub_with_sec(tmp_path): server_secret_key_file=server_secret_key_file): subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202) sub = create_subscriber_from_dict_config(subscriber_settings) - from posttroll.publisher import Publisher pub = Publisher(ipc_address) @@ -70,10 +72,8 @@ def test_ipc_pubsub_with_sec(tmp_path): def delayed_send(msg): time.sleep(.2) - from posttroll.message import Message msg = Message(subject="/hi", atype="string", data=msg) pub.send(str(msg)) - from threading import Thread thr = Thread(target=delayed_send, args=["very sensitive message"]) thr.start() try: @@ -125,7 +125,6 @@ def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): server_secret_key_file=server_secret_key_file): subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202) sub = create_subscriber_from_dict_config(subscriber_settings) - from posttroll.publisher import create_publisher_from_dict_config pub_settings = dict(address=ipc_address, nameservers=False, port=1789) pub = create_publisher_from_dict_config(pub_settings) @@ -133,11 +132,8 @@ def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): pub.start() def delayed_send(msg): - time.sleep(.2) - from posttroll.message import Message msg = Message(subject="/hi", atype="string", data=msg) pub.send(str(msg)) - from threading import Thread thr = Thread(target=delayed_send, args=["very sensitive message"]) thr.start() try: @@ -167,7 +163,6 @@ def test_switch_to_secure_backend_for_nameserver(tmp_path): def test_create_certificates_cli(tmp_path): """Test the certificate creation cli.""" - from posttroll.backends.zmq import generate_keys name = "server" args = [name, "-d", str(tmp_path)] generate_keys(args) From ca576b5eed7fe3d6a04b8298d42cdfcf3666d875 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Wed, 22 May 2024 11:26:34 +0200 Subject: [PATCH 44/45] Fix sphinx theme --- doc/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/requirements.txt b/doc/requirements.txt index 9c558e3..91f8e5d 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1 +1,2 @@ +sphinx-rtd-theme . From c657b1ccbf3fbecb5819c1b41ddd2d31b592729b Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Mon, 3 Jun 2024 12:36:07 +0200 Subject: [PATCH 45/45] Fix tests --- posttroll/tests/test_secure_zmq_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index 36542d4..f912125 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -123,7 +123,7 @@ def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): clients_public_keys_directory=os.path.dirname(client_public_key_file), server_public_key_file=server_public_key_file, server_secret_key_file=server_secret_key_file): - subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202) + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10203) sub = create_subscriber_from_dict_config(subscriber_settings) pub_settings = dict(address=ipc_address, nameservers=False, port=1789) @@ -132,6 +132,7 @@ def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): pub.start() def delayed_send(msg): + time.sleep(.2) msg = Message(subject="/hi", atype="string", data=msg) pub.send(str(msg)) thr = Thread(target=delayed_send, args=["very sensitive message"])