diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 0b794a5..924a954 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -21,7 +21,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
- pip install -U pytest pytest-cov pyzmq netifaces donfig
+ pip install -U pytest pytest-cov pyzmq netifaces-plus donfig pytest-reraise
- name: Install posttroll
run: |
pip install --no-deps -e .
diff --git a/bin/nameserver b/bin/nameserver
index d2d3370..8aa89e6 100755
--- a/bin/nameserver
+++ b/bin/nameserver
@@ -20,19 +20,17 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-"""The nameserver. Port 5555 (hardcoded) is used for communications.
-"""
+"""The nameserver. Port 5555 (hardcoded) is used for communications."""
# TODO: make port configurable.
-from posttroll.ns import NameServer
-
import logging
-import _strptime
+
+from posttroll.ns import NameServer
logger = logging.getLogger(__name__)
-if __name__ == '__main__':
+if __name__ == "__main__":
import argparse
@@ -58,14 +56,14 @@ if __name__ == '__main__':
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("[%(levelname)s: %(asctime)s :"
" %(name)s] %(message)s",
- '%Y-%m-%d %H:%M:%S'))
+ "%Y-%m-%d %H:%M:%S"))
if opts.verbose:
loglevel = logging.DEBUG
else:
loglevel = logging.INFO
handler.setLevel(loglevel)
- logging.getLogger('').setLevel(loglevel)
- logging.getLogger('').addHandler(handler)
+ logging.getLogger("").setLevel(loglevel)
+ logging.getLogger("").addHandler(handler)
logger = logging.getLogger("nameserver")
multicast_enabled = (opts.no_multicast == False)
diff --git a/doc/Makefile b/doc/Makefile
index bd4a009..d0c3cbf 100644
--- a/doc/Makefile
+++ b/doc/Makefile
@@ -1,130 +1,20 @@
-# Makefile for Sphinx documentation
+# Minimal makefile for Sphinx documentation
#
-# You can set these variables from the command line.
-SPHINXOPTS =
-SPHINXBUILD = sphinx-build
-PAPER =
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
BUILDDIR = build
-# Internal variables.
-PAPEROPT_a4 = -D latex_paper_size=a4
-PAPEROPT_letter = -D latex_paper_size=letter
-ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source
-
-.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest
-
+# Put it first so that "make" without argument is like "make help".
help:
- @echo "Please use \`make ' where is one of"
- @echo " html to make standalone HTML files"
- @echo " dirhtml to make HTML files named index.html in directories"
- @echo " singlehtml to make a single large HTML file"
- @echo " pickle to make pickle files"
- @echo " json to make JSON files"
- @echo " htmlhelp to make HTML files and a HTML help project"
- @echo " qthelp to make HTML files and a qthelp project"
- @echo " devhelp to make HTML files and a Devhelp project"
- @echo " epub to make an epub"
- @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
- @echo " latexpdf to make LaTeX files and run them through pdflatex"
- @echo " text to make text files"
- @echo " man to make manual pages"
- @echo " changes to make an overview of all changed/added/deprecated items"
- @echo " linkcheck to check all external links for integrity"
- @echo " doctest to run all doctests embedded in the documentation (if enabled)"
-
-clean:
- -rm -rf $(BUILDDIR)/*
-
-html:
- $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
- @echo
- @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
-
-dirhtml:
- $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
- @echo
- @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
-
-singlehtml:
- $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
- @echo
- @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
-
-pickle:
- $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
- @echo
- @echo "Build finished; now you can process the pickle files."
-
-json:
- $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
- @echo
- @echo "Build finished; now you can process the JSON files."
-
-htmlhelp:
- $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
- @echo
- @echo "Build finished; now you can run HTML Help Workshop with the" \
- ".hhp project file in $(BUILDDIR)/htmlhelp."
-
-qthelp:
- $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
- @echo
- @echo "Build finished; now you can run "qcollectiongenerator" with the" \
- ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
- @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/PostTroll.qhcp"
- @echo "To view the help file:"
- @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/PostTroll.qhc"
-
-devhelp:
- $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
- @echo
- @echo "Build finished."
- @echo "To view the help file:"
- @echo "# mkdir -p $$HOME/.local/share/devhelp/PostTroll"
- @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/PostTroll"
- @echo "# devhelp"
-
-epub:
- $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
- @echo
- @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
-
-latex:
- $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
- @echo
- @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
- @echo "Run \`make' in that directory to run these through (pdf)latex" \
- "(use \`make latexpdf' here to do that automatically)."
-
-latexpdf:
- $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
- @echo "Running LaTeX files through pdflatex..."
- make -C $(BUILDDIR)/latex all-pdf
- @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
-
-text:
- $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
- @echo
- @echo "Build finished. The text files are in $(BUILDDIR)/text."
-
-man:
- $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
- @echo
- @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
-
-changes:
- $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
- @echo
- @echo "The overview file is in $(BUILDDIR)/changes."
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
-linkcheck:
- $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
- @echo
- @echo "Link check complete; look for any errors in the above output " \
- "or in $(BUILDDIR)/linkcheck/output.txt."
+.PHONY: help Makefile
-doctest:
- $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
- @echo "Testing of doctests in the sources finished, look at the " \
- "results in $(BUILDDIR)/doctest/output.txt."
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/doc/requirements.txt b/doc/requirements.txt
index 9c558e3..91f8e5d 100644
--- a/doc/requirements.txt
+++ b/doc/requirements.txt
@@ -1 +1,2 @@
+sphinx-rtd-theme
.
diff --git a/doc/source/conf.py b/doc/source/conf.py
index 5593311..08ac172 100644
--- a/doc/source/conf.py
+++ b/doc/source/conf.py
@@ -1,247 +1,30 @@
-# -*- coding: utf-8 -*-
+# Configuration file for the Sphinx documentation builder.
#
-# PostTroll documentation build configuration file, created by
-# sphinx-quickstart on Tue Sep 11 12:58:14 2012.
-#
-# This file is execfile()d with the current directory set to its containing dir.
-#
-# Note that not all possible configuration values are present in this
-# autogenerated file.
-#
-# All configuration values have a default; values that are commented out
-# serve to show the default.
-
-import sys
-import os
-from posttroll import __version__
-# If extensions (or modules to document with autodoc) are in another directory,
-# add these directories to sys.path here. If the directory is relative to the
-# documentation root, use os.path.abspath to make it absolute, like shown here.
-#sys.path.insert(0, os.path.abspath('.'))
-sys.path.insert(0, os.path.abspath('../../'))
-sys.path.insert(0, os.path.abspath('../../posttroll'))
-
-
-
-class Mock(object):
- """A mocking class
- """
- def __init__(self, *args, **kwargs):
- pass
-
- def __call__(self, *args, **kwargs):
- return Mock()
-
- @classmethod
- def __getattr__(cls, name):
- if name in ('__file__', '__path__'):
- return '/dev/null'
- elif name[0] == name[0].upper():
- mock_type = type(name, (), {})
- mock_type.__module__ = __name__
- return mock_type
- else:
- return Mock()
-
-MOCK_MODULES = ['zmq']
-for mod_name in MOCK_MODULES:
- sys.modules[mod_name] = Mock()
-
-# -- General configuration -----------------------------------------------------
-
-# If your documentation needs a minimal Sphinx version, state it here.
-#needs_sphinx = '1.0'
-
-# Add any Sphinx extension module names here, as strings. They can be extensions
-# coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
-extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest']
-
-# Add any paths that contain templates here, relative to this directory.
-templates_path = ['sphinx_templates']
-
-# The suffix of source filenames.
-source_suffix = '.rst'
+# For the full list of built-in configuration values, see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
-# The encoding of source files.
-#source_encoding = 'utf-8-sig'
+# -- Project information -----------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
+from posttroll.version import version
-# The master toctree document.
-master_doc = 'index'
+project = "Posttroll"
+copyright = "2012, Pytroll Crew"
+author = "Pytroll Crew"
+release = version
-# General information about the project.
-project = u'PostTroll'
-copyright = u'2012-2014, Pytroll crew'
+# -- General configuration ---------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
-# The version info for the project you're documenting, acts as replacement for
-# |version| and |release|, also used in various other places throughout the
-# built documents.
-#
-
-
-# The full version, including alpha/beta/rc tags.
-release = __version__
-# The short X.Y version.
-version = ".".join(release.split(".")[:2])
-
-# The language for content autogenerated by Sphinx. Refer to documentation
-# for a list of supported languages.
-#language = None
+extensions = ["sphinx.ext.napoleon", "sphinx.ext.autodoc"]
+autodoc_mock_imports = ["pyzmq"]
-# There are two options for replacing |today|: either, you set today to some
-# non-false value, then it is used:
-#today = ''
-# Else, today_fmt is used as the format for a strftime call.
-#today_fmt = '%B %d, %Y'
-
-# List of patterns, relative to source directory, that match files and
-# directories to ignore when looking for source files.
+templates_path = ["_templates"]
exclude_patterns = []
-# The reST default role (used for this markup: `text`) to use for all documents.
-#default_role = None
-
-# If true, '()' will be appended to :func: etc. cross-reference text.
-#add_function_parentheses = True
-
-# If true, the current module name will be prepended to all description
-# unit titles (such as .. function::).
-#add_module_names = True
-
-# If true, sectionauthor and moduleauthor directives will be shown in the
-# output. They are ignored by default.
-#show_authors = False
-
-# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'sphinx'
-
-# A list of ignored prefixes for module index sorting.
-#modindex_common_prefix = []
-
-
-# -- Options for HTML output ---------------------------------------------------
-
-# The theme to use for HTML and HTML Help pages. See the documentation for
-# a list of builtin themes.
-html_theme = 'default'
-
-# Theme options are theme-specific and customize the look and feel of a theme
-# further. For a list of options available for each theme, see the
-# documentation.
-#html_theme_options = {}
-
-# Add any paths that contain custom themes here, relative to this directory.
-#html_theme_path = []
-
-# The name for this set of Sphinx documents. If None, it defaults to
-# " v documentation".
-#html_title = None
-
-# A shorter title for the navigation bar. Default is the same as html_title.
-#html_short_title = None
-
-# The name of an image file (relative to this directory) to place at the top
-# of the sidebar.
-#html_logo = None
-
-# The name of an image file (within the static path) to use as favicon of the
-# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
-# pixels large.
-#html_favicon = None
-
-# Add any paths that contain custom static files (such as style sheets) here,
-# relative to this directory. They are copied after the builtin static files,
-# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['sphinx_static']
-
-# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
-# using the given strftime format.
-#html_last_updated_fmt = '%b %d, %Y'
-
-# If true, SmartyPants will be used to convert quotes and dashes to
-# typographically correct entities.
-#html_use_smartypants = True
-
-# Custom sidebar templates, maps document names to template names.
-#html_sidebars = {}
-
-# Additional templates that should be rendered to pages, maps page names to
-# template names.
-#html_additional_pages = {}
-
-# If false, no module index is generated.
-#html_domain_indices = True
-
-# If false, no index is generated.
-#html_use_index = True
-
-# If true, the index is split into individual pages for each letter.
-#html_split_index = False
-
-# If true, links to the reST sources are added to the pages.
-#html_show_sourcelink = True
-
-# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
-#html_show_sphinx = True
-
-# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
-#html_show_copyright = True
-
-# If true, an OpenSearch description file will be output, and all pages will
-# contain a tag referring to it. The value of this option must be the
-# base URL from which the finished HTML is served.
-#html_use_opensearch = ''
-
-# This is the file name suffix for HTML files (e.g. ".xhtml").
-#html_file_suffix = None
-
-# Output file base name for HTML help builder.
-htmlhelp_basename = 'PostTrolldoc'
-
-
-# -- Options for LaTeX output --------------------------------------------------
-
-# The paper size ('letter' or 'a4').
-#latex_paper_size = 'letter'
-
-# The font size ('10pt', '11pt' or '12pt').
-#latex_font_size = '10pt'
-
-# Grouping the document tree into LaTeX files. List of tuples
-# (source start file, target name, title, author, documentclass [howto/manual]).
-latex_documents = [
- ('index', 'PostTroll.tex', u'PostTroll Documentation',
- u'Pytroll crew', 'manual'),
-]
-
-# The name of an image file (relative to this directory) to place at the top of
-# the title page.
-#latex_logo = None
-
-# For "manual" documents, if this is true, then toplevel headings are parts,
-# not chapters.
-#latex_use_parts = False
-
-# If true, show page references after internal links.
-#latex_show_pagerefs = False
-
-# If true, show URL addresses after external links.
-#latex_show_urls = False
-
-# Additional stuff for the LaTeX preamble.
-#latex_preamble = ''
-
-# Documents to append as an appendix to all manuals.
-#latex_appendices = []
-
-# If false, no module index is generated.
-#latex_domain_indices = True
-# -- Options for manual page output --------------------------------------------
+# -- Options for HTML output -------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
-# One entry per manual page. List of tuples
-# (source start file, name, description, authors, manual section).
-man_pages = [
- ('index', 'posttroll', u'PostTroll Documentation',
- [u'Pytroll crew'], 1)
-]
+html_theme = "sphinx_rtd_theme"
+html_static_path = ["_static"]
diff --git a/doc/source/index.rst b/doc/source/index.rst
index e66bba3..7ff0e12 100644
--- a/doc/source/index.rst
+++ b/doc/source/index.rst
@@ -115,6 +115,17 @@ to specify the nameserver(s) explicitly in the publishing code::
.. seealso:: :class:`posttroll.publisher.Publish`
and :class:`posttroll.subscriber.Subscribe`
+Configuration parameters
+------------------------
+
+Global configuration variables that are available through a Donfig configuration object:
+- tcp_keepalive
+- tcp_keepalive_cnt
+- tcp_keepalive_idle
+- tcp_keepalive_intvl
+- multicast_interface
+- mc_group
+
Setting TCP keep-alive
----------------------
@@ -139,6 +150,55 @@ relevant socket options.
.. _zmq_setsockopts: http://api.zeromq.org/master:zmq-setsockopt
+Using secure ZeroMQ backend
+---------------------------
+
+To use securely authenticated sockets with posttroll (uses ZMQ's curve authentication), the backend needs to be defined
+through posttroll config system, for example using an environment variable::
+
+ POSTTROLL_BACKEND=secure_zmq
+
+On the server side (for example a publisher), we need to define the server's secret key and the directory where the
+accepted client keys are provided::
+
+ POSTTROLL_SERVER_SECRET_KEY_FILE=/path/to/server.key_secret
+ POSTTROLL_PUBLIC_SECRET_KEYS_DIRECTORY=/path/to/client_public_keys/
+
+On the client side (for example a subscriber), we need to define the server's public key file and the client's secret
+key file::
+
+ POSTTROLL_CLIENT_SECRET_KEY_FILE=/path/to/client.key_secret
+ POSTTROLL_SERVER_PUBLIC_KEY_FILE=/path/to/server.key
+
+These settings can also be set using the posttroll config object, for example::
+
+ >>> from posttroll import config
+ >>> with config.set(backend="secure_zmq", server_pubic_key_file="..."):
+ ...
+
+The posttroll configuration uses donfig, for more information, check https://donfig.readthedocs.io/en/latest/.
+
+
+Generating the public and secret key pairs
+******************************************
+
+In order for the secure ZMQ backend to work, public/secret key pairs need to be generated, one for the client side and
+one for the server side. A command-line script is provided for this purpose::
+
+ > posttroll-generate-keys -h
+ usage: posttroll-generate-keys [-h] [-d DIRECTORY] name
+
+ Create a public/secret key pair for the secure zmq backend. This will create two files (in the current directory if not otherwise specified) with the suffixes '.key' and '.key_secret'. The name of the files will be the one provided.
+
+ positional arguments:
+ name Name of the file.
+
+ options:
+ -h, --help show this help message and exit
+ -d DIRECTORY, --directory DIRECTORY
+ Directory to place the keys in.
+
+
Converting from older posttroll versions
----------------------------------------
@@ -229,15 +289,6 @@ Multicast code
:members:
:undoc-members:
-
-Connections
-~~~~~~~~~~~
-
-.. automodule:: posttroll.connections
- :members:
- :undoc-members:
-
-
Misc
~~~~
diff --git a/posttroll/__init__.py b/posttroll/__init__.py
index d812970..df053e3 100644
--- a/posttroll/__init__.py
+++ b/posttroll/__init__.py
@@ -26,16 +26,12 @@
import datetime as dt
import logging
-import os
import sys
-import zmq
from donfig import Config
-from .version import get_versions
-
-config = Config('posttroll')
-context = {}
+config = Config("posttroll", defaults=[dict(backend="unsecure_zmq")])
+# context = {}
logger = logging.getLogger(__name__)
@@ -44,11 +40,12 @@ def get_context():
This function takes care of creating new contexts in case of forks.
"""
- pid = os.getpid()
- if pid not in context:
- context[pid] = zmq.Context()
- logger.debug('renewed context for PID %d', pid)
- return context[pid]
+ backend = config["backend"]
+ if "zmq" in backend:
+ from posttroll.backends.zmq import get_context
+ return get_context()
+ else:
+ raise NotImplementedError(f"No support for backend {backend} implemented (yet?).")
def strp_isoformat(strg):
@@ -62,30 +59,14 @@ def strp_isoformat(strg):
return strg
if len(strg) < 19 or len(strg) > 26:
if len(strg) > 30:
- strg = strg[:30] + '...'
+ strg = strg[:30] + "..."
raise ValueError("Invalid ISO formatted time string '%s'" % strg)
if strg.find(".") == -1:
- strg += '.000000'
- if sys.version[0:3] >= '2.6':
+ strg += ".000000"
+ if sys.version[0:3] >= "2.6":
return dt.datetime.strptime(strg, "%Y-%m-%dT%H:%M:%S.%f")
else:
dat, mis = strg.split(".")
dat = dt.datetime.strptime(dat, "%Y-%m-%dT%H:%M:%S")
- mis = int(float('.' + mis)*1000000)
+ mis = int(float("." + mis) * 1000000)
return dat.replace(microsecond=mis)
-
-
-def _set_tcp_keepalive(socket):
- _set_int_sockopt(socket, zmq.TCP_KEEPALIVE, config.get("tcp_keepalive", None))
- _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_CNT, config.get("tcp_keepalive_cnt", None))
- _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_IDLE, config.get("tcp_keepalive_idle", None))
- _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_INTVL, config.get("tcp_keepalive_intvl", None))
-
-
-def _set_int_sockopt(socket, param, value):
- if value is not None:
- socket.setsockopt(param, int(value))
-
-
-__version__ = get_versions()['version']
-del get_versions
diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py
index dc8076b..42c0d4c 100644
--- a/posttroll/address_receiver.py
+++ b/posttroll/address_receiver.py
@@ -35,27 +35,29 @@
import time
import netifaces
-from zmq import REP, LINGER
+from zmq import ZMQError
-from posttroll.bbmcast import MulticastReceiver, SocketTimeout
+from posttroll import config
+from posttroll.bbmcast import MulticastReceiver, get_configured_broadcast_port
from posttroll.message import Message
from posttroll.publisher import Publish
-from posttroll import get_context
-
-__all__ = ('AddressReceiver', 'getaddress')
+__all__ = ("AddressReceiver", "getaddress")
LOGGER = logging.getLogger(__name__)
-debug = os.environ.get('DEBUG', False)
-broadcast_port = 21200
+debug = os.environ.get("DEBUG", False)
-default_publish_port = 16543
+DEFAULT_ADDRESS_PUBLISH_PORT = 16543
ten_minutes = dt.timedelta(minutes=10)
zero_seconds = dt.timedelta(seconds=0)
+def get_configured_address_port():
+ return config.get("address_publish_port", DEFAULT_ADDRESS_PUBLISH_PORT)
+
+
def get_local_ips():
"""Get local IP addresses."""
inet_addrs = [netifaces.ifaddresses(iface).get(netifaces.AF_INET)
@@ -64,7 +66,7 @@ def get_local_ips():
for addr in inet_addrs:
if addr is not None:
for add in addr:
- ips.append(add['addr'])
+ ips.append(add["addr"])
return ips
# -----------------------------------------------------------------------------
@@ -74,17 +76,17 @@ def get_local_ips():
# -----------------------------------------------------------------------------
-class AddressReceiver(object):
+class AddressReceiver:
"""General thread to receive broadcast addresses."""
def __init__(self, max_age=ten_minutes, port=None,
do_heartbeat=True, multicast_enabled=True, restrict_to_localhost=False):
- """Initialize addres receiver."""
+ """Set up the address receiver."""
self._max_age = max_age
- self._port = port or default_publish_port
+ self._port = port or get_configured_address_port()
self._address_lock = threading.Lock()
self._addresses = {}
- self._subject = '/address'
+ self._subject = "/address"
self._do_heartbeat = do_heartbeat
self._multicast_enabled = multicast_enabled
self._last_age_check = dt.datetime(1900, 1, 1)
@@ -121,7 +123,7 @@ def get(self, name=""):
mda = copy.copy(metadata)
mda["receive_time"] = mda["receive_time"].isoformat()
addrs.append(mda)
- LOGGER.debug('return address %s', str(addrs))
+ LOGGER.debug("return address %s", str(addrs))
return addrs
def _check_age(self, pub, min_interval=zero_seconds):
@@ -137,41 +139,20 @@ def _check_age(self, pub, min_interval=zero_seconds):
for addr, metadata in self._addresses.items():
atime = metadata["receive_time"]
if now - atime > self._max_age:
- mda = {'status': False,
- 'URI': addr,
- 'service': metadata['service']}
- msg = Message('/address/' + metadata['name'], 'info', mda)
+ mda = {"status": False,
+ "URI": addr,
+ "service": metadata["service"]}
+ msg = Message("/address/" + metadata["name"], "info", mda)
to_del.append(addr)
- LOGGER.info("publish remove '%s'", str(msg))
- pub.send(msg.encode())
+ LOGGER.info(f"publish remove '{msg}'")
+ pub.send(str(msg.encode()))
for addr in to_del:
del self._addresses[addr]
def _run(self):
"""Run the receiver."""
- port = broadcast_port
- nameservers = []
- if self._multicast_enabled:
- while True:
- try:
- recv = MulticastReceiver(port)
- except IOError as err:
- if err.errno == errno.ENODEV:
- LOGGER.error("Receiver initialization failed "
- "(no such device). "
- "Trying again in %d s",
- 10)
- time.sleep(10)
- else:
- raise
- else:
- recv.settimeout(tout=2.0)
- LOGGER.info("Receiver initialized.")
- break
-
- else:
- recv = _SimpleReceiver(port)
- nameservers = ["localhost"]
+ port = get_configured_broadcast_port()
+ nameservers, recv = self.set_up_address_receiver(port)
self._is_running = True
with Publish("address_receiver", self._port, ["addresses"],
@@ -180,30 +161,35 @@ def _run(self):
while self._do_run:
try:
data, fromaddr = recv()
- if self._multicast_enabled:
- ip_, port = fromaddr
- if self._restrict_to_localhost and ip_ not in self._local_ips:
- # discard external message
- LOGGER.debug('Discard external message')
- continue
- LOGGER.debug("data %s", data)
- except SocketTimeout:
- if self._multicast_enabled:
- LOGGER.debug("Multicast socket timed out on recv!")
+ except TimeoutError:
+ if self._do_run:
+ if self._multicast_enabled:
+ LOGGER.debug("Multicast socket timed out on recv!")
continue
+ else:
+ raise
+ except ZMQError:
+ return
finally:
self._check_age(pub, min_interval=self._max_age / 20)
if self._do_heartbeat:
pub.heartbeat(min_interval=29)
+ if self._multicast_enabled:
+ ip_, port = fromaddr
+ if self._restrict_to_localhost and ip_ not in self._local_ips:
+ # discard external message
+ LOGGER.debug("Discard external message")
+ continue
+ LOGGER.debug("data %s", data)
msg = Message.decode(data)
name = msg.subject.split("/")[1]
- if msg.type == 'info' and msg.subject.lower().startswith(self._subject):
+ if msg.type == "info" and msg.subject.lower().startswith(self._subject):
addr = msg.data["URI"]
- msg.data['status'] = True
+ msg.data["status"] = True
metadata = copy.copy(msg.data)
metadata["name"] = name
- LOGGER.debug('receiving address %s %s %s', str(addr),
+ LOGGER.debug("receiving address %s %s %s", str(addr),
str(name), str(metadata))
if addr not in self._addresses:
LOGGER.info("nameserver: publish add '%s'",
@@ -214,6 +200,35 @@ def _run(self):
self._is_running = False
recv.close()
+ def set_up_address_receiver(self, port):
+ """Set up the address receiver depending on if it is multicast or not."""
+ nameservers = False
+ if self._multicast_enabled:
+ while True:
+ try:
+ recv = MulticastReceiver(port)
+ except IOError as err:
+ if err.errno == errno.ENODEV:
+ LOGGER.error("Receiver initialization failed "
+ "(no such device). "
+ "Trying again in %d s",
+ 10)
+ time.sleep(10)
+ else:
+ raise
+ else:
+ recv.settimeout(tout=2.0)
+ LOGGER.info("Receiver initialized.")
+ break
+
+ else:
+ if config["backend"] not in ["unsecure_zmq", "secure_zmq"]:
+ raise NotImplementedError
+ from posttroll.backends.zmq.address_receiver import SimpleReceiver
+ recv = SimpleReceiver(port, timeout=2)
+ nameservers = ["localhost"]
+ return nameservers, recv
+
def _add(self, adr, metadata):
"""Add an address."""
with self._address_lock:
@@ -221,27 +236,6 @@ def _add(self, adr, metadata):
self._addresses[adr] = metadata
-class _SimpleReceiver(object):
- """Simple listing on port for address messages."""
-
- def __init__(self, port=None):
- """Initialize receiver."""
- self._port = port or default_publish_port
- self._socket = get_context().socket(REP)
- self._socket.bind("tcp://*:" + str(port))
-
- def __call__(self):
- """Receive and return a message."""
- message = self._socket.recv_string()
- self._socket.send_string("ok")
- return message, None
-
- def close(self):
- """Close the receiver."""
- self._socket.setsockopt(LINGER, 1)
- self._socket.close()
-
-
# -----------------------------------------------------------------------------
# default
getaddress = AddressReceiver
diff --git a/posttroll/backends/__init__.py b/posttroll/backends/__init__.py
new file mode 100644
index 0000000..982ba70
--- /dev/null
+++ b/posttroll/backends/__init__.py
@@ -0,0 +1 @@
+"""Init file for the backends."""
diff --git a/posttroll/backends/zmq/__init__.py b/posttroll/backends/zmq/__init__.py
new file mode 100644
index 0000000..c943737
--- /dev/null
+++ b/posttroll/backends/zmq/__init__.py
@@ -0,0 +1,70 @@
+"""Main module for the zmq backend."""
+import argparse
+import logging
+import os
+from pathlib import Path
+
+import zmq
+from zmq.auth.certs import create_certificates
+
+from posttroll import config
+
+logger = logging.getLogger(__name__)
+context = {}
+
+
+def get_context():
+ """Provide the context to use.
+
+ This function takes care of creating new contexts in case of forks.
+ """
+ pid = os.getpid()
+ if pid not in context:
+ context[pid] = zmq.Context()
+ logger.debug("renewed context for PID %d", pid)
+ return context[pid]
+
+
+def destroy_context(linger=None):
+ """Destroy the context."""
+ pid = os.getpid()
+ context.pop(pid).destroy(linger)
+
+
+def _set_tcp_keepalive(socket):
+ """Set the tcp keepalive parameters on *socket*."""
+ keepalive_options = get_tcp_keepalive_options()
+ for param, value in keepalive_options.items():
+ socket.setsockopt(param, value)
+
+
+def get_tcp_keepalive_options():
+ """Get the tcp_keepalive options from config."""
+ keepalive_options = dict()
+ for opt in ("tcp_keepalive",
+ "tcp_keepalive_cnt",
+ "tcp_keepalive_idle",
+ "tcp_keepalive_intvl"):
+ try:
+ value = int(config[opt])
+ except (KeyError, TypeError):
+ continue
+ param = getattr(zmq, opt.upper())
+ keepalive_options[param] = value
+ return keepalive_options
+
+
+def generate_keys(args=None):
+ """Generate a public/secret key pair."""
+ parser = argparse.ArgumentParser(
+ prog="posttroll-generate-keys",
+ description=("Create a public/secret key pair for the secure zmq backend. This will create two "
+ "files (in the current directory if not otherwise specified) with the suffixes '.key'"
+ " and '.key_secret'. The name of the files will be the one provided."))
+
+ parser.add_argument("name", type=str, help="Name of the file.")
+ parser.add_argument("-d", "--directory", help="Directory to place the keys in.", default=".", type=Path)
+
+ parsed = parser.parse_args(args)
+
+ create_certificates(parsed.directory, parsed.name)
diff --git a/posttroll/backends/zmq/address_receiver.py b/posttroll/backends/zmq/address_receiver.py
new file mode 100644
index 0000000..ef58dfa
--- /dev/null
+++ b/posttroll/backends/zmq/address_receiver.py
@@ -0,0 +1,36 @@
+"""ZMQ implementation of the the simple receiver."""
+
+from zmq import REP
+
+from posttroll.address_receiver import get_configured_address_port
+from posttroll.backends.zmq.socket import close_socket, set_up_server_socket
+
+
+class SimpleReceiver(object):
+ """Simple listing on port for address messages."""
+
+ def __init__(self, port=None, timeout=2):
+ """Set up the receiver."""
+ self._port = port or get_configured_address_port()
+ address = "tcp://*:" + str(port)
+ self._socket, _, self._authenticator = set_up_server_socket(REP, address)
+ self._running = True
+ self.timeout = timeout
+
+ def __call__(self):
+ """Receive a message."""
+ while self._running:
+ try:
+ message = self._socket.recv_string(self.timeout)
+ except TimeoutError:
+ continue
+ else:
+ self._socket.send_string("ok")
+ return message, None
+
+ def close(self):
+ """Close the receiver."""
+ self._running = False
+ close_socket(self._socket)
+ if self._authenticator:
+ self._authenticator.stop()
diff --git a/posttroll/backends/zmq/message_broadcaster.py b/posttroll/backends/zmq/message_broadcaster.py
new file mode 100644
index 0000000..238d9eb
--- /dev/null
+++ b/posttroll/backends/zmq/message_broadcaster.py
@@ -0,0 +1,54 @@
+"""Message broadcaster implementation using zmq."""
+
+import logging
+import threading
+
+from zmq import LINGER, NOBLOCK, REQ, ZMQError
+
+from posttroll.backends.zmq.socket import close_socket, set_up_client_socket
+
+logger = logging.getLogger(__name__)
+
+
+class ZMQDesignatedReceiversSender:
+ """Sends message to multiple *receivers* on *port*."""
+
+ def __init__(self, default_port, receivers):
+ """Set up the sender."""
+ self.default_port = default_port
+ self.receivers = receivers
+ self._shutdown_event = threading.Event()
+
+ def __call__(self, data):
+ """Send data."""
+ for receiver in self.receivers:
+ self._send_to_address(receiver, data)
+
+ def _send_to_address(self, address, data, timeout=10):
+ """Send data to *address* and *port* without verification of response."""
+ # Socket to talk to server
+ if address.find(":") == -1:
+ full_address = "tcp://%s:%d" % (address, self.default_port)
+ else:
+ full_address = "tcp://%s" % address
+ options = {LINGER: int(timeout * 1000)}
+ socket = set_up_client_socket(REQ, full_address, options)
+ try:
+
+ socket.send_string(data)
+ while not self._shutdown_event.is_set():
+ try:
+ message = socket.recv_string(NOBLOCK)
+ except ZMQError:
+ self._shutdown_event.wait(.1)
+ continue
+ if message != "ok":
+ logger.warning("invalid acknowledge received: %s" % message)
+ break
+
+ finally:
+ close_socket(socket)
+
+ def close(self):
+ """Close the sender."""
+ self._shutdown_event.set()
diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py
new file mode 100644
index 0000000..dc0fcfb
--- /dev/null
+++ b/posttroll/backends/zmq/ns.py
@@ -0,0 +1,108 @@
+"""ZMQ implexentation of ns."""
+
+import logging
+from contextlib import suppress
+from threading import Lock
+
+from zmq import LINGER, REP, REQ
+
+from posttroll.backends.zmq.socket import SocketReceiver, close_socket, set_up_client_socket, set_up_server_socket
+from posttroll.message import Message
+from posttroll.ns import get_active_address, get_configured_nameserver_port
+
+logger = logging.getLogger("__name__")
+
+nslock = Lock()
+
+
+def zmq_get_pub_address(name, timeout=10, nameserver="localhost"):
+ """Get the address of the publisher.
+
+ For a given publisher *name* from the nameserver on *nameserver* (localhost by default).
+ """
+ nameserver_address = create_nameserver_address(nameserver)
+ # Socket to talk to server
+ logger.debug(f"Connecting to {nameserver_address}")
+ socket = create_req_socket(timeout, nameserver_address)
+ return _fetch_address_using_socket(socket, name, timeout)
+
+
+def create_nameserver_address(nameserver):
+ """Create the nameserver address."""
+ port = get_configured_nameserver_port()
+ nameserver_address = "tcp://" + nameserver + ":" + str(port)
+ return nameserver_address
+
+
+def _fetch_address_using_socket(socket, name, timeout):
+ try:
+ socket_receiver = SocketReceiver()
+ socket_receiver.register(socket)
+
+ message = Message("/oper/ns", "request", {"service": name})
+ socket.send_string(str(message))
+
+ # Get the reply.
+ for message, _ in socket_receiver.receive(socket, timeout=timeout):
+ return message.data
+ except TimeoutError:
+ raise TimeoutError("Didn't get an address after %d seconds."
+ % timeout)
+ finally:
+ socket_receiver.unregister(socket)
+ close_socket(socket)
+
+
+def create_req_socket(timeout, nameserver_address):
+ """Create a REQ socket."""
+ options = {LINGER: int(timeout * 1000)}
+ socket = set_up_client_socket(REQ, nameserver_address, options)
+ return socket
+
+
+class ZMQNameServer:
+ """The name server."""
+
+ def __init__(self):
+ """Set up the nameserver."""
+ self.running = True
+ self.listener = None
+
+ def run(self, address_receiver):
+ """Run the listener and answer to requests."""
+ port = get_configured_nameserver_port()
+
+ try:
+ # stop was called before we could start running, exit
+ if not self.running:
+ return
+ address = "tcp://*:" + str(port)
+ self.listener, _, self._authenticator = set_up_server_socket(REP, address)
+ logger.debug(f"Nameserver listening on port {port}")
+ socket_receiver = SocketReceiver()
+ socket_receiver.register(self.listener)
+ while self.running:
+ try:
+ for msg, _ in socket_receiver.receive(self.listener, timeout=1):
+ logger.debug("Replying to request: " + str(msg))
+ active_address = get_active_address(msg.data["service"], address_receiver)
+ self.listener.send_unicode(str(active_address))
+ except TimeoutError:
+ continue
+ except KeyboardInterrupt:
+ # Needed to stop the nameserver.
+ pass
+ finally:
+ socket_receiver.unregister(self.listener)
+ self.close_sockets_and_threads()
+
+ def close_sockets_and_threads(self):
+ """Close all sockets and threads."""
+ with suppress(AttributeError):
+ close_socket(self.listener)
+ with suppress(AttributeError):
+ self._authenticator.stop()
+
+ def stop(self):
+ """Stop the name server."""
+ self.running = False
diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py
new file mode 100644
index 0000000..8d2bec5
--- /dev/null
+++ b/posttroll/backends/zmq/publisher.py
@@ -0,0 +1,59 @@
+"""ZMQ implementation of the publisher."""
+
+import logging
+from contextlib import suppress
+from threading import Lock
+
+import zmq
+
+from posttroll.backends.zmq import get_tcp_keepalive_options
+from posttroll.backends.zmq.socket import close_socket, set_up_server_socket
+
+LOGGER = logging.getLogger(__name__)
+
+
+class ZMQPublisher:
+ """Unsecure ZMQ implementation of the publisher class."""
+
+ def __init__(self, address, name="", min_port=None, max_port=None):
+ """Set up the publisher.
+
+ Args:
+ address: the address to connect to.
+ name: the name of this publishing service.
+ min_port: the minimal port number to use.
+ max_port: the maximal port number to use.
+
+ """
+ self.name = name
+ self.destination = address
+ self.publish_socket = None
+ self.min_port = min_port
+ self.max_port = max_port
+ self.port_number = None
+ self._pub_lock = Lock()
+ self._authenticator = None
+
+ def start(self):
+ """Start the publisher."""
+ self._create_socket()
+ LOGGER.info(f"Publisher for {self.destination} started on port {self.port_number}.")
+
+ return self
+
+ def _create_socket(self):
+ options = get_tcp_keepalive_options()
+ self.publish_socket, port, self._authenticator = set_up_server_socket(zmq.PUB, self.destination, options,
+ (self.min_port, self.max_port))
+ self.port_number = port
+
+ def send(self, msg):
+ """Send the given message."""
+ with self._pub_lock:
+ self.publish_socket.send_string(msg)
+
+ def stop(self):
+ """Stop the publisher."""
+ close_socket(self.publish_socket)
+ with suppress(AttributeError):
+ self._authenticator.stop()
diff --git a/posttroll/backends/zmq/socket.py b/posttroll/backends/zmq/socket.py
new file mode 100644
index 0000000..7adb295
--- /dev/null
+++ b/posttroll/backends/zmq/socket.py
@@ -0,0 +1,155 @@
+"""ZMQ socket handling functions."""
+
+from urllib.parse import urlsplit, urlunsplit
+
+import zmq
+from zmq.auth.thread import ThreadAuthenticator
+
+from posttroll import config, get_context
+from posttroll.message import Message
+
+
+def close_socket(sock):
+ """Close a zmq socket."""
+ sock.setsockopt(zmq.LINGER, 1)
+ sock.close()
+
+
+def set_up_client_socket(socket_type, address, options=None):
+ """Set up a client (connecting) zmq socket."""
+ backend = config["backend"]
+ if backend == "unsecure_zmq":
+ sock = create_unsecure_client_socket(socket_type)
+ elif backend == "secure_zmq":
+ sock = create_secure_client_socket(socket_type)
+ add_options(sock, options)
+ sock.connect(address)
+ return sock
+
+
+def create_unsecure_client_socket(socket_type):
+ """Create an unsecure client socket."""
+ return get_context().socket(socket_type)
+
+
+def add_options(sock, options=None):
+ """Add options to a socket."""
+ if not options:
+ return
+ for param, val in options.items():
+ sock.setsockopt(param, val)
+
+
+def create_secure_client_socket(socket_type):
+ """Create a secure client socket."""
+ subscriber = get_context().socket(socket_type)
+
+ client_secret_key_file = config["client_secret_key_file"]
+ server_public_key_file = config["server_public_key_file"]
+ client_public, client_secret = zmq.auth.load_certificate(client_secret_key_file)
+ subscriber.curve_secretkey = client_secret
+ subscriber.curve_publickey = client_public
+
+ server_public, _ = zmq.auth.load_certificate(server_public_key_file)
+ # The client must know the server's public key to make a CURVE connection.
+ subscriber.curve_serverkey = server_public
+ return subscriber
+
+
+def set_up_server_socket(socket_type, destination, options=None, port_interval=(None, None)):
+ """Set up a server (binding) socket."""
+ if options is None:
+ options = {}
+ backend = config["backend"]
+ if backend == "unsecure_zmq":
+ sock = create_unsecure_server_socket(socket_type)
+ authenticator = None
+ elif backend == "secure_zmq":
+ sock, authenticator = create_secure_server_socket(socket_type)
+
+ add_options(sock, options)
+
+ port = bind(sock, destination, port_interval)
+ return sock, port, authenticator
+
+
+def create_unsecure_server_socket(socket_type):
+ """Create an unsecure server socket."""
+ return get_context().socket(socket_type)
+
+
+def bind(sock, destination, port_interval):
+ """Bind the socket to a destination.
+
+ If a random port is to be chosen, the port_interval is used.
+ """
+ # Check for port 0 (random port)
+ min_port, max_port = port_interval
+ u__ = urlsplit(destination)
+ port = u__.port
+ if port == 0:
+ dest = urlunsplit((u__.scheme, u__.hostname,
+ u__.path, u__.query, u__.fragment))
+ port_number = sock.bind_to_random_port(dest,
+ min_port=min_port,
+ max_port=max_port)
+ netloc = u__.hostname + ":" + str(port_number)
+ destination = urlunsplit((u__.scheme, netloc, u__.path,
+ u__.query, u__.fragment))
+ else:
+ sock.bind(destination)
+ port_number = port
+ return port_number
+
+
+def create_secure_server_socket(socket_type):
+ """Create a secure server socket."""
+ server_secret_key = config["server_secret_key_file"]
+ clients_public_keys_directory = config["clients_public_keys_directory"]
+ authorized_sub_addresses = config.get("authorized_client_addresses", [])
+
+ ctx = get_context()
+
+ # Start an authenticator for this context.
+ authenticator_thread = ThreadAuthenticator(ctx)
+ authenticator_thread.start()
+ authenticator_thread.allow(*authorized_sub_addresses)
+ # Tell authenticator to use the certificate in a directory
+ authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory)
+
+ server_socket = ctx.socket(socket_type)
+
+ server_public, server_secret = zmq.auth.load_certificate(server_secret_key)
+ server_socket.curve_secretkey = server_secret
+ server_socket.curve_publickey = server_public
+ server_socket.curve_server = True
+ return server_socket, authenticator_thread
+
+
+class SocketReceiver:
+ """A receiver for mulitple sockets."""
+
+ def __init__(self):
+ """Set up the receiver."""
+ self._poller = zmq.Poller()
+
+ def register(self, socket):
+ """Register the socket."""
+ self._poller.register(socket, zmq.POLLIN)
+
+ def unregister(self, socket):
+ """Unregister the socket."""
+ self._poller.unregister(socket)
+
+ def receive(self, *sockets, timeout=None):
+ """Timeout is in seconds."""
+ if timeout:
+ timeout *= 1000
+ socks = dict(self._poller.poll(timeout=timeout))
+ if socks:
+ for sock in sockets:
+ if socks.get(sock) == zmq.POLLIN:
+ received = sock.recv_string(zmq.NOBLOCK)
+ yield Message.decode(received), sock
+ else:
+ raise TimeoutError("Did not receive anything on sockets.")
diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py
new file mode 100644
index 0000000..836f590
--- /dev/null
+++ b/posttroll/backends/zmq/subscriber.py
@@ -0,0 +1,209 @@
+"""ZMQ implementation of the subscriber."""
+
+import logging
+from threading import Lock
+from time import sleep
+from urllib.parse import urlsplit
+
+from zmq import PULL, SUB, SUBSCRIBE, ZMQError
+
+from posttroll.backends.zmq import get_tcp_keepalive_options
+from posttroll.backends.zmq.socket import SocketReceiver, close_socket, set_up_client_socket
+
+LOGGER = logging.getLogger(__name__)
+
+
+class ZMQSubscriber:
+ """A ZMQ subscriber class."""
+
+ def __init__(self, addresses, topics="", message_filter=None, translate=False):
+ """Initialize the subscriber."""
+ self._topics = topics
+ self._filter = message_filter
+ self._translate = translate
+
+ self.sub_addr = {}
+ self.addr_sub = {}
+
+ self._hooks = []
+ self._hooks_cb = {}
+
+ self._sock_receiver = SocketReceiver()
+ self._lock = Lock()
+
+ self.update(addresses)
+
+ self._loop = None
+
+ @property
+ def running(self):
+ """Check if suscriber is running."""
+ return self._loop
+
+ def add(self, address, topics=None):
+ """Add *address* to the subscribing list for *topics*.
+
+ It topics is None we will subscribe to already specified topics.
+ """
+ with self._lock:
+ if address in self.addresses:
+ return
+
+ topics = topics or self._topics
+ LOGGER.info("Subscriber adding address %s with topics %s",
+ str(address), str(topics))
+ subscriber = self._add_sub_socket(address, topics)
+ self.sub_addr[subscriber] = address
+ self.addr_sub[address] = subscriber
+
+ def remove(self, address):
+ """Remove *address* from the subscribing list for *topics*."""
+ with self._lock:
+ try:
+ subscriber = self.addr_sub[address]
+ except KeyError:
+ return
+ LOGGER.info("Subscriber removing address %s", str(address))
+ del self.addr_sub[address]
+ del self.sub_addr[subscriber]
+ self._remove_sub_socket(subscriber)
+
+ def _remove_sub_socket(self, subscriber):
+ if self._sock_receiver:
+ self._sock_receiver.unregister(subscriber)
+ subscriber.close()
+
+ def update(self, addresses):
+ """Update with a set of addresses."""
+ if isinstance(addresses, str):
+ addresses = [addresses, ]
+ current_addresses, new_addresses = set(self.addresses), set(addresses)
+ addresses_to_remove = current_addresses.difference(new_addresses)
+ addresses_to_add = new_addresses.difference(current_addresses)
+ for addr in addresses_to_remove:
+ self.remove(addr)
+ for addr in addresses_to_add:
+ self.add(addr)
+ return bool(addresses_to_remove or addresses_to_add)
+
+ def add_hook_sub(self, address, topics, callback):
+ """Specify a SUB *callback* in the same stream (thread) as the main receive loop.
+
+ The callback will be called with the received messages from the
+ specified subscription.
+
+ Good for operations, which is required to be done in the same thread as
+ the main recieve loop (e.q operations on the underlying sockets).
+ """
+ topics = topics
+ LOGGER.info("Subscriber adding SUB hook %s for topics %s",
+ str(address), str(topics))
+ socket = self._add_sub_socket(address, topics)
+ self._add_hook(socket, callback)
+
+ def add_hook_pull(self, address, callback):
+ """Specify a PULL *callback* in the same stream (thread) as the main receive loop.
+
+ The callback will be called with the received messages from the
+ specified subscription. Good for pushed 'inproc' messages from another thread.
+ """
+ LOGGER.info("Subscriber adding PULL hook %s", str(address))
+ socket = self._create_socket(PULL, address)
+ if self._sock_receiver:
+ self._sock_receiver.register(socket)
+ self._add_hook(socket, callback)
+
+ def _add_hook(self, socket, callback):
+ """Add a generic hook. The passed socket has to be "receive only"."""
+ self._hooks.append(socket)
+ self._hooks_cb[socket] = callback
+
+ @property
+ def addresses(self):
+ """Get the addresses."""
+ return self.sub_addr.values()
+
+ @property
+ def subscribers(self):
+ """Get the subscribers."""
+ return self.sub_addr.keys()
+
+ def recv(self, timeout=None):
+ """Receive, optionally with *timeout* in seconds."""
+ for sub in list(self.subscribers) + self._hooks:
+ self._sock_receiver.register(sub)
+ self._loop = True
+ try:
+ while self._loop:
+ sleep(0)
+ yield from self._new_messages(timeout)
+ finally:
+ for sub in list(self.subscribers) + self._hooks:
+ self._sock_receiver.unregister(sub)
+ # self.poller.unregister(sub)
+
+ def _new_messages(self, timeout):
+ """Check for new messages to yield and pass to the callbacks."""
+ all_subs = list(self.subscribers) + self._hooks
+ try:
+ for m__, sock in self._sock_receiver.receive(*all_subs, timeout=timeout):
+ if sock in self.subscribers:
+ if not self._filter or self._filter(m__):
+ if self._translate:
+ url = urlsplit(self.sub_addr[sock])
+ host = url[1].split(":")[0]
+ m__.sender = (m__.sender.split("@")[0]
+ + "@" + host)
+ yield m__
+ elif sock in self._hooks:
+ self._hooks_cb[sock](m__)
+ except TimeoutError:
+ yield None
+ except ZMQError as err:
+ if self._loop:
+ LOGGER.exception("Receive failed: %s", str(err))
+
+ def __call__(self, **kwargs):
+ """Handle calls with class instance."""
+ return self.recv(**kwargs)
+
+ def stop(self):
+ """Stop the subscriber."""
+ self._loop = False
+
+ def close(self):
+ """Close the subscriber: stop it and close the local subscribers."""
+ self.stop()
+ for sub in list(self.subscribers) + self._hooks:
+ try:
+ close_socket(sub)
+ except ZMQError:
+ pass
+
+ def __del__(self):
+ """Clean up after the instance is deleted."""
+ for sub in list(self.subscribers) + self._hooks:
+ try:
+ close_socket(sub)
+ except Exception: # noqa: E722
+ pass
+
+ def _add_sub_socket(self, address, topics):
+
+ options = get_tcp_keepalive_options()
+
+ subscriber = self._create_socket(SUB, address, options)
+ add_subscriptions(subscriber, topics)
+
+ if self._sock_receiver:
+ self._sock_receiver.register(subscriber)
+ return subscriber
+
+ def _create_socket(self, socket_type, address, options):
+ return set_up_client_socket(socket_type, address, options)
+
+
+def add_subscriptions(socket, topics):
+ """Add subscriptions to a socket."""
+ for t__ in topics:
+ socket.setsockopt_string(SUBSCRIBE, str(t__))
diff --git a/posttroll/bbmcast.py b/posttroll/bbmcast.py
index 3aa6b70..d9f3ae0 100644
--- a/posttroll/bbmcast.py
+++ b/posttroll/bbmcast.py
@@ -31,24 +31,51 @@
import logging
import os
import struct
-from socket import (AF_INET, INADDR_ANY, IP_ADD_MEMBERSHIP, IP_MULTICAST_LOOP,
- IP_MULTICAST_TTL, IPPROTO_IP, SO_BROADCAST, SO_REUSEADDR,
- SOCK_DGRAM, SOL_IP, SOL_SOCKET, gethostbyname, socket,
- timeout, SO_LINGER)
-
-__all__ = ('MulticastSender', 'MulticastReceiver', 'mcast_sender',
- 'mcast_receiver', 'SocketTimeout')
+import warnings
+from contextlib import suppress
+from socket import (
+ AF_INET,
+ INADDR_ANY,
+ IP_ADD_MEMBERSHIP,
+ IP_MULTICAST_IF,
+ IP_MULTICAST_LOOP,
+ IP_MULTICAST_TTL,
+ IPPROTO_IP,
+ SO_BROADCAST,
+ SO_LINGER,
+ SO_REUSEADDR,
+ SOCK_DGRAM,
+ SOL_IP,
+ SOL_SOCKET,
+ gethostbyname,
+ inet_aton,
+ socket,
+ timeout,
+)
+
+from posttroll import config
+
+__all__ = ("MulticastSender", "MulticastReceiver", "mcast_sender",
+ "mcast_receiver", "SocketTimeout")
# 224.0.0.0 through 224.0.0.255 is reserved administrative tasks
-MC_GROUP = os.environ.get('PYTROLL_MC_GROUP', '225.0.0.212')
+DEFAULT_MC_GROUP = "225.0.0.212"
# local network multicast (<32)
-TTL_LOCALNET = int(os.environ.get('PYTROLL_MC_TTL', 31))
+TTL_LOCALNET = int(os.environ.get("PYTROLL_MC_TTL", 31))
logger = logging.getLogger(__name__)
SocketTimeout = timeout # for easy access to socket.timeout
+DEFAULT_BROADCAST_PORT = 21200
+
+
+def get_configured_broadcast_port():
+ """Get the configured nameserver port."""
+ return config.get("broadcast_port", DEFAULT_BROADCAST_PORT)
+
+
# -----------------------------------------------------------------------------
#
# Sender.
@@ -56,15 +83,15 @@
# -----------------------------------------------------------------------------
-class MulticastSender(object):
+class MulticastSender:
"""Multicast sender on *port* and *mcgroup*."""
- def __init__(self, port, mcgroup=MC_GROUP):
- """Initialize multicast sending."""
+ def __init__(self, port, mcgroup=None):
+ """Set up the multicast sender."""
self.port = port
self.group = mcgroup
self.socket, self.group = mcast_sender(mcgroup)
- logger.debug('Started multicast group %s', mcgroup)
+ logger.debug("Started multicast group %s", self.group)
def __call__(self, data):
"""Send data to a socket."""
@@ -77,25 +104,43 @@ def close(self):
# Allow non-object interface
-def mcast_sender(mcgroup=MC_GROUP):
+def mcast_sender(mcgroup=None):
"""Non-object interface for sending multicast messages."""
+ if mcgroup is None:
+ mcgroup = get_mc_group()
sock = socket(AF_INET, SOCK_DGRAM)
try:
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
if _is_broadcast_group(mcgroup):
- group = ''
+ group = ""
sock.setsockopt(SOL_SOCKET, SO_BROADCAST, 1)
- elif int(mcgroup.split(".")[0]) > 239 or int(mcgroup.split(".")[0]) < 224:
- raise IOError("Invalid multicast address.")
+ elif ((int(mcgroup.split(".")[0]) > 239) or
+ (int(mcgroup.split(".")[0]) < 224)):
+ raise IOError(f"Invalid multicast address {mcgroup}")
else:
group = mcgroup
- ttl = struct.pack('b', TTL_LOCALNET) # Time-to-live
+ ttl = struct.pack("b", TTL_LOCALNET) # Time-to-live
sock.setsockopt(IPPROTO_IP, IP_MULTICAST_TTL, ttl)
+
+ with suppress(KeyError):
+ multicast_interface = config.get("multicast_interface")
+ sock.setsockopt(IPPROTO_IP, IP_MULTICAST_IF, inet_aton(multicast_interface))
except Exception:
sock.close()
raise
return sock, group
+
+def get_mc_group():
+ try:
+ mcgroup = os.environ["PYTROLL_MC_GROUP"]
+ warnings.warn("PYTROLL_MC_GROUP is pending deprecation, please use POSTTROLL_MC_GROUP instead.",
+ PendingDeprecationWarning)
+ except KeyError:
+ mcgroup = DEFAULT_MC_GROUP
+ mcgroup = config.get("mc_group", mcgroup)
+ return mcgroup
+
# -----------------------------------------------------------------------------
#
# Receiver.
@@ -103,13 +148,13 @@ def mcast_sender(mcgroup=MC_GROUP):
# -----------------------------------------------------------------------------
-class MulticastReceiver(object):
+class MulticastReceiver:
"""Multicast receiver on *port* for an *mcgroup*."""
BUFSIZE = 1024
- def __init__(self, port, mcgroup=MC_GROUP):
- """Initialize multicast receiver."""
+ def __init__(self, port, mcgroup=None):
+ """Set up the multicast receiver."""
# Note: a multicast receiver will also receive broadcast on same port.
self.port = port
self.socket, self.group = mcast_receiver(port, mcgroup)
@@ -129,14 +174,16 @@ def __call__(self):
def close(self):
"""Close the receiver."""
- self.socket.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack('ii', 1, 1))
+ self.socket.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 1, 1))
self.socket.close()
# Allow non-object interface
-def mcast_receiver(port, mcgroup=MC_GROUP):
+def mcast_receiver(port, mcgroup=None):
"""Open a UDP socket, bind it to a port and select a multicast group."""
+ if mcgroup is None:
+ mcgroup = get_mc_group()
if _is_broadcast_group(mcgroup):
group = None
else:
@@ -154,22 +201,21 @@ def mcast_receiver(port, mcgroup=MC_GROUP):
sock.setsockopt(SOL_IP, IP_MULTICAST_LOOP, 1) # default
# Bind it to the port
- sock.bind(('', port))
+ sock.bind(("", port))
# Look up multicast group address in name server
# (doesn't hurt if it is already in ddd.ddd.ddd.ddd format)
if group:
group = gethostbyname(group)
- # Construct binary group address
- bytes_ = [int(b) for b in group.split(".")]
- grpaddr = 0
- for byte in bytes_:
- grpaddr = (grpaddr << 8) | byte
-
- # Construct struct mreq from grpaddr and ifaddr
- ifaddr = INADDR_ANY
- mreq = struct.pack('!LL', grpaddr, ifaddr)
+ # Construct struct mreq
+ try:
+ multicast_interface = config.get("multicast_interface")
+ ifaddr = inet_aton(multicast_interface)
+ mreq = struct.pack("=4s4s", inet_aton(group), ifaddr)
+ except KeyError:
+ ifaddr = INADDR_ANY
+ mreq = struct.pack("=4sl", inet_aton(group), ifaddr)
# Add group membership
sock.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq)
@@ -177,7 +223,7 @@ def mcast_receiver(port, mcgroup=MC_GROUP):
sock.close()
raise
- return sock, group or ''
+ return sock, group or ""
# -----------------------------------------------------------------------------
#
@@ -188,6 +234,6 @@ def mcast_receiver(port, mcgroup=MC_GROUP):
def _is_broadcast_group(group):
"""Check if *group* is a valid multicasting group."""
- if not group or gethostbyname(group) in ('0.0.0.0', '255.255.255.255'):
+ if not group or gethostbyname(group) in ("0.0.0.0", "255.255.255.255"):
return True
return False
diff --git a/posttroll/listener.py b/posttroll/listener.py
index 608d134..e89c486 100644
--- a/posttroll/listener.py
+++ b/posttroll/listener.py
@@ -23,11 +23,12 @@
"""Listener module."""
-from posttroll.subscriber import create_subscriber_from_dict_config
+import logging
+import time
from queue import Queue
from threading import Thread
-import time
-import logging
+
+from posttroll.subscriber import create_subscriber_from_dict_config
class ListenerContainer:
@@ -106,11 +107,11 @@ def create_subscriber(self):
def _get_subscriber_config(self):
config = {
- 'services': self.services,
- 'topics': self.topics,
- 'addr_listener': True,
- 'addresses': self.addresses,
- 'nameserver': self.nameserver,
+ "services": self.services,
+ "topics": self.topics,
+ "addr_listener": True,
+ "addresses": self.addresses,
+ "nameserver": self.nameserver,
}
return config
diff --git a/posttroll/logger.py b/posttroll/logger.py
index 31fd152..2155a76 100644
--- a/posttroll/logger.py
+++ b/posttroll/logger.py
@@ -25,14 +25,14 @@
# TODO: remove old hanging subscriptions
-from posttroll.subscriber import Subscribe
-from posttroll.publisher import NoisyPublisher
-from posttroll.message import Message
-from threading import Thread
-
import copy
import logging
import logging.handlers
+from threading import Thread
+
+from posttroll.message import Message
+from posttroll.publisher import NoisyPublisher
+from posttroll.subscriber import Subscribe
LOGGER = logging.getLogger(__name__)
@@ -75,11 +75,11 @@ def close(self):
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
COLORS = {
- 'WARNING': YELLOW,
- 'INFO': GREEN,
- 'DEBUG': BLUE,
- 'CRITICAL': MAGENTA,
- 'ERROR': RED
+ "WARNING": YELLOW,
+ "INFO": GREEN,
+ "DEBUG": BLUE,
+ "CRITICAL": MAGENTA,
+ "ERROR": RED
}
COLOR_SEQ = "\033[1;%dm"
@@ -201,8 +201,8 @@ def run():
time.sleep(1)
except KeyboardInterrupt:
tlogger.stop()
- print("Thanks for using pytroll/logger. See you soon on www.pytroll.org!")
+ print("Thanks for using pytroll/logger. See you soon on www.pytroll.org!") # noqa
-if __name__ == '__main__':
+if __name__ == "__main__":
run()
diff --git a/posttroll/message.py b/posttroll/message.py
index 6c10c0a..ab68484 100644
--- a/posttroll/message.py
+++ b/posttroll/message.py
@@ -48,8 +48,8 @@
from posttroll import strp_isoformat
-_MAGICK = 'pytroll:/'
-_VERSION = 'v1.01'
+_MAGICK = "pytroll:/"
+_VERSION = "v1.01"
class MessageError(Exception):
@@ -117,17 +117,17 @@ class Message(object):
- It will make a Message pickleable.
"""
- def __init__(self, subject='', atype='', data='', binary=False, rawstr=None):
+ def __init__(self, subject="", atype="", data="", binary=False, rawstr=None):
"""Initialize a Message from a subject, type and data, or from a raw string."""
if rawstr:
self.__dict__ = _decode(rawstr)
else:
try:
- self.subject = subject.decode('utf-8')
+ self.subject = subject.decode("utf-8")
except AttributeError:
self.subject = subject
try:
- self.type = atype.decode('utf-8')
+ self.type = atype.decode("utf-8")
except AttributeError:
self.type = atype
self.type = atype
@@ -142,17 +142,17 @@ def __init__(self, subject='', atype='', data='', binary=False, rawstr=None):
def user(self):
"""Try to return a user from a sender."""
try:
- return self.sender[:self.sender.index('@')]
+ return self.sender[:self.sender.index("@")]
except ValueError:
- return ''
+ return ""
@property
def host(self):
"""Try to return a host from a sender."""
try:
- return self.sender[self.sender.index('@') + 1:]
+ return self.sender[self.sender.index("@") + 1:]
except ValueError:
- return ''
+ return ""
@property
def head(self):
@@ -181,7 +181,7 @@ def __unicode__(self):
def __str__(self):
"""Return the human readable representation of the Message."""
try:
- return unicode(self).encode('utf-8')
+ return unicode(self).encode("utf-8")
except NameError:
return self.encode()
@@ -243,36 +243,17 @@ def datetime_decoder(dct):
def _decode(rawstr):
"""Convert a raw string to a Message."""
- # Check for the magick word.
- try:
- rawstr = rawstr.decode('utf-8')
- except (AttributeError, UnicodeEncodeError):
- pass
- except (UnicodeDecodeError):
- try:
- rawstr = rawstr.decode('iso-8859-1')
- except (UnicodeDecodeError):
- rawstr = rawstr.decode('utf-8', 'ignore')
- if not rawstr.startswith(_MAGICK):
- raise MessageError("This is not a '%s' message (wrong magick word)"
- % _MAGICK)
- rawstr = rawstr[len(_MAGICK):]
+ rawstr = _check_for_magic_word(rawstr)
- # Check for element count and version
- raw = re.split(r"\s+", rawstr, maxsplit=6)
- if len(raw) < 5:
- raise MessageError("Could node decode raw string: '%s ...'"
- % str(rawstr[:36]))
- version = raw[4][:len(_VERSION)]
- if not _is_valid_version(version):
- raise MessageError("Invalid Message version: '%s'" % str(version))
+ raw = _check_for_element_count(rawstr)
+ version = _check_for_version(raw)
# Start to build message
- msg = dict((('subject', raw[0].strip()),
- ('type', raw[1].strip()),
- ('sender', raw[2].strip()),
- ('time', strp_isoformat(raw[3].strip())),
- ('version', version)))
+ msg = dict((("subject", raw[0].strip()),
+ ("type", raw[1].strip()),
+ ("sender", raw[2].strip()),
+ ("time", strp_isoformat(raw[3].strip())),
+ ("version", version)))
# Data part
try:
@@ -282,26 +263,59 @@ def _decode(rawstr):
mimetype = None
if mimetype is None:
- msg['data'] = ''
- msg['binary'] = False
- elif mimetype == 'application/json':
+ msg["data"] = ""
+ msg["binary"] = False
+ elif mimetype == "application/json":
try:
- msg['data'] = json.loads(raw[6], object_hook=datetime_decoder)
- msg['binary'] = False
+ msg["data"] = json.loads(raw[6], object_hook=datetime_decoder)
+ msg["binary"] = False
except ValueError:
raise MessageError("JSON decode failed on '%s ...'" % raw[6][:36])
- elif mimetype == 'text/ascii':
- msg['data'] = str(data)
- msg['binary'] = False
- elif mimetype == 'binary/octet-stream':
- msg['data'] = data
- msg['binary'] = True
+ elif mimetype == "text/ascii":
+ msg["data"] = str(data)
+ msg["binary"] = False
+ elif mimetype == "binary/octet-stream":
+ msg["data"] = data
+ msg["binary"] = True
else:
raise MessageError("Unknown mime-type '%s'" % mimetype)
return msg
+def _check_for_version(raw):
+ version = raw[4][:len(_VERSION)]
+ if not _is_valid_version(version):
+ raise MessageError("Invalid Message version: '%s'" % str(version))
+ return version
+
+
+def _check_for_element_count(rawstr):
+ raw = re.split(r"\s+", rawstr, maxsplit=6)
+ if len(raw) < 5:
+ raise MessageError("Could node decode raw string: '%s ...'"
+ % str(rawstr[:36]))
+
+ return raw
+
+
+def _check_for_magic_word(rawstr):
+ """Check for the magick word."""
+ try:
+ rawstr = rawstr.decode("utf-8")
+ except (AttributeError, UnicodeEncodeError):
+ pass
+ except (UnicodeDecodeError):
+ try:
+ rawstr = rawstr.decode("iso-8859-1")
+ except (UnicodeDecodeError):
+ rawstr = rawstr.decode("utf-8", "ignore")
+ if not rawstr.startswith(_MAGICK):
+ raise MessageError("This is not a '%s' message (wrong magick word)"
+ % _MAGICK)
+ return rawstr[len(_MAGICK):]
+
+
def datetime_encoder(obj):
"""Encode datetimes into iso format."""
try:
@@ -317,15 +331,15 @@ def _encode(msg, head=False, binary=False):
if not head and msg.data:
if not binary and isinstance(msg.data, str):
- return (rawstr + ' ' +
- 'text/ascii' + ' ' + msg.data)
+ return (rawstr + " " +
+ "text/ascii" + " " + msg.data)
elif not binary:
- return (rawstr + ' ' +
- 'application/json' + ' ' +
+ return (rawstr + " " +
+ "application/json" + " " +
json.dumps(msg.data, default=datetime_encoder))
else:
- return (rawstr + ' ' +
- 'binary/octet-stream' + ' ' + msg.data)
+ return (rawstr + " " +
+ "binary/octet-stream" + " " + msg.data)
return rawstr
diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py
index 32103c8..4990c36 100644
--- a/posttroll/message_broadcaster.py
+++ b/posttroll/message_broadcaster.py
@@ -23,58 +23,34 @@
"""Message broadcast module."""
-import time
-import threading
-import logging
import errno
+import logging
+import threading
-from posttroll import message
-from posttroll.bbmcast import MulticastSender, MC_GROUP
-from posttroll import get_context
-from zmq import REQ, LINGER
+from posttroll import config, message
+from posttroll.bbmcast import MulticastSender, get_configured_broadcast_port
-__all__ = ('MessageBroadcaster', 'AddressBroadcaster', 'sendaddress')
+__all__ = ("MessageBroadcaster", "AddressBroadcaster", "sendaddress")
LOGGER = logging.getLogger(__name__)
-broadcast_port = 21200
-
-class DesignatedReceiversSender(object):
+class DesignatedReceiversSender:
"""Sends message to multiple *receivers* on *port*."""
-
def __init__(self, default_port, receivers):
"""Set settings."""
- self.default_port = default_port
- self.receivers = receivers
+ backend = config.get("backend", "unsecure_zmq")
+ if backend == "unsecure_zmq":
+ from posttroll.backends.zmq.message_broadcaster import ZMQDesignatedReceiversSender
+ self._sender = ZMQDesignatedReceiversSender(default_port, receivers)
def __call__(self, data):
"""Send messages from all receivers."""
- for receiver in self.receivers:
- self._send_to_address(receiver, data)
-
- def _send_to_address(self, address, data, timeout=10):
- """Send data to *address* and *port* without verification of response."""
- # Socket to talk to server
- socket = get_context().socket(REQ)
- try:
- socket.setsockopt(LINGER, timeout * 1000)
- if address.find(":") == -1:
- socket.connect("tcp://%s:%d" % (address, self.default_port))
- else:
- socket.connect("tcp://%s" % address)
- socket.send_string(data)
- message = socket.recv_string()
- if message != "ok":
- LOGGER.warn("invalid acknowledge received: %s" % message)
-
- finally:
- socket.close()
+ return self._sender(data)
def close(self):
"""Close the sender."""
- pass
-
+ return self._sender.close()
# ----------------------------------------------------------------------------
#
@@ -83,51 +59,48 @@ def close(self):
# ----------------------------------------------------------------------------
-class MessageBroadcaster(object):
+class MessageBroadcaster:
"""Class to broadcast stuff.
If *interval* is 0 or negative, no broadcasting is done.
"""
def __init__(self, msg, port, interval, designated_receivers=None):
- """Initialize message broadcaster."""
+ """Set up the message broadcaster."""
if designated_receivers:
self._sender = DesignatedReceiversSender(port,
designated_receivers)
else:
- # mcgroup = None or '' is broadcast
- # mcgroup = MC_GROUP is default multicast group
- self._sender = MulticastSender(port, mcgroup=MC_GROUP)
+ self._sender = MulticastSender(port)
self._interval = interval
self._message = msg
- self._do_run = False
- self._is_running = False
+ self._shutdown_event = threading.Event()
self._thread = threading.Thread(target=self._run)
def start(self):
"""Start the broadcasting."""
if self._interval > 0:
- if not self._is_running:
- self._do_run = True
+ if not self._thread.is_alive():
self._thread.start()
return self
def is_running(self):
"""Are we running."""
- return self._is_running
+ return self._thread.is_alive()
def stop(self):
"""Stop the broadcasting."""
- self._do_run = False
+ self._shutdown_event.set()
+ self._sender.close()
+ self._thread.join()
return self
def _run(self):
"""Broadcasts forever."""
- self._is_running = True
network_fail = False
try:
- while self._do_run:
+ while not self._shutdown_event.is_set():
try:
if network_fail is True:
LOGGER.info("Network connection re-established!")
@@ -141,9 +114,8 @@ def _run(self):
network_fail = True
else:
raise
- time.sleep(self._interval)
+ self._shutdown_event.wait(self._interval)
finally:
- self._is_running = False
self._sender.close()
@@ -155,13 +127,13 @@ def _run(self):
class AddressBroadcaster(MessageBroadcaster):
- """Class to broadcast stuff."""
+ """Class to broadcast addresses."""
def __init__(self, name, address, interval, nameservers):
- """Initialize address broadcasting."""
+ """Set up the Address broadcaster."""
msg = message.Message("/address/%s" % name, "info",
{"URI": "%s:%d" % address}).encode()
- MessageBroadcaster.__init__(self, msg, broadcast_port, interval,
+ MessageBroadcaster.__init__(self, msg, get_configured_broadcast_port(), interval,
nameservers)
@@ -184,7 +156,7 @@ def __init__(self, name, address, data_type, interval=2, nameservers=None):
msg = message.Message("/address/%s" % name, "info",
{"URI": address,
"service": data_type}).encode()
- MessageBroadcaster.__init__(self, msg, broadcast_port, interval,
+ MessageBroadcaster.__init__(self, msg, get_configured_broadcast_port(), interval,
nameservers)
diff --git a/posttroll/ns.py b/posttroll/ns.py
index 1ecff05..1bf05b4 100644
--- a/posttroll/ns.py
+++ b/posttroll/ns.py
@@ -23,35 +23,36 @@
"""Manage other's subscriptions.
-Default port is 5557, if $NAMESERVER_PORT is not defined.
+Default port is 5557, if $POSTTROLL_NAMESERVER_PORT is not defined.
"""
import datetime as dt
import logging
import os
import time
+import warnings
-from threading import Lock
-# pylint: disable=E0611
-from zmq import LINGER, NOBLOCK, POLLIN, REP, REQ, Poller
-
-from posttroll import get_context
+from posttroll import config
from posttroll.address_receiver import AddressReceiver
from posttroll.message import Message
# pylint: enable=E0611
-PORT = int(os.environ.get("NAMESERVER_PORT", 5557))
+DEFAULT_NAMESERVER_PORT = 5557
logger = logging.getLogger(__name__)
-nslock = Lock()
-
-class TimeoutError(BaseException):
- """A timeout."""
+def get_configured_nameserver_port():
+ """Get the configured nameserver port."""
+ try:
+ port = int(os.environ["NAMESERVER_PORT"])
+ warnings.warn("NAMESERVER_PORT is pending deprecation, please use POSTTROLL_NAMESERVER_PORT instead.",
+ PendingDeprecationWarning)
+ except KeyError:
+ port = DEFAULT_NAMESERVER_PORT
+ return config.get("nameserver_port", port)
- pass
# Client functions.
@@ -79,34 +80,16 @@ def get_pub_addresses(names=None, timeout=10, nameserver="localhost"):
def get_pub_address(name, timeout=10, nameserver="localhost"):
"""Get the address of the named publisher.
- Kwargs:
- - name: name of the publishers
- - nameserver: nameserver address to query the publishers from (default: localhost).
+ Args:
+ name: name of the publishers
+ timeout: how long to wait for an address, in seconds.
+ nameserver: nameserver address to query the publishers from (default: localhost).
"""
- # Socket to talk to server
- socket = get_context().socket(REQ)
- try:
- socket.setsockopt(LINGER, int(timeout * 1000))
- socket.connect("tcp://" + nameserver + ":" + str(PORT))
- logger.debug('Connecting to %s',
- "tcp://" + nameserver + ":" + str(PORT))
- poller = Poller()
- poller.register(socket, POLLIN)
-
- message = Message("/oper/ns", "request", {"service": name})
- socket.send_string(str(message))
-
- # Get the reply.
- sock = poller.poll(timeout=timeout * 1000)
- if sock:
- if sock[0][0] == socket:
- message = Message.decode(socket.recv_string(NOBLOCK))
- return message.data
- else:
- raise TimeoutError("Didn't get an address after %d seconds."
- % timeout)
- finally:
- socket.close()
+ if config["backend"] not in ["unsecure_zmq", "secure_zmq"]:
+ raise NotImplementedError(f"Did not recognize backend: {config['backend']}")
+ from posttroll.backends.zmq.ns import zmq_get_pub_address
+ return zmq_get_pub_address(name, timeout, nameserver)
+
# Server part.
@@ -130,6 +113,11 @@ def __init__(self, max_age=None, multicast_enabled=True, restrict_to_localhost=F
self._max_age = max_age or dt.timedelta(minutes=10)
self._multicast_enabled = multicast_enabled
self._restrict_to_localhost = restrict_to_localhost
+ backend = config["backend"]
+ if backend not in ["unsecure_zmq", "secure_zmq"]:
+ raise NotImplementedError(f"Did not recognize backend: {backend}")
+ from posttroll.backends.zmq.ns import ZMQNameServer
+ self._ns = ZMQNameServer()
def run(self, *args):
"""Run the listener and answer to requests."""
@@ -139,37 +127,11 @@ def run(self, *args):
multicast_enabled=self._multicast_enabled,
restrict_to_localhost=self._restrict_to_localhost)
arec.start()
- port = PORT
-
try:
- with nslock:
- self.listener = get_context().socket(REP)
- self.listener.bind("tcp://*:" + str(port))
- logger.debug('Listening on port %s', str(port))
- poller = Poller()
- poller.register(self.listener, POLLIN)
- while self.loop:
- with nslock:
- socks = dict(poller.poll(1000))
- if socks:
- if socks.get(self.listener) == POLLIN:
- msg = self.listener.recv_string()
- else:
- continue
- logger.debug("Replying to request: " + str(msg))
- msg = Message.decode(msg)
- active_address = get_active_address(msg.data["service"], arec)
- self.listener.send_unicode(str(active_address))
- except KeyboardInterrupt:
- # Needed to stop the nameserver.
- pass
+ return self._ns.run(arec)
finally:
arec.stop()
- self.stop()
def stop(self):
- """Stop the name server."""
- self.listener.setsockopt(LINGER, 1)
- self.loop = False
- with nslock:
- self.listener.close()
+ """Stop the nameserver."""
+ return self._ns.stop()
diff --git a/posttroll/publisher.py b/posttroll/publisher.py
index 4a9bdec..dee85cc 100644
--- a/posttroll/publisher.py
+++ b/posttroll/publisher.py
@@ -26,15 +26,10 @@
import datetime as dt
import logging
import socket
-from threading import Lock
-from urllib.parse import urlsplit, urlunsplit
-import zmq
-from posttroll import get_context
-from posttroll import _set_tcp_keepalive
+from posttroll import config
from posttroll.message import Message
from posttroll.message_broadcaster import sendaddressservice
-from posttroll import config
LOGGER = logging.getLogger(__name__)
@@ -92,56 +87,31 @@ class Publisher:
def __init__(self, address, name="", min_port=None, max_port=None):
"""Bind the publisher class to a port."""
- self.name = name
- self.destination = address
- self.publish_socket = None
# Limit port range or use the defaults when no port is defined
# by the user
- self.min_port = min_port or int(config.get('pub_min_port', 49152))
- self.max_port = max_port or int(config.get('pub_max_port', 65536))
- self.port_number = None
-
+ min_port = min_port or int(config.get("pub_min_port", 49152))
+ max_port = max_port or int(config.get("pub_max_port", 65536))
# Initialize no heartbeat
self._heartbeat = None
- self._pub_lock = Lock()
+
+ backend = config.get("backend", "unsecure_zmq")
+ if backend not in ["unsecure_zmq", "secure_zmq"]:
+ raise NotImplementedError(f"No support for backend {backend} implemented (yet?).")
+ from posttroll.backends.zmq.publisher import ZMQPublisher
+ self._publisher = ZMQPublisher(address, name=name, min_port=min_port, max_port=max_port)
def start(self):
"""Start the publisher."""
- self.publish_socket = get_context().socket(zmq.PUB)
- _set_tcp_keepalive(self.publish_socket)
-
- self.bind()
- LOGGER.info("publisher started on port %s", str(self.port_number))
+ self._publisher.start()
return self
- def bind(self):
- """Bind the port."""
- # Check for port 0 (random port)
- u__ = urlsplit(self.destination)
- port = u__.port
- if port == 0:
- dest = urlunsplit((u__.scheme, u__.hostname,
- u__.path, u__.query, u__.fragment))
- self.port_number = self.publish_socket.bind_to_random_port(
- dest,
- min_port=self.min_port,
- max_port=self.max_port)
- netloc = u__.hostname + ":" + str(self.port_number)
- self.destination = urlunsplit((u__.scheme, netloc, u__.path,
- u__.query, u__.fragment))
- else:
- self.publish_socket.bind(self.destination)
- self.port_number = port
-
def send(self, msg):
"""Send the given message."""
- with self._pub_lock:
- self.publish_socket.send_string(msg)
+ return self._publisher.send(msg)
def stop(self):
"""Stop the publisher."""
- self.publish_socket.setsockopt(zmq.LINGER, 1)
- self.publish_socket.close()
+ return self._publisher.stop()
def close(self):
"""Alias for stop."""
@@ -153,13 +123,23 @@ def heartbeat(self, min_interval=0):
self._heartbeat = _PublisherHeartbeat(self)
self._heartbeat(min_interval)
+ @property
+ def name(self):
+ """Get the name of the publisher."""
+ return self._publisher.name
+
+ @property
+ def port_number(self):
+ """Get the port number from the actual publisher."""
+ return self._publisher.port_number
+
class _PublisherHeartbeat:
"""Publisher for heartbeat."""
def __init__(self, publisher):
self.publisher = publisher
- self.subject = '/heartbeat/' + publisher.name
+ self.subject = "/heartbeat/" + publisher.name
self.lastbeat = dt.datetime(1900, 1, 1)
def __call__(self, min_interval=0):
@@ -211,17 +191,18 @@ def __init__(self, name, port=0, aliases=None, broadcast_interval=2,
def start(self):
"""Start the publisher."""
- pub_addr = _get_publish_address(self._port)
- self._publisher = self._publisher_class(pub_addr, self._name,
+ pub_addr = _create_tcp_publish_address(self._port)
+ self._publisher = self._publisher_class(pub_addr, name=self._name,
min_port=self.min_port,
- max_port=self.max_port).start()
- LOGGER.debug("entering publish %s", str(self._publisher.destination))
- addr = _get_publish_address(self._publisher.port_number, str(get_own_ip()))
+ max_port=self.max_port)
+ self._publisher.start()
+ addr = _create_tcp_publish_address(self._publisher.port_number, str(get_own_ip()))
self._broadcaster = sendaddressservice(self._name, addr,
self._aliases,
self._broadcast_interval,
- self._nameservers).start()
- return self._publisher
+ self._nameservers)
+ self._broadcaster.start()
+ return self
def send(self, msg):
"""Send a *msg*."""
@@ -241,12 +222,17 @@ def close(self):
"""Alias for stop."""
self.stop()
+ @property
+ def port_number(self):
+ """Get the port number."""
+ return self._publisher.port_number
+
def heartbeat(self, min_interval=0):
- """Publish a heartbeat."""
- self._publisher.heartbeat(min_interval=min_interval)
+ """Send a heartbeat ... but only if *min_interval* seconds has passed since last beat."""
+ self._publisher.heartbeat(min_interval)
-def _get_publish_address(port, ip_address="*"):
+def _create_tcp_publish_address(port, ip_address="*"):
return "tcp://" + ip_address + ":" + str(port)
@@ -255,7 +241,7 @@ class Publish:
See :class:`Publisher` and :class:`NoisyPublisher` for more information on the arguments.
- The publisher is selected based on the arguments, see :function:`create_publisher_from_dict_config` for
+ The publisher is selected based on the arguments, see :func:`create_publisher_from_dict_config` for
information how the selection is done.
Example on how to use the :class:`Publish` context::
@@ -281,9 +267,9 @@ class Publish:
def __init__(self, name, port=0, aliases=None, broadcast_interval=2, nameservers=None,
min_port=None, max_port=None):
"""Initialize the class."""
- settings = {'name': name, 'port': port, 'min_port': min_port, 'max_port': max_port,
- 'aliases': aliases, 'broadcast_interval': broadcast_interval,
- 'nameservers': nameservers}
+ settings = {"name": name, "port": port, "min_port": min_port, "max_port": max_port,
+ "aliases": aliases, "broadcast_interval": broadcast_interval,
+ "nameservers": nameservers}
self.publisher = create_publisher_from_dict_config(settings)
def __enter__(self):
@@ -315,18 +301,21 @@ def create_publisher_from_dict_config(settings):
described in the docstrings of the respective classes, namely :class:`~posttroll.publisher.Publisher` and
:class:`~posttroll.publisher.NoisyPublisher`.
"""
- if settings.get('port') and settings.get('nameservers') is False:
+ if (settings.get("port") or settings.get("address")) and settings.get("nameservers") is False:
return _get_publisher_instance(settings)
return _get_noisypublisher_instance(settings)
def _get_publisher_instance(settings):
- publisher_address = _get_publish_address(settings['port'])
- publisher_name = settings.get("name", "")
- min_port = settings.get("min_port")
- max_port = settings.get("max_port")
-
- return Publisher(publisher_address, name=publisher_name, min_port=min_port, max_port=max_port)
+ settings = settings.copy()
+ publisher_address = settings.pop("address", None)
+ port = settings.pop("port", None)
+ if not publisher_address:
+ publisher_address = _create_tcp_publish_address(port)
+ settings.pop("nameservers", None)
+ settings.pop("aliases", None)
+ settings.pop("broadcast_interval", None)
+ return Publisher(publisher_address, **settings)
def _get_noisypublisher_instance(settings):
diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py
index d966548..fc3a8c1 100644
--- a/posttroll/subscriber.py
+++ b/posttroll/subscriber.py
@@ -27,17 +27,10 @@
import datetime as dt
import logging
import time
-from time import sleep
-from threading import Lock
-from urllib.parse import urlsplit
-# pylint: disable=E0611
-from zmq import LINGER, NOBLOCK, POLLIN, PULL, SUB, SUBSCRIBE, Poller, ZMQError
-
-# pylint: enable=E0611
-from posttroll import get_context
-from posttroll import _set_tcp_keepalive
-from posttroll.message import _MAGICK, Message
+from posttroll import config
+from posttroll.address_receiver import get_configured_address_port
+from posttroll.message import _MAGICK
from posttroll.ns import get_pub_address
LOGGER = logging.getLogger(__name__)
@@ -67,81 +60,31 @@ class Subscriber:
"""
- def __init__(self, addresses, topics='', message_filter=None, translate=False):
+ def __init__(self, addresses, topics="", message_filter=None, translate=False):
"""Initialize the subscriber."""
- self._topics = self._magickfy_topics(topics)
- self._filter = message_filter
- self._translate = translate
-
- self.sub_addr = {}
- self.addr_sub = {}
-
- self._hooks = []
- self._hooks_cb = {}
-
- self.poller = Poller()
- self._lock = Lock()
-
- self.update(addresses)
+ topics = self._magickfy_topics(topics)
+ backend = config.get("backend", "unsecure_zmq")
+ if backend not in ["unsecure_zmq", "secure_zmq"]:
+ raise NotImplementedError(f"No support for backend {backend} implemented (yet?).")
- self._loop = None
+ from posttroll.backends.zmq.subscriber import ZMQSubscriber
+ self._subscriber = ZMQSubscriber(addresses, topics=topics,
+ message_filter=message_filter, translate=translate)
def add(self, address, topics=None):
"""Add *address* to the subscribing list for *topics*.
- It topics is None we will subscibe to already specified topics.
+ It topics is None we will subscribe to already specified topics.
"""
- with self._lock:
- if address in self.addresses:
- return
-
- topics = self._magickfy_topics(topics) or self._topics
- LOGGER.info("Subscriber adding address %s with topics %s",
- str(address), str(topics))
- subscriber = self._add_sub_socket(address, topics)
- self.sub_addr[subscriber] = address
- self.addr_sub[address] = subscriber
-
- def _add_sub_socket(self, address, topics):
- subscriber = get_context().socket(SUB)
- _set_tcp_keepalive(subscriber)
- for t__ in topics:
- subscriber.setsockopt_string(SUBSCRIBE, str(t__))
- subscriber.connect(address)
-
- if self.poller:
- self.poller.register(subscriber, POLLIN)
- return subscriber
+ return self._subscriber.add(address, self._magickfy_topics(topics))
def remove(self, address):
"""Remove *address* from the subscribing list for *topics*."""
- with self._lock:
- try:
- subscriber = self.addr_sub[address]
- except KeyError:
- return
- LOGGER.info("Subscriber removing address %s", str(address))
- del self.addr_sub[address]
- del self.sub_addr[subscriber]
- self._remove_sub_socket(subscriber)
-
- def _remove_sub_socket(self, subscriber):
- if self.poller:
- self.poller.unregister(subscriber)
- subscriber.close()
+ return self._subscriber.remove(address)
def update(self, addresses):
"""Update with a set of addresses."""
- if isinstance(addresses, str):
- addresses = [addresses, ]
- current_addresses, new_addresses = set(self.addresses), set(addresses)
- addresses_to_remove = current_addresses.difference(new_addresses)
- addresses_to_add = new_addresses.difference(current_addresses)
- for addr in addresses_to_remove:
- self.remove(addr)
- for addr in addresses_to_add:
- self.add(addr)
- return bool(addresses_to_remove or addresses_to_add)
+ return self._subscriber.update(addresses)
def add_hook_sub(self, address, topics, callback):
"""Specify a SUB *callback* in the same stream (thread) as the main receive loop.
@@ -152,11 +95,7 @@ def add_hook_sub(self, address, topics, callback):
Good for operations, which is required to be done in the same thread as
the main recieve loop (e.q operations on the underlying sockets).
"""
- topics = self._magickfy_topics(topics)
- LOGGER.info("Subscriber adding SUB hook %s for topics %s",
- str(address), str(topics))
- socket = self._add_sub_socket(address, topics)
- self._add_hook(socket, callback)
+ return self._subscriber.add_hook_sub(address, self._magickfy_topics(topics), callback)
def add_hook_pull(self, address, callback):
"""Specify a PULL *callback* in the same stream (thread) as the main receive loop.
@@ -164,84 +103,38 @@ def add_hook_pull(self, address, callback):
The callback will be called with the received messages from the
specified subscription. Good for pushed 'inproc' messages from another thread.
"""
- LOGGER.info("Subscriber adding PULL hook %s", str(address))
- socket = get_context().socket(PULL)
- socket.connect(address)
- if self.poller:
- self.poller.register(socket, POLLIN)
- self._add_hook(socket, callback)
-
- def _add_hook(self, socket, callback):
- """Add a generic hook. The passed socket has to be "receive only"."""
- self._hooks.append(socket)
- self._hooks_cb[socket] = callback
+ return self._subscriber.add_hook_pull(address, callback)
@property
def addresses(self):
"""Get the addresses."""
- return self.sub_addr.values()
+ return self._subscriber.addresses
@property
def subscribers(self):
"""Get the subscribers."""
- return self.sub_addr.keys()
+ return self._subscriber.subscribers
def recv(self, timeout=None):
"""Receive, optionally with *timeout* in seconds."""
- if timeout:
- timeout *= 1000.
-
- for sub in list(self.subscribers) + self._hooks:
- self.poller.register(sub, POLLIN)
- self._loop = True
- try:
- while self._loop:
- sleep(0)
- try:
- socks = dict(self.poller.poll(timeout=timeout))
- if socks:
- for sub in self.subscribers:
- if sub in socks and socks[sub] == POLLIN:
- m__ = Message.decode(sub.recv_string(NOBLOCK))
- if not self._filter or self._filter(m__):
- if self._translate:
- url = urlsplit(self.sub_addr[sub])
- host = url[1].split(":")[0]
- m__.sender = (m__.sender.split("@")[0]
- + "@" + host)
- yield m__
-
- for sub in self._hooks:
- if sub in socks and socks[sub] == POLLIN:
- m__ = Message.decode(sub.recv_string(NOBLOCK))
- self._hooks_cb[sub](m__)
- else:
- # timeout
- yield None
- except ZMQError as err:
- if self._loop:
- LOGGER.exception("Receive failed: %s", str(err))
- finally:
- for sub in list(self.subscribers) + self._hooks:
- self.poller.unregister(sub)
+ return self._subscriber.recv(timeout)
def __call__(self, **kwargs):
"""Handle calls with class instance."""
- return self.recv(**kwargs)
+ return self._subscriber(**kwargs)
def stop(self):
"""Stop the subscriber."""
- self._loop = False
+ return self._subscriber.stop()
def close(self):
"""Close the subscriber: stop it and close the local subscribers."""
- self.stop()
- for sub in list(self.subscribers) + self._hooks:
- try:
- sub.setsockopt(LINGER, 1)
- sub.close()
- except ZMQError:
- pass
+ return self._subscriber.close()
+
+ @property
+ def running(self):
+ """Check if suscriber is running."""
+ return self._subscriber.running
@staticmethod
def _magickfy_topics(topics):
@@ -255,21 +148,13 @@ def _magickfy_topics(topics):
ts_ = []
for t__ in topics:
if not t__.startswith(_MAGICK):
- if t__ and t__[0] == '/':
+ if t__ and t__[0] == "/":
t__ = _MAGICK + t__
else:
- t__ = _MAGICK + '/' + t__
+ t__ = _MAGICK + "/" + t__
ts_.append(t__)
return ts_
- def __del__(self):
- """Clean up after the instance is deleted."""
- for sub in list(self.subscribers) + self._hooks:
- try:
- sub.close()
- except Exception: # noqa: E722
- pass
-
class NSSubscriber:
"""Automatically subscribe to *services*.
@@ -296,9 +181,9 @@ def __init__(self, services="", topics=_MAGICK, addr_listener=False,
Default is to listen to all available services.
"""
- self._services = _to_array(services)
- self._topics = _to_array(topics)
- self._addresses = _to_array(addresses)
+ self._services = _to_list(services)
+ self._topics = _to_list(topics)
+ self._addresses = _to_list(addresses)
self._timeout = timeout
self._translate = translate
@@ -313,7 +198,7 @@ def _get_addr_loop(service, timeout):
"""Try to get the address of *service* until for *timeout* seconds."""
then = dt.datetime.now() + dt.timedelta(seconds=timeout)
while dt.datetime.now() < then:
- addrs = get_pub_address(service, nameserver=self._nameserver)
+ addrs = get_pub_address(service, self._timeout, nameserver=self._nameserver)
if addrs:
return [addr["URI"] for addr in addrs]
time.sleep(1)
@@ -360,7 +245,7 @@ class Subscribe:
See :class:`NSSubscriber` and :class:`Subscriber` for initialization parameters.
- The subscriber is selected based on the arguments, see :function:`create_subscriber_from_dict_config` for
+ The subscriber is selected based on the arguments, see :func:`create_subscriber_from_dict_config` for
information how the selection is done.
Example::
@@ -379,14 +264,14 @@ def __init__(self, services="", topics=_MAGICK, addr_listener=False,
message_filter=None):
"""Initialize the class."""
settings = {
- 'services': services,
- 'topics': topics,
- 'message_filter': message_filter,
- 'translate': translate,
- 'addr_listener': addr_listener,
- 'addresses': addresses,
- 'timeout': timeout,
- 'nameserver': nameserver,
+ "services": services,
+ "topics": topics,
+ "message_filter": message_filter,
+ "translate": translate,
+ "addr_listener": addr_listener,
+ "addresses": addresses,
+ "timeout": timeout,
+ "nameserver": nameserver,
}
self.subscriber = create_subscriber_from_dict_config(settings)
@@ -399,7 +284,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return self.subscriber.stop()
-def _to_array(obj):
+def _to_list(obj):
"""Convert *obj* to list if not already one."""
if isinstance(obj, str):
return [obj, ]
@@ -417,16 +302,17 @@ def __init__(self, subscriber, services="", nameserver="localhost"):
services = [services, ]
self.services = services
self.subscriber = subscriber
- self.subscriber.add_hook_sub("tcp://" + nameserver + ":16543",
+ address_publish_port = get_configured_address_port()
+ self.subscriber.add_hook_sub("tcp://" + nameserver + ":" + str(address_publish_port),
["pytroll://address", ],
self.handle_msg)
def handle_msg(self, msg):
"""Handle the message *msg*."""
addr_ = msg.data["URI"]
- status = msg.data.get('status', True)
+ status = msg.data.get("status", True)
if status:
- msg_services = msg.data.get('service')
+ msg_services = msg.data.get("service")
for service in self.services:
if not service or service in msg_services:
LOGGER.debug("Adding address %s %s", str(addr_),
@@ -455,28 +341,29 @@ def create_subscriber_from_dict_config(settings):
:class:`~posttroll.subscriber.Subscriber` and :class:`~posttroll.subscriber.NSSubscriber`.
"""
- if settings.get('addresses') and settings.get('nameserver') is False:
+ if settings.get("addresses") and settings.get("nameserver") is False:
return _get_subscriber_instance(settings)
return _get_nssubscriber_instance(settings).start()
def _get_subscriber_instance(settings):
- addresses = settings['addresses']
- topics = settings.get('topics', '')
- message_filter = settings.get('message_filter', None)
- translate = settings.get('translate', False)
+ _ = settings.pop("nameserver", None)
+ _ = settings.pop("port", None)
+ _ = settings.pop("services", None)
+ _ = settings.pop("addr_listener", None),
+ _ = settings.pop("timeout", None)
- return Subscriber(addresses, topics=topics, message_filter=message_filter, translate=translate)
+ return Subscriber(**settings)
def _get_nssubscriber_instance(settings):
- services = settings.get('services', '')
- topics = settings.get('topics', _MAGICK)
- addr_listener = settings.get('addr_listener', False)
- addresses = settings.get('addresses', None)
- timeout = settings.get('timeout', 10)
- translate = settings.get('translate', False)
- nameserver = settings.get('nameserver', 'localhost') or 'localhost'
+ services = settings.get("services", "")
+ topics = settings.get("topics", _MAGICK)
+ addr_listener = settings.get("addr_listener", False)
+ addresses = settings.get("addresses", None)
+ timeout = settings.get("timeout", 10)
+ translate = settings.get("translate", False)
+ nameserver = settings.get("nameserver", "localhost") or "localhost"
return NSSubscriber(
services=services,
diff --git a/posttroll/testing.py b/posttroll/testing.py
index d9cbc60..ceb79dc 100644
--- a/posttroll/testing.py
+++ b/posttroll/testing.py
@@ -10,7 +10,7 @@ def patched_subscriber_recv(messages):
def interuptible_recv(self):
"""Yield message until the subscriber is closed."""
for msg in messages:
- if self._loop is False:
+ if self.running is False:
break
yield msg
@@ -20,7 +20,7 @@ def interuptible_recv(self):
@contextmanager
def patched_publisher():
- """Patch the Subscriber object to return given messages."""
+ """Patch the Publisher object to return given messages."""
from unittest import mock
published = []
diff --git a/posttroll/tests/__init__.py b/posttroll/tests/__init__.py
index 88e1689..5e53b2b 100644
--- a/posttroll/tests/__init__.py
+++ b/posttroll/tests/__init__.py
@@ -20,25 +20,4 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-"""Tests package.
-"""
-
-from posttroll.tests import test_bbmcast, test_message, test_pubsub
-import unittest
-
-
-def suite():
- """The global test suite.
- """
- mysuite = unittest.TestSuite()
- # Test the documentation strings
- # Use the unittests also
- mysuite.addTests(test_bbmcast.suite())
- mysuite.addTests(test_message.suite())
- mysuite.addTests(test_pubsub.suite())
-
- return mysuite
-
-
-def load_tests(loader, tests, pattern):
- return suite()
+"""Tests package."""
diff --git a/posttroll/tests/test_bbmcast.py b/posttroll/tests/test_bbmcast.py
index a7bea8d..b61b26c 100644
--- a/posttroll/tests/test_bbmcast.py
+++ b/posttroll/tests/test_bbmcast.py
@@ -20,100 +20,177 @@
# You should have received a copy of the GNU General Public License along with
# pytroll. If not, see .
-import unittest
-import random
-from socket import SOL_SOCKET, SO_BROADCAST, error
-
-from posttroll import bbmcast
-
-
-class TestBB(unittest.TestCase):
-
- """Test class.
- """
-
- def test_mcast_sender(self):
- """Unit test for mcast_sender.
- """
- mcgroup = (str(random.randint(224, 239)) + "." +
- str(random.randint(0, 255)) + "." +
- str(random.randint(0, 255)) + "." +
- str(random.randint(0, 255)))
- socket, group = bbmcast.mcast_sender(mcgroup)
- if mcgroup in ("0.0.0.0", "255.255.255.255"):
- self.assertEqual(group, "")
- self.assertEqual(socket.getsockopt(SOL_SOCKET, SO_BROADCAST), 1)
- else:
- self.assertEqual(group, mcgroup)
- self.assertEqual(socket.getsockopt(SOL_SOCKET, SO_BROADCAST), 0)
-
- socket.close()
+"""Test multicasting and broadcasting."""
- mcgroup = "0.0.0.0"
- socket, group = bbmcast.mcast_sender(mcgroup)
- self.assertEqual(group, "")
- self.assertEqual(socket.getsockopt(SOL_SOCKET, SO_BROADCAST), 1)
- socket.close()
+import os
+import random
+from socket import SO_BROADCAST, SOL_SOCKET, error
+from threading import Thread
- mcgroup = "255.255.255.255"
- socket, group = bbmcast.mcast_sender(mcgroup)
- self.assertEqual(group, "")
- self.assertEqual(socket.getsockopt(SOL_SOCKET, SO_BROADCAST), 1)
- socket.close()
+import pytest
- mcgroup = (str(random.randint(0, 223)) + "." +
- str(random.randint(0, 255)) + "." +
- str(random.randint(0, 255)) + "." +
- str(random.randint(0, 255)))
- self.assertRaises(IOError, bbmcast.mcast_sender, mcgroup)
-
- mcgroup = (str(random.randint(240, 255)) + "." +
- str(random.randint(0, 255)) + "." +
- str(random.randint(0, 255)) + "." +
- str(random.randint(0, 255)))
- self.assertRaises(IOError, bbmcast.mcast_sender, mcgroup)
-
- def test_mcast_receiver(self):
- """Unit test for mcast_receiver.
- """
- mcport = random.randint(1025, 65535)
- mcgroup = "0.0.0.0"
- socket, group = bbmcast.mcast_receiver(mcport, mcgroup)
- self.assertEqual(group, "")
- socket.close()
+from posttroll import bbmcast
- mcgroup = "255.255.255.255"
- socket, group = bbmcast.mcast_receiver(mcport, mcgroup)
- self.assertEqual(group, "")
- socket.close()
- # Valid multicast range is 224.0.0.0 to 239.255.255.255
- mcgroup = (str(random.randint(224, 239)) + "." +
- str(random.randint(0, 255)) + "." +
- str(random.randint(0, 255)) + "." +
- str(random.randint(0, 255)))
- socket, group = bbmcast.mcast_receiver(mcport, mcgroup)
- self.assertEqual(group, mcgroup)
+def test_mcast_sender_works_with_valid_addresses():
+ """Unit test for mcast_sender."""
+ mcgroup = (str(random.randint(224, 239)) + "." +
+ str(random.randint(0, 255)) + "." +
+ str(random.randint(0, 255)) + "." +
+ str(random.randint(0, 255)))
+ socket, group = bbmcast.mcast_sender(mcgroup)
+ if mcgroup in ("0.0.0.0", "255.255.255.255"):
+ assert group == ""
+ assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1
+ else:
+ assert group == mcgroup
+ assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 0
+
+ socket.close()
+
+
+def test_mcast_sender_uses_broadcast_for_0s():
+ """Test mcast_sender uses broadcast for 0.0.0.0."""
+ mcgroup = "0.0.0.0"
+ socket, group = bbmcast.mcast_sender(mcgroup)
+ assert group == ""
+ assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1
+ socket.close()
+
+
+def test_mcast_sender_uses_broadcast_for_255s():
+ """Test mcast_sender uses broadcast for 255.255.255.255."""
+ mcgroup = "255.255.255.255"
+ socket, group = bbmcast.mcast_sender(mcgroup)
+ assert group == ""
+ assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1
+ socket.close()
+
+
+def test_mcast_sender_raises_for_invalit_adresses():
+ """Test mcast_sender uses broadcast for 0.0.0.0."""
+ mcgroup = (str(random.randint(0, 223)) + "." +
+ str(random.randint(0, 255)) + "." +
+ str(random.randint(0, 255)) + "." +
+ str(random.randint(0, 255)))
+ with pytest.raises(OSError, match="Invalid multicast address .*"):
+ bbmcast.mcast_sender(mcgroup)
+
+ mcgroup = (str(random.randint(240, 255)) + "." +
+ str(random.randint(0, 255)) + "." +
+ str(random.randint(0, 255)) + "." +
+ str(random.randint(0, 255)))
+ with pytest.raises(OSError, match="Invalid multicast address .*"):
+ bbmcast.mcast_sender(mcgroup)
+
+
+def test_mcast_receiver_works_with_valid_addresses():
+ """Unit test for mcast_receiver."""
+ mcport = random.randint(1025, 65535)
+ mcgroup = "0.0.0.0"
+ socket, group = bbmcast.mcast_receiver(mcport, mcgroup)
+ assert group == ""
+ socket.close()
+
+ mcgroup = "255.255.255.255"
+ socket, group = bbmcast.mcast_receiver(mcport, mcgroup)
+ assert group == ""
+ socket.close()
+
+ # Valid multicast range is 224.0.0.0 to 239.255.255.255
+ mcgroup = (str(random.randint(224, 239)) + "." +
+ str(random.randint(0, 255)) + "." +
+ str(random.randint(0, 255)) + "." +
+ str(random.randint(0, 255)))
+ socket, group = bbmcast.mcast_receiver(mcport, mcgroup)
+ assert group == mcgroup
+ socket.close()
+
+ mcgroup = (str(random.randint(0, 223)) + "." +
+ str(random.randint(0, 255)) + "." +
+ str(random.randint(0, 255)) + "." +
+ str(random.randint(0, 255)))
+ with pytest.raises(error, match=".*Invalid argument.*"):
+ bbmcast.mcast_receiver(mcport, mcgroup)
+
+ mcgroup = (str(random.randint(240, 255)) + "." +
+ str(random.randint(0, 255)) + "." +
+ str(random.randint(0, 255)) + "." +
+ str(random.randint(0, 255)))
+ with pytest.raises(error, match=".*Invalid argument.*"):
+ bbmcast.mcast_receiver(mcport, mcgroup)
+
+
+@pytest.mark.skipif(
+ os.getenv("DISABLED_MULTICAST"),
+ reason="Multicast tests disabled.",
+)
+def test_multicast_roundtrip(reraise):
+ """Test sending and receiving a multicast message."""
+ mcgroup = bbmcast.DEFAULT_MC_GROUP
+ mcport = 5555
+ rec_socket, _rec_group = bbmcast.mcast_receiver(mcport, mcgroup)
+ rec_socket.settimeout(.1)
+
+ message = "Ho Ho Ho!"
+
+ def check_message(sock, message):
+ with reraise:
+ data, _ = sock.recvfrom(1024)
+ assert data.decode() == message
+
+ snd_socket, _snd_group = bbmcast.mcast_sender(mcgroup)
+
+ thr = Thread(target=check_message, args=(rec_socket, message))
+ thr.start()
+
+ snd_socket.sendto(message.encode(), (mcgroup, mcport))
+
+ thr.join()
+ rec_socket.close()
+ snd_socket.close()
+
+
+def test_broadcast_roundtrip(reraise):
+ """Test sending and receiving a broadcast message."""
+ mcgroup = "0.0.0.0"
+ mcport = 5555
+ rec_socket, _rec_group = bbmcast.mcast_receiver(mcport, mcgroup)
+
+ message = "Ho Ho Ho!"
+
+ def check_message(sock, message):
+ with reraise:
+ data, _ = sock.recvfrom(1024)
+ assert data.decode() == message
+
+ snd_socket, _snd_group = bbmcast.mcast_sender(mcgroup)
+
+ thr = Thread(target=check_message, args=(rec_socket, message))
+ thr.start()
+
+ snd_socket.sendto(message.encode(), (mcgroup, mcport))
+
+ thr.join()
+ rec_socket.close()
+ snd_socket.close()
+
+
+def test_posttroll_mc_group_is_used():
+ """Test that configured mc_group is used."""
+ from posttroll import config
+ other_group = "226.0.0.13"
+ with config.set(mc_group=other_group):
+ socket, group = bbmcast.mcast_sender()
socket.close()
+ assert group == "226.0.0.13"
- mcgroup = (str(random.randint(0, 223)) + "." +
- str(random.randint(0, 255)) + "." +
- str(random.randint(0, 255)) + "." +
- str(random.randint(0, 255)))
- self.assertRaises(error, bbmcast.mcast_receiver, mcport, mcgroup)
-
- mcgroup = (str(random.randint(240, 255)) + "." +
- str(random.randint(0, 255)) + "." +
- str(random.randint(0, 255)) + "." +
- str(random.randint(0, 255)))
- self.assertRaises(error, bbmcast.mcast_receiver, mcport, mcgroup)
-
-
-def suite():
- """The suite for test_bbmcast.
- """
- loader = unittest.TestLoader()
- mysuite = unittest.TestSuite()
- mysuite.addTest(loader.loadTestsFromTestCase(TestBB))
- return mysuite
+def test_pytroll_mc_group_is_deprecated(monkeypatch):
+ """Test that PYTROLL_MC_GROUP is used but pending deprecation."""
+ other_group = "226.0.0.13"
+ monkeypatch.setenv("PYTROLL_MC_GROUP", other_group)
+ with pytest.deprecated_call():
+ socket, group = bbmcast.mcast_sender()
+ socket.close()
+ assert group == "226.0.0.13"
diff --git a/posttroll/tests/test_message.py b/posttroll/tests/test_message.py
index 3a4ee6e..af97236 100644
--- a/posttroll/tests/test_message.py
+++ b/posttroll/tests/test_message.py
@@ -23,26 +23,25 @@
"""Test module for the message class."""
+import copy
import os
import sys
import unittest
-import copy
from datetime import datetime
-from posttroll.message import Message, _MAGICK
-
+from posttroll.message import _MAGICK, Message
-HOME = os.path.dirname(__file__) or '.'
-sys.path = [os.path.abspath(HOME + '/../..'), ] + sys.path
+HOME = os.path.dirname(__file__) or "."
+sys.path = [os.path.abspath(HOME + "/../.."), ] + sys.path
-DATADIR = HOME + '/data'
-SOME_METADATA = {'timestamp': datetime(2010, 12, 3, 16, 28, 39),
- 'satellite': 'metop2',
- 'uri': 'file://data/my/path/to/hrpt/files/myfile',
- 'orbit': 1222,
- 'format': 'hrpt',
- 'afloat': 1.2345}
+DATADIR = HOME + "/data"
+SOME_METADATA = {"timestamp": datetime(2010, 12, 3, 16, 28, 39),
+ "satellite": "metop2",
+ "uri": "file://data/my/path/to/hrpt/files/myfile",
+ "orbit": 1222,
+ "format": "hrpt",
+ "afloat": 1.2345}
class Test(unittest.TestCase):
@@ -50,41 +49,32 @@ class Test(unittest.TestCase):
def test_encode_decode(self):
"""Test the encoding/decoding of the message class."""
- msg1 = Message('/test/whatup/doc', 'info', data='not much to say')
+ msg1 = Message("/test/whatup/doc", "info", data="not much to say")
- sender = '%s@%s' % (msg1.user, msg1.host)
- self.assertTrue(sender == msg1.sender,
- msg='Messaging, decoding user, host from sender failed')
+ sender = "%s@%s" % (msg1.user, msg1.host)
+ assert sender == msg1.sender, "Messaging, decoding user, host from sender failed"
msg2 = Message.decode(msg1.encode())
- self.assertTrue(str(msg2) == str(msg1),
- msg='Messaging, encoding, decoding failed')
+ assert str(msg2) == str(msg1), "Messaging, encoding, decoding failed"
def test_decode(self):
"""Test the decoding of a message."""
rawstr = (_MAGICK +
- r'/test/1/2/3 info ras@hawaii 2008-04-11T22:13:22.123000 v1.01' +
+ r"/test/1/2/3 info ras@hawaii 2008-04-11T22:13:22.123000 v1.01" +
r' text/ascii "what' + r"'" + r's up doc"')
msg = Message.decode(rawstr)
- self.assertTrue(str(msg) == rawstr,
- msg='Messaging, decoding of message failed')
+ assert str(msg) == rawstr, "Messaging, decoding of message failed"
def test_encode(self):
"""Test the encoding of a message."""
- subject = '/test/whatup/doc'
+ subject = "/test/whatup/doc"
atype = "info"
- data = 'not much to say'
+ data = "not much to say"
msg1 = Message(subject, atype, data=data)
- sender = '%s@%s' % (msg1.user, msg1.host)
- self.assertEqual(_MAGICK +
- subject + " " +
- atype + " " +
- sender + " " +
- str(msg1.time.isoformat()) + " " +
- msg1.version + " " +
- 'text/ascii' + " " +
- data,
- msg1.encode())
+ sender = "%s@%s" % (msg1.user, msg1.host)
+ full_message = (_MAGICK + subject + " " + atype + " " + sender + " " +
+ str(msg1.time.isoformat()) + " " + msg1.version + " " + "text/ascii" + " " + data)
+ assert full_message == msg1.encode()
def test_unicode(self):
"""Test handling of unicode."""
@@ -92,72 +82,66 @@ def test_unicode(self):
msg = ('pytroll://PPS-monitorplot/3/norrköping/utv/polar/direct_readout/ file '
'safusr.u@lxserv1096.smhi.se 2018-11-16T12:19:29.934025 v1.01 application/json'
' {"start_time": "2018-11-16T12:02:43.700000"}')
- self.assertEqual(msg, str(Message(rawstr=msg)))
+ assert msg == str(Message(rawstr=msg))
except UnicodeDecodeError:
- self.fail('Unexpected unicode decoding error')
+ self.fail("Unexpected unicode decoding error")
try:
msg = (u'pytroll://oper/polar/direct_readout/norrköping pong sat@MERLIN 2019-01-07T12:52:19.872171'
r' v1.01 application/json {"station": "norrk\u00f6ping"}')
try:
- self.assertEqual(msg, str(Message(rawstr=msg)).decode('utf-8'))
+ assert msg == str(Message(rawstr=msg)).decode("utf-8")
except AttributeError:
- self.assertEqual(msg, str(Message(rawstr=msg)))
+ assert msg == str(Message(rawstr=msg))
except UnicodeDecodeError:
- self.fail('Unexpected unicode decoding error')
+ self.fail("Unexpected unicode decoding error")
def test_iso(self):
"""Test handling of iso-8859-1."""
msg = ('pytroll://oper/polar/direct_readout/norrköping pong sat@MERLIN '
'2019-01-07T12:52:19.872171 v1.01 application/json {"station": "norrköping"}')
try:
- iso_msg = msg.decode('utf-8').encode('iso-8859-1')
+ iso_msg = msg.decode("utf-8").encode("iso-8859-1")
except AttributeError:
- iso_msg = msg.encode('iso-8859-1')
+ iso_msg = msg.encode("iso-8859-1")
try:
Message(rawstr=iso_msg)
except UnicodeDecodeError:
- self.fail('Unexpected iso decoding error')
+ self.fail("Unexpected iso decoding error")
def test_pickle(self):
"""Test pickling."""
import pickle
- msg1 = Message('/test/whatup/doc', 'info', data='not much to say')
+ msg1 = Message("/test/whatup/doc", "info", data="not much to say")
try:
- fp_ = open("pickle.message", 'wb')
+ fp_ = open("pickle.message", "wb")
pickle.dump(msg1, fp_)
fp_.close()
- fp_ = open("pickle.message", 'rb')
+ fp_ = open("pickle.message", "rb")
msg2 = pickle.load(fp_)
fp_.close()
- self.assertTrue(str(msg1) == str(msg2),
- msg='Messaging, pickle failed')
+ assert str(msg1) == str(msg2), "Messaging, pickle failed"
finally:
try:
- os.remove('pickle.message')
+ os.remove("pickle.message")
except OSError:
pass
def test_metadata(self):
"""Test metadata encoding/decoding."""
metadata = copy.copy(SOME_METADATA)
- msg = Message.decode(Message('/sat/polar/smb/level1', 'file',
+ msg = Message.decode(Message("/sat/polar/smb/level1", "file",
data=metadata).encode())
- self.assertTrue(msg.data == metadata,
- msg='Messaging, metadata decoding / encoding failed')
+ assert msg.data == metadata, "Messaging, metadata decoding / encoding failed"
def test_serialization(self):
"""Test json serialization."""
- compare_file = '/message_metadata.dumps'
- try:
- import json
- except ImportError:
- import simplejson as json
- compare_file += '.simplejson'
+ compare_file = "/message_metadata.dumps"
+ import json
metadata = copy.copy(SOME_METADATA)
- metadata['timestamp'] = metadata['timestamp'].isoformat()
+ metadata["timestamp"] = metadata["timestamp"].isoformat()
fp_ = open(DATADIR + compare_file)
dump = fp_.read()
fp_.close()
@@ -165,21 +149,8 @@ def test_serialization(self):
msg = json.loads(dump)
for key, val in msg.items():
- self.assertEqual(val, metadata.get(key))
+ assert val == metadata.get(key)
msg = json.loads(local_dump)
for key, val in msg.items():
- self.assertEqual(val, metadata.get(key))
-
-
-def suite():
- """Create the suite for test_message."""
- loader = unittest.TestLoader()
- mysuite = unittest.TestSuite()
- mysuite.addTest(loader.loadTestsFromTestCase(Test))
-
- return mysuite
-
-
-if __name__ == '__main__':
- unittest.main()
+ assert val == metadata.get(key)
diff --git a/posttroll/tests/test_nameserver.py b/posttroll/tests/test_nameserver.py
new file mode 100644
index 0000000..f4fe81d
--- /dev/null
+++ b/posttroll/tests/test_nameserver.py
@@ -0,0 +1,257 @@
+"""Tests for communication involving the nameserver for service discovery."""
+
+import os
+import time
+import unittest
+from contextlib import contextmanager
+from datetime import timedelta
+from threading import Thread
+from unittest import mock
+
+import pytest
+
+from posttroll import config
+from posttroll.message import Message
+from posttroll.ns import NameServer, get_pub_address
+from posttroll.publisher import Publish
+from posttroll.subscriber import Subscribe
+
+
+def free_port():
+ """Get a free port.
+
+ From https://gist.github.com/bertjwregeer/0be94ced48383a42e70c3d9fff1f4ad0
+
+ Returns a factory that finds the next free port that is available on the OS
+ This is a bit of a hack, it does this by creating a new socket, and calling
+ bind with the 0 port. The operating system will assign a brand new port,
+ which we can find out using getsockname(). Once we have the new port
+ information we close the socket thereby returning it to the free pool.
+ This means it is technically possible for this function to return the same
+ port twice (for example if run in very quick succession), however operating
+ systems return a random port number in the default range (1024 - 65535),
+ and it is highly unlikely for two processes to get the same port number.
+ In other words, it is possible to flake, but incredibly unlikely.
+ """
+ import socket
+
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ s.bind(("0.0.0.0", 0))
+ portnum = s.getsockname()[1]
+ s.close()
+
+ return portnum
+
+
+@contextmanager
+def create_nameserver_instance(max_age=3, multicast_enabled=True):
+ """Create a nameserver instance."""
+ config.set(nameserver_port=free_port())
+ config.set(address_publish_port=free_port())
+ ns = NameServer(max_age=timedelta(seconds=max_age), multicast_enabled=multicast_enabled)
+ thr = Thread(target=ns.run)
+ thr.start()
+
+ try:
+ yield
+ finally:
+ ns.stop()
+ thr.join()
+
+
+class TestAddressReceiver(unittest.TestCase):
+ """Test the AddressReceiver."""
+
+ @mock.patch("posttroll.address_receiver.Message")
+ @mock.patch("posttroll.address_receiver.Publish")
+ @mock.patch("posttroll.address_receiver.MulticastReceiver")
+ def test_localhost_restriction(self, mcrec, pub, msg):
+ """Test address receiver restricted only to localhost."""
+ mocked_publish_instance = mock.Mock()
+ pub.return_value.__enter__.return_value = mocked_publish_instance
+ mcr_instance = mock.Mock()
+ mcrec.return_value = mcr_instance
+ mcr_instance.return_value = "blabla", ("255.255.255.255", 12)
+
+ from posttroll.address_receiver import AddressReceiver
+ adr = AddressReceiver(restrict_to_localhost=True)
+ adr.start()
+ time.sleep(3)
+ try:
+ msg.decode.assert_not_called()
+ mocked_publish_instance.send.assert_not_called()
+ finally:
+ adr.stop()
+
+
+@pytest.mark.parametrize(
+ "multicast_enabled",
+ [True, False]
+)
+def test_pub_addresses(multicast_enabled):
+ """Test retrieving addresses."""
+ from posttroll.ns import get_pub_addresses
+ from posttroll.publisher import Publish
+
+ if multicast_enabled:
+ if os.getenv("DISABLED_MULTICAST"):
+ pytest.skip("Multicast tests disabled.")
+ nameservers = None
+ else:
+ nameservers = ["localhost"]
+ with config.set(broadcast_port=free_port()):
+ with create_nameserver_instance(multicast_enabled=multicast_enabled):
+ with Publish(str("data_provider"), 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1):
+ time.sleep(.3)
+ res = get_pub_addresses(["this_data"], timeout=.5)
+ assert len(res) == 1
+ expected = {u"status": True,
+ u"service": [u"data_provider", u"this_data"],
+ u"name": u"address"}
+ for key, val in expected.items():
+ assert res[0][key] == val
+ assert "receive_time" in res[0]
+ assert "URI" in res[0]
+ res = get_pub_addresses([str("data_provider")])
+ assert len(res) == 1
+ expected = {u"status": True,
+ u"service": [u"data_provider", u"this_data"],
+ u"name": u"address"}
+ for key, val in expected.items():
+ assert res[0][key] == val
+ assert "receive_time" in res[0]
+ assert "URI" in res[0]
+
+
+@pytest.mark.parametrize(
+ "multicast_enabled",
+ [True, False]
+)
+def test_pub_sub_ctx(multicast_enabled):
+ """Test publish and subscribe."""
+ if multicast_enabled:
+ if os.getenv("DISABLED_MULTICAST"):
+ pytest.skip("Multicast tests disabled.")
+ nameservers = None
+ else:
+ nameservers = ["localhost"]
+
+ with config.set(broadcast_port=free_port()):
+ with create_nameserver_instance(multicast_enabled=multicast_enabled):
+ with Publish("data_provider", 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1) as pub:
+ with Subscribe("this_data", "counter") as sub:
+ for counter in range(5):
+ message = Message("/counter", "info", str(counter))
+ pub.send(str(message))
+ time.sleep(.1)
+ msg = next(sub.recv(.2))
+ if msg is not None:
+ assert str(msg) == str(message)
+ tested = True
+ assert tested
+
+
+@pytest.mark.parametrize(
+ "multicast_enabled",
+ [True, False]
+)
+def test_pub_sub_add_rm(multicast_enabled):
+ """Test adding and removing publishers."""
+ if multicast_enabled:
+ if os.getenv("DISABLED_MULTICAST"):
+ pytest.skip("Multicast tests disabled.")
+ nameservers = None
+ else:
+ nameservers = ["localhost"]
+
+ max_age = 0.5
+ with config.set(broadcast_port=free_port()):
+ with create_nameserver_instance(max_age=max_age, multicast_enabled=multicast_enabled):
+ with Subscribe("this_data", "counter", addr_listener=True, timeout=.2) as sub:
+ assert len(sub.addresses) == 0
+ with Publish("data_provider", 0, ["this_data"], nameservers=nameservers):
+ time.sleep(.1)
+ next(sub.recv(.1))
+ assert len(sub.addresses) == 1
+ time.sleep(max_age * 4)
+ for msg in sub.recv(.1):
+ if msg is None:
+ break
+ time.sleep(0.3)
+ assert len(sub.addresses) == 0
+ with Publish("data_provider_2", 0, ["another_data"], nameservers=nameservers):
+ time.sleep(.1)
+ next(sub.recv(.1))
+ assert len(sub.addresses) == 0
+
+
+@pytest.mark.skipif(
+ os.getenv("DISABLED_MULTICAST"),
+ reason="Multicast tests disabled.",
+)
+def test_listener_container():
+ """Test listener container."""
+ from posttroll.listener import ListenerContainer
+ from posttroll.message import Message
+ from posttroll.publisher import NoisyPublisher
+
+ with create_nameserver_instance():
+ pub = NoisyPublisher("test", broadcast_interval=0.1)
+ pub.start()
+ sub = ListenerContainer(topics=["/counter"])
+ time.sleep(.1)
+ for counter in range(5):
+ tested = False
+ msg_out = Message("/counter", "info", str(counter))
+ pub.send(str(msg_out))
+
+ msg_in = sub.output_queue.get(True, 1)
+ if msg_in is not None:
+ assert str(msg_in) == str(msg_out)
+ tested = True
+ assert tested
+ pub.stop()
+ sub.stop()
+
+
+@pytest.mark.skipif(
+ os.getenv("DISABLED_MULTICAST"),
+ reason="Multicast tests disabled.",
+)
+def test_noisypublisher_heartbeat():
+ """Test that the heartbeat in the NoisyPublisher works."""
+ from posttroll.publisher import NoisyPublisher
+ from posttroll.subscriber import Subscribe
+
+ min_interval = 10
+
+ try:
+ with config.set(address_publish_port=free_port(), nameserver_port=free_port()):
+ ns_ = NameServer()
+ thr = Thread(target=ns_.run)
+ thr.start()
+
+ pub = NoisyPublisher("test")
+ pub.start()
+ time.sleep(0.2)
+
+ with Subscribe("test", topics="/heartbeat/test", nameserver="localhost") as sub:
+ time.sleep(0.2)
+ pub.heartbeat(min_interval=min_interval)
+ msg = next(sub.recv(1))
+ assert msg.type == "beat"
+ assert msg.data == {"min_interval": min_interval}
+ finally:
+ pub.stop()
+ ns_.stop()
+ thr.join()
+
+
+def test_switch_backend_for_nameserver():
+ """Test switching backend for nameserver."""
+ with config.set(backend="spurious_backend"):
+ with pytest.raises(NotImplementedError):
+ NameServer()
+ with pytest.raises(NotImplementedError):
+ get_pub_address("some_name")
diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py
index a593bb2..c152bf8 100644
--- a/posttroll/tests/test_pubsub.py
+++ b/posttroll/tests/test_pubsub.py
@@ -23,200 +23,49 @@
"""Test the publishing and subscribing facilities."""
-import unittest
-from unittest import mock
-from datetime import timedelta
-from threading import Thread, Lock
import time
+import unittest
from contextlib import contextmanager
+from threading import Lock
+from unittest import mock
-import posttroll
import pytest
from donfig import Config
-test_lock = Lock()
-
-
-class TestNS(unittest.TestCase):
- """Test the nameserver."""
-
- def setUp(self):
- """Set up the testing class."""
- from posttroll.ns import NameServer
- test_lock.acquire()
- self.ns = NameServer(max_age=timedelta(seconds=3))
- self.thr = Thread(target=self.ns.run)
- self.thr.start()
-
- def tearDown(self):
- """Clean up after the tests have run."""
- self.ns.stop()
- self.thr.join()
- time.sleep(2)
- test_lock.release()
-
- def test_pub_addresses(self):
- """Test retrieving addresses."""
- from posttroll.ns import get_pub_addresses
- from posttroll.publisher import Publish
-
- with Publish(str("data_provider"), 0, ["this_data"], broadcast_interval=0.1):
- time.sleep(.3)
- res = get_pub_addresses(["this_data"], timeout=.5)
- assert len(res) == 1
- expected = {u'status': True,
- u'service': [u'data_provider', u'this_data'],
- u'name': u'address'}
- for key, val in expected.items():
- assert res[0][key] == val
- assert "receive_time" in res[0]
- assert "URI" in res[0]
- res = get_pub_addresses([str("data_provider")])
- assert len(res) == 1
- expected = {u'status': True,
- u'service': [u'data_provider', u'this_data'],
- u'name': u'address'}
- for key, val in expected.items():
- assert res[0][key] == val
- assert "receive_time" in res[0]
- assert "URI" in res[0]
-
- def test_pub_sub_ctx(self):
- """Test publish and subscribe."""
- from posttroll.message import Message
- from posttroll.publisher import Publish
- from posttroll.subscriber import Subscribe
-
- with Publish("data_provider", 0, ["this_data"]) as pub:
- with Subscribe("this_data", "counter") as sub:
- for counter in range(5):
- message = Message("/counter", "info", str(counter))
- pub.send(str(message))
- time.sleep(1)
- msg = next(sub.recv(2))
- if msg is not None:
- assert str(msg) == str(message)
- tested = True
- sub.close()
- assert tested
+import posttroll
+from posttroll import config
+from posttroll.message import Message
+from posttroll.publisher import Publish, Publisher, create_publisher_from_dict_config
+from posttroll.subscriber import Subscribe, Subscriber
- def test_pub_sub_add_rm(self):
- """Test adding and removing publishers."""
- from posttroll.publisher import Publish
- from posttroll.subscriber import Subscribe
-
- time.sleep(4)
- with Subscribe("this_data", "counter", True) as sub:
- assert len(sub.sub_addr) == 0
- with Publish("data_provider", 0, ["this_data"]):
- time.sleep(4)
- next(sub.recv(2))
- assert len(sub.sub_addr) == 1
- time.sleep(3)
- for msg in sub.recv(2):
- if msg is None:
- break
- time.sleep(3)
- assert len(sub.sub_addr) == 0
- with Publish("data_provider_2", 0, ["another_data"]):
- time.sleep(4)
- next(sub.recv(2))
- assert len(sub.sub_addr) == 0
- sub.close()
-
-
-class TestNSWithoutMulticasting(unittest.TestCase):
- """Test the nameserver."""
+test_lock = Lock()
- def setUp(self):
- """Set up the testing class."""
- from posttroll.ns import NameServer
- test_lock.acquire()
- self.nameservers = ['localhost']
- self.ns = NameServer(max_age=timedelta(seconds=3),
- multicast_enabled=False)
- self.thr = Thread(target=self.ns.run)
- self.thr.start()
- def tearDown(self):
- """Clean up after the tests have run."""
- self.ns.stop()
- self.thr.join()
- time.sleep(2)
- test_lock.release()
+def free_port():
+ """Get a free port.
- def test_pub_addresses(self):
- """Test retrieving addresses."""
- from posttroll.ns import get_pub_addresses
- from posttroll.publisher import Publish
+ From https://gist.github.com/bertjwregeer/0be94ced48383a42e70c3d9fff1f4ad0
- with Publish("data_provider", 0, ["this_data"],
- nameservers=self.nameservers):
- time.sleep(3)
- res = get_pub_addresses(["this_data"])
- self.assertEqual(len(res), 1)
- expected = {u'status': True,
- u'service': [u'data_provider', u'this_data'],
- u'name': u'address'}
- for key, val in expected.items():
- self.assertEqual(res[0][key], val)
- self.assertTrue("receive_time" in res[0])
- self.assertTrue("URI" in res[0])
- res = get_pub_addresses(["data_provider"])
- self.assertEqual(len(res), 1)
- expected = {u'status': True,
- u'service': [u'data_provider', u'this_data'],
- u'name': u'address'}
- for key, val in expected.items():
- self.assertEqual(res[0][key], val)
- self.assertTrue("receive_time" in res[0])
- self.assertTrue("URI" in res[0])
-
- def test_pub_sub_ctx(self):
- """Test publish and subscribe."""
- from posttroll.message import Message
- from posttroll.publisher import Publish
- from posttroll.subscriber import Subscribe
+ Returns a factory that finds the next free port that is available on the OS
+ This is a bit of a hack, it does this by creating a new socket, and calling
+ bind with the 0 port. The operating system will assign a brand new port,
+ which we can find out using getsockname(). Once we have the new port
+ information we close the socket thereby returning it to the free pool.
+ This means it is technically possible for this function to return the same
+ port twice (for example if run in very quick succession), however operating
+ systems return a random port number in the default range (1024 - 65535),
+ and it is highly unlikely for two processes to get the same port number.
+ In other words, it is possible to flake, but incredibly unlikely.
+ """
+ import socket
- with Publish("data_provider", 0, ["this_data"],
- nameservers=self.nameservers) as pub:
- with Subscribe("this_data", "counter") as sub:
- for counter in range(5):
- message = Message("/counter", "info", str(counter))
- pub.send(str(message))
- time.sleep(1)
- msg = next(sub.recv(2))
- if msg is not None:
- self.assertEqual(str(msg), str(message))
- tested = True
- sub.close()
- self.assertTrue(tested)
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ s.bind(("0.0.0.0", 0))
+ portnum = s.getsockname()[1]
+ s.close()
- def test_pub_sub_add_rm(self):
- """Test adding and removing publishers."""
- from posttroll.publisher import Publish
- from posttroll.subscriber import Subscribe
-
- time.sleep(4)
- with Subscribe("this_data", "counter", True) as sub:
- self.assertEqual(len(sub.sub_addr), 0)
- with Publish("data_provider", 0, ["this_data"],
- nameservers=self.nameservers):
- time.sleep(4)
- next(sub.recv(2))
- self.assertEqual(len(sub.sub_addr), 1)
- time.sleep(3)
- for msg in sub.recv(2):
- if msg is None:
- break
-
- time.sleep(3)
- self.assertEqual(len(sub.sub_addr), 0)
- with Publish("data_provider_2", 0, ["another_data"],
- nameservers=self.nameservers):
- time.sleep(4)
- next(sub.recv(2))
- self.assertEqual(len(sub.sub_addr), 0)
+ return portnum
class TestPubSub(unittest.TestCase):
@@ -233,27 +82,24 @@ def tearDown(self):
def test_pub_address_timeout(self):
"""Test timeout in offline nameserver."""
from posttroll.ns import get_pub_address
- from posttroll.ns import TimeoutError
-
- self.assertRaises(TimeoutError,
- get_pub_address, ["this_data", 0.5])
+ with pytest.raises(TimeoutError):
+ get_pub_address("this_data", 0.05)
def test_pub_suber(self):
"""Test publisher and subscriber."""
- from posttroll.message import Message
- from posttroll.publisher import Publisher
from posttroll.publisher import get_own_ip
- from posttroll.subscriber import Subscriber
-
pub_address = "tcp://" + str(get_own_ip()) + ":0"
pub = Publisher(pub_address).start()
addr = pub_address[:-1] + str(pub.port_number)
- sub = Subscriber([addr], '/counter')
+ sub = Subscriber([addr], topics="/counter")
+ # wait a bit before sending the first message so that the subscriber is ready
+ time.sleep(.002)
+
tested = False
for counter in range(5):
message = Message("/counter", "info", str(counter))
pub.send(str(message))
- time.sleep(1)
+ time.sleep(.05)
msg = next(sub.recv(2))
if msg is not None:
@@ -264,22 +110,21 @@ def test_pub_suber(self):
def test_pub_sub_ctx_no_nameserver(self):
"""Test publish and subscribe."""
- from posttroll.message import Message
- from posttroll.publisher import Publish
- from posttroll.subscriber import Subscribe
-
with Publish("data_provider", 40000, nameservers=False) as pub:
with Subscribe(topics="counter", nameserver=False, addresses=["tcp://127.0.0.1:40000"]) as sub:
+ assert isinstance(sub, Subscriber)
+ # wait a bit before sending the first message so that the subscriber is ready
+ time.sleep(.002)
for counter in range(5):
message = Message("/counter", "info", str(counter))
pub.send(str(message))
- time.sleep(1)
+ time.sleep(.05)
msg = next(sub.recv(2))
if msg is not None:
- self.assertEqual(str(msg), str(message))
+ assert str(msg) == str(message)
tested = True
sub.close()
- self.assertTrue(tested)
+ assert tested
class TestPub(unittest.TestCase):
@@ -293,47 +138,50 @@ def tearDown(self):
"""Clean up after the tests have run."""
test_lock.release()
- def test_pub_unicode(self):
+ def test_pub_supports_unicode(self):
"""Test publishing messages in Unicode."""
from posttroll.message import Message
from posttroll.publisher import Publish
- message = Message("/pџтяöll", "info", 'hej')
- with Publish("a_service", 9000) as pub:
+ message = Message("/pџтяöll", "info", "hej")
+ with Publish("a_service", 0) as pub:
try:
pub.send(message.encode())
except UnicodeDecodeError:
self.fail("Sending raised UnicodeDecodeError unexpectedly!")
- def test_pub_minmax_port(self):
- """Test user defined port range."""
+ def test_pub_minmax_port_from_config(self):
+ """Test config defined port range."""
# Using environment variables to set port range
# Try over a range of ports just in case the single port is reserved
for port in range(40000, 50000):
# Set the port range to config
with posttroll.config.set(pub_min_port=str(port), pub_max_port=str(port + 1)):
- res = _get_port(min_port=None, max_port=None)
+ res = _get_port_from_publish_instance(min_port=None, max_port=None)
if res is False:
# The port wasn't free, try another one
continue
# Port was selected, make sure it's within the "range" of one
- self.assertEqual(res, port)
+ assert res == port
break
+ def test_pub_minmax_port_from_instanciation(self):
+ """Test port range defined at instanciation."""
# Using range of ports defined at instantation time, this
# should override environment variables
for port in range(50000, 60000):
- res = _get_port(min_port=port, max_port=port+1)
+ res = _get_port_from_publish_instance(min_port=port, max_port=port + 1)
if res is False:
# The port wasn't free, try again
continue
# Port was selected, make sure it's within the "range" of one
- self.assertEqual(res, port)
+ assert res == port
break
-def _get_port(min_port=None, max_port=None):
+def _get_port_from_publish_instance(min_port=None, max_port=None):
from zmq.error import ZMQError
+
from posttroll.publisher import Publish
try:
@@ -347,48 +195,6 @@ def _get_port(min_port=None, max_port=None):
return False
-class TestListenerContainer(unittest.TestCase):
- """Testing listener container."""
-
- def setUp(self):
- """Set up the testing class."""
- from posttroll.ns import NameServer
- test_lock.acquire()
- self.ns = NameServer(max_age=timedelta(seconds=3))
- self.thr = Thread(target=self.ns.run)
- self.thr.start()
-
- def tearDown(self):
- """Clean up after the tests have run."""
- self.ns.stop()
- self.thr.join()
- time.sleep(2)
- test_lock.release()
-
- def test_listener_container(self):
- """Test listener container."""
- from posttroll.message import Message
- from posttroll.publisher import NoisyPublisher
- from posttroll.listener import ListenerContainer
-
- pub = NoisyPublisher("test")
- pub.start()
- sub = ListenerContainer(topics=["/counter"])
- time.sleep(2)
- for counter in range(5):
- tested = False
- msg_out = Message("/counter", "info", str(counter))
- pub.send(str(msg_out))
-
- msg_in = sub.output_queue.get(True, 1)
- if msg_in is not None:
- self.assertEqual(str(msg_in), str(msg_out))
- tested = True
- self.assertTrue(tested)
- pub.stop()
- sub.stop()
-
-
class TestListenerContainerNoNameserver(unittest.TestCase):
"""Testing listener container without nameserver."""
@@ -402,9 +208,8 @@ def tearDown(self):
def test_listener_container(self):
"""Test listener container."""
- from posttroll.message import Message
- from posttroll.publisher import Publisher
from posttroll.listener import ListenerContainer
+ from posttroll.message import Message
pub_addr = "tcp://127.0.0.1:55000"
pub = Publisher(pub_addr, name="test")
@@ -419,144 +224,126 @@ def test_listener_container(self):
msg_in = sub.output_queue.get(True, 1)
if msg_in is not None:
- self.assertEqual(str(msg_in), str(msg_out))
+ assert str(msg_in) == str(msg_out)
tested = True
- self.assertTrue(tested)
+ assert tested
pub.stop()
sub.stop()
-class TestAddressReceiver(unittest.TestCase):
- """Test the AddressReceiver."""
+# Test create_publisher_from_config
- @mock.patch("posttroll.address_receiver.Message")
- @mock.patch("posttroll.address_receiver.Publish")
- @mock.patch("posttroll.address_receiver.MulticastReceiver")
- def test_localhost_restriction(self, mcrec, pub, msg):
- """Test address receiver restricted only to localhost."""
- mcr_instance = mock.Mock()
- mcrec.return_value = mcr_instance
- mcr_instance.return_value = 'blabla', ('255.255.255.255', 12)
- from posttroll.address_receiver import AddressReceiver
- adr = AddressReceiver(restrict_to_localhost=True)
- adr.start()
- time.sleep(3)
- msg.decode.assert_not_called()
- adr.stop()
+def test_publisher_with_invalid_arguments_crashes():
+ """Test that only valid arguments are passed to Publisher."""
+ settings = {"address": "ipc:///tmp/test.ipc", "nameservers": False, "invalid_arg": "bar"}
+ with pytest.raises(TypeError):
+ _ = create_publisher_from_dict_config(settings)
-class TestPublisherDictConfig(unittest.TestCase):
- """Test configuring publishers with a dictionary."""
+def test_publisher_is_selected():
+ """Test that Publisher is selected as publisher class."""
+ settings = {"port": 12345, "nameservers": False}
- @mock.patch('posttroll.publisher.Publisher')
- def test_publisher_is_selected(self, Publisher):
- """Test that Publisher is selected as publisher class."""
- from posttroll.publisher import create_publisher_from_dict_config
+ pub = create_publisher_from_dict_config(settings)
+ assert isinstance(pub, Publisher)
+ assert pub is not None
- settings = {'port': 12345, 'nameservers': False}
- pub = create_publisher_from_dict_config(settings)
- Publisher.assert_called_once()
- assert pub is not None
+@mock.patch("posttroll.publisher.Publisher")
+def test_publisher_all_arguments(Publisher):
+ """Test that only valid arguments are passed to Publisher."""
+ settings = {"port": 12345, "nameservers": False, "name": "foo",
+ "min_port": 40000, "max_port": 41000}
+ _ = create_publisher_from_dict_config(settings)
+ _check_valid_settings_in_call(settings, Publisher, ignore=["port", "nameservers"])
+ assert Publisher.call_args[0][0].startswith("tcp://*:")
+ assert Publisher.call_args[0][0].endswith(str(settings["port"]))
- @mock.patch('posttroll.publisher.Publisher')
- def test_publisher_all_arguments(self, Publisher):
- """Test that only valid arguments are passed to Publisher."""
- from posttroll.publisher import create_publisher_from_dict_config
- settings = {'port': 12345, 'nameservers': False, 'name': 'foo',
- 'min_port': 40000, 'max_port': 41000, 'invalid_arg': 'bar'}
- _ = create_publisher_from_dict_config(settings)
- _check_valid_settings_in_call(settings, Publisher, ignore=['port', 'nameservers'])
- assert Publisher.call_args[0][0].startswith("tcp://*:")
- assert Publisher.call_args[0][0].endswith(str(settings['port']))
+def test_no_name_raises_keyerror():
+ """Trying to create a NoisyPublisher without a given name will raise KeyError."""
+ with pytest.raises(KeyError):
+ _ = create_publisher_from_dict_config(dict())
- def test_no_name_raises_keyerror(self):
- """Trying to create a NoisyPublisher without a given name will raise KeyError."""
- from posttroll.publisher import create_publisher_from_dict_config
- with self.assertRaises(KeyError):
- _ = create_publisher_from_dict_config(dict())
+def test_noisypublisher_is_selected_only_name():
+ """Test that NoisyPublisher is selected as publisher class."""
+ from posttroll.publisher import NoisyPublisher
- @mock.patch('posttroll.publisher.NoisyPublisher')
- def test_noisypublisher_is_selected_only_name(self, NoisyPublisher):
- """Test that NoisyPublisher is selected as publisher class."""
- from posttroll.publisher import create_publisher_from_dict_config
+ settings = {"name": "publisher_name"}
- settings = {'name': 'publisher_name'}
+ pub = create_publisher_from_dict_config(settings)
+ assert isinstance(pub, NoisyPublisher)
- pub = create_publisher_from_dict_config(settings)
- NoisyPublisher.assert_called_once()
- assert pub is not None
- @mock.patch('posttroll.publisher.NoisyPublisher')
- def test_noisypublisher_is_selected_name_and_port(self, NoisyPublisher):
- """Test that NoisyPublisher is selected as publisher class."""
- from posttroll.publisher import create_publisher_from_dict_config
+def test_noisypublisher_is_selected_name_and_port():
+ """Test that NoisyPublisher is selected as publisher class."""
+ from posttroll.publisher import NoisyPublisher
- settings = {'name': 'publisher_name', 'port': 40000}
+ settings = {"name": "publisher_name", "port": 40000}
- _ = create_publisher_from_dict_config(settings)
- NoisyPublisher.assert_called_once()
+ pub = create_publisher_from_dict_config(settings)
+ assert isinstance(pub, NoisyPublisher)
- @mock.patch('posttroll.publisher.NoisyPublisher')
- def test_noisypublisher_all_arguments(self, NoisyPublisher):
- """Test that only valid arguments are passed to NoisyPublisher."""
- from posttroll.publisher import create_publisher_from_dict_config
- settings = {'port': 12345, 'nameservers': ['foo'], 'name': 'foo',
- 'min_port': 40000, 'max_port': 41000, 'invalid_arg': 'bar',
- 'aliases': ['alias1', 'alias2'], 'broadcast_interval': 42}
- _ = create_publisher_from_dict_config(settings)
- _check_valid_settings_in_call(settings, NoisyPublisher, ignore=['name'])
- assert NoisyPublisher.call_args[0][0] == settings["name"]
+@mock.patch("posttroll.publisher.NoisyPublisher")
+def test_noisypublisher_all_arguments(NoisyPublisher):
+ """Test that only valid arguments are passed to NoisyPublisher."""
+ from posttroll.publisher import create_publisher_from_dict_config
- @mock.patch('posttroll.publisher.Publisher')
- def test_publish_is_not_noisy(self, Publisher):
- """Test that Publisher is selected with the context manager when it should be."""
- from posttroll.publisher import Publish
+ settings = {"port": 12345, "nameservers": ["foo"], "name": "foo",
+ "min_port": 40000, "max_port": 41000, "invalid_arg": "bar",
+ "aliases": ["alias1", "alias2"], "broadcast_interval": 42}
+ _ = create_publisher_from_dict_config(settings)
+ _check_valid_settings_in_call(settings, NoisyPublisher, ignore=["name"])
+ assert NoisyPublisher.call_args[0][0] == settings["name"]
- with Publish("service_name", port=40000, nameservers=False):
- Publisher.assert_called_once()
- @mock.patch('posttroll.publisher.NoisyPublisher')
- def test_publish_is_noisy_only_name(self, NoisyPublisher):
- """Test that NoisyPublisher is selected with the context manager when only name is given."""
- from posttroll.publisher import Publish
+def test_publish_is_not_noisy():
+ """Test that Publisher is selected with the context manager when it should be."""
+ from posttroll.publisher import Publish
- with Publish("service_name"):
- NoisyPublisher.assert_called_once()
+ with Publish("service_name", port=40000, nameservers=False) as pub:
+ assert isinstance(pub, Publisher)
- @mock.patch('posttroll.publisher.NoisyPublisher')
- def test_publish_is_noisy_with_port(self, NoisyPublisher):
- """Test that NoisyPublisher is selected with the context manager when port is given."""
- from posttroll.publisher import Publish
- with Publish("service_name", port=40000):
- NoisyPublisher.assert_called_once()
+def test_publish_is_noisy_only_name():
+ """Test that NoisyPublisher is selected with the context manager when only name is given."""
+ from posttroll.publisher import NoisyPublisher, Publish
+
+ with Publish("service_name") as pub:
+ assert isinstance(pub, NoisyPublisher)
+
+
+def test_publish_is_noisy_with_port():
+ """Test that NoisyPublisher is selected with the context manager when port is given."""
+ from posttroll.publisher import NoisyPublisher, Publish
+
+ with Publish("service_name", port=40001) as pub:
+ assert isinstance(pub, NoisyPublisher)
- @mock.patch('posttroll.publisher.NoisyPublisher')
- def test_publish_is_noisy_with_nameservers(self, NoisyPublisher):
- """Test that NoisyPublisher is selected with the context manager when nameservers are given."""
- from posttroll.publisher import Publish
- with Publish("service_name", nameservers=['a', 'b']):
- NoisyPublisher.assert_called_once()
+def test_publish_is_noisy_with_nameservers():
+ """Test that NoisyPublisher is selected with the context manager when nameservers are given."""
+ from posttroll.publisher import NoisyPublisher, Publish
+
+ with Publish("service_name", nameservers=["a", "b"]) as pub:
+ assert isinstance(pub, NoisyPublisher)
def _check_valid_settings_in_call(settings, pub_class, ignore=None):
ignore = ignore or []
for key in settings:
- if key == 'invalid_arg':
- assert 'invalid_arg' not in pub_class.call_args[1]
+ if key == "invalid_arg":
+ assert "invalid_arg" not in pub_class.call_args[1]
continue
if key in ignore:
continue
assert pub_class.call_args[1][key] == settings[key]
-@mock.patch('posttroll.subscriber.Subscriber')
-@mock.patch('posttroll.subscriber.NSSubscriber')
+@mock.patch("posttroll.subscriber.Subscriber")
+@mock.patch("posttroll.subscriber.NSSubscriber")
def test_dict_config_minimal(NSSubscriber, Subscriber):
"""Test that without any settings NSSubscriber is created."""
from posttroll.subscriber import create_subscriber_from_dict_config
@@ -567,31 +354,31 @@ def test_dict_config_minimal(NSSubscriber, Subscriber):
Subscriber.assert_not_called()
-@mock.patch('posttroll.subscriber.Subscriber')
-@mock.patch('posttroll.subscriber.NSSubscriber')
+@mock.patch("posttroll.subscriber.Subscriber")
+@mock.patch("posttroll.subscriber.NSSubscriber")
def test_dict_config_nameserver_false(NSSubscriber, Subscriber):
"""Test that NSSubscriber is created with 'localhost' nameserver when no addresses are given."""
from posttroll.subscriber import create_subscriber_from_dict_config
- subscriber = create_subscriber_from_dict_config({'nameserver': False})
+ subscriber = create_subscriber_from_dict_config({"nameserver": False})
NSSubscriber.assert_called_once()
assert subscriber == NSSubscriber().start()
Subscriber.assert_not_called()
-@mock.patch('posttroll.subscriber.Subscriber')
-@mock.patch('posttroll.subscriber.NSSubscriber')
+@mock.patch("posttroll.subscriber.Subscriber")
+@mock.patch("posttroll.subscriber.NSSubscriber")
def test_dict_config_subscriber(NSSubscriber, Subscriber):
"""Test that Subscriber is created when nameserver is False and addresses are given."""
from posttroll.subscriber import create_subscriber_from_dict_config
- subscriber = create_subscriber_from_dict_config({'nameserver': False, 'addresses': ['addr1']})
+ subscriber = create_subscriber_from_dict_config({"nameserver": False, "addresses": ["addr1"]})
assert subscriber == Subscriber.return_value
Subscriber.assert_called_once()
NSSubscriber.assert_not_called()
-@mock.patch('posttroll.subscriber.NSSubscriber.start')
+@mock.patch("posttroll.subscriber.NSSubscriber.start")
def test_dict_config_full_nssubscriber(NSSubscriber_start):
"""Test that all NSSubscriber options are passed."""
from posttroll.subscriber import create_subscriber_from_dict_config
@@ -611,8 +398,7 @@ def test_dict_config_full_nssubscriber(NSSubscriber_start):
NSSubscriber_start.assert_called_once()
-@mock.patch('posttroll.subscriber.Subscriber.update')
-def test_dict_config_full_subscriber(Subscriber_update):
+def test_dict_config_full_subscriber():
"""Test that all Subscriber options are passed."""
from posttroll.subscriber import create_subscriber_from_dict_config
@@ -620,7 +406,7 @@ def test_dict_config_full_subscriber(Subscriber_update):
"services": "val1",
"topics": "val2",
"addr_listener": "val3",
- "addresses": "val4",
+ "addresses": "ipc://bla.ipc",
"timeout": "val5",
"translate": "val6",
"nameserver": False,
@@ -629,14 +415,10 @@ def test_dict_config_full_subscriber(Subscriber_update):
_ = create_subscriber_from_dict_config(settings)
-@pytest.fixture
-def tcp_keepalive_settings(monkeypatch):
+@pytest.fixture()
+def _tcp_keepalive_settings(monkeypatch):
"""Set TCP Keepalive settings."""
- monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE", "1")
- monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE_CNT", "10")
- monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE_IDLE", "1")
- monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE_INTVL", "1")
- with reset_config_for_tests():
+ with config.set(tcp_keepalive=1, tcp_keepalive_cnt=10, tcp_keepalive_idle=1, tcp_keepalive_intvl=1):
yield
@@ -649,97 +431,75 @@ def reset_config_for_tests():
posttroll.config = old_config
-@pytest.fixture
-def tcp_keepalive_no_settings(monkeypatch):
+@pytest.fixture()
+def _tcp_keepalive_no_settings():
"""Set TCP Keepalive settings."""
- monkeypatch.delenv("POSTTROLL_TCP_KEEPALIVE", raising=False)
- monkeypatch.delenv("POSTTROLL_TCP_KEEPALIVE_CNT", raising=False)
- monkeypatch.delenv("POSTTROLL_TCP_KEEPALIVE_IDLE", raising=False)
- monkeypatch.delenv("POSTTROLL_TCP_KEEPALIVE_INTVL", raising=False)
- with reset_config_for_tests():
+ with config.set(tcp_keepalive=None, tcp_keepalive_cnt=None, tcp_keepalive_idle=None, tcp_keepalive_intvl=None):
yield
-def test_publisher_tcp_keepalive(tcp_keepalive_settings):
+@pytest.mark.usefixtures("_tcp_keepalive_settings")
+def test_publisher_tcp_keepalive():
"""Test that TCP Keepalive is set for Publisher if the environment variables are present."""
- socket = mock.MagicMock()
- with mock.patch('posttroll.publisher.get_context') as get_context:
- get_context.return_value.socket.return_value = socket
- from posttroll.publisher import Publisher
-
- _ = Publisher("tcp://127.0.0.1:9000").start()
-
- _assert_tcp_keepalive(socket)
+ from posttroll.backends.zmq.publisher import ZMQPublisher
+ pub = ZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start()
+ _assert_tcp_keepalive(pub.publish_socket)
+ pub.stop()
-def test_publisher_tcp_keepalive_not_set(tcp_keepalive_no_settings):
+@pytest.mark.usefixtures("_tcp_keepalive_no_settings")
+def test_publisher_tcp_keepalive_not_set():
"""Test that TCP Keepalive is not set on by default."""
- socket = mock.MagicMock()
- with mock.patch('posttroll.publisher.get_context') as get_context:
- get_context.return_value.socket.return_value = socket
- from posttroll.publisher import Publisher
-
- _ = Publisher("tcp://127.0.0.1:9000").start()
- _assert_no_tcp_keepalive(socket)
+ from posttroll.backends.zmq.publisher import ZMQPublisher
+ pub = ZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start()
+ _assert_no_tcp_keepalive(pub.publish_socket)
+ pub.stop()
-def test_subscriber_tcp_keepalive(tcp_keepalive_settings):
+@pytest.mark.usefixtures("_tcp_keepalive_settings")
+def test_subscriber_tcp_keepalive():
"""Test that TCP Keepalive is set for Subscriber if the environment variables are present."""
- socket = mock.MagicMock()
- with mock.patch('posttroll.subscriber.get_context') as get_context:
- get_context.return_value.socket.return_value = socket
- from posttroll.subscriber import Subscriber
-
- _ = Subscriber("tcp://127.0.0.1:9000")
+ from posttroll.backends.zmq.subscriber import ZMQSubscriber
+ sub = ZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}")
+ assert len(sub.addr_sub.values()) == 1
+ _assert_tcp_keepalive(list(sub.addr_sub.values())[0])
+ sub.stop()
- _assert_tcp_keepalive(socket)
-
-def test_subscriber_tcp_keepalive_not_set(tcp_keepalive_no_settings):
+@pytest.mark.usefixtures("_tcp_keepalive_no_settings")
+def test_subscriber_tcp_keepalive_not_set():
"""Test that TCP Keepalive is not set on by default."""
- socket = mock.MagicMock()
- with mock.patch('posttroll.subscriber.get_context') as get_context:
- get_context.return_value.socket.return_value = socket
- from posttroll.subscriber import Subscriber
-
- _ = Subscriber("tcp://127.0.0.1:9000")
-
- _assert_no_tcp_keepalive(socket)
+ from posttroll.backends.zmq.subscriber import ZMQSubscriber
+ sub = ZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}")
+ assert len(sub.addr_sub.values()) == 1
+ _assert_no_tcp_keepalive(list(sub.addr_sub.values())[0])
+ sub.close()
def _assert_tcp_keepalive(socket):
import zmq
- assert mock.call(zmq.TCP_KEEPALIVE, 1) in socket.setsockopt.mock_calls
- assert mock.call(zmq.TCP_KEEPALIVE_CNT, 10) in socket.setsockopt.mock_calls
- assert mock.call(zmq.TCP_KEEPALIVE_IDLE, 1) in socket.setsockopt.mock_calls
- assert mock.call(zmq.TCP_KEEPALIVE_INTVL, 1) in socket.setsockopt.mock_calls
+ assert socket.getsockopt(zmq.TCP_KEEPALIVE) == 1
+ assert socket.getsockopt(zmq.TCP_KEEPALIVE_CNT) == 10
+ assert socket.getsockopt(zmq.TCP_KEEPALIVE_IDLE) == 1
+ assert socket.getsockopt(zmq.TCP_KEEPALIVE_INTVL) == 1
def _assert_no_tcp_keepalive(socket):
- assert "TCP_KEEPALIVE" not in str(socket.setsockopt.mock_calls)
-
+ import zmq
-def test_noisypublisher_heartbeat():
- """Test that the heartbeat in the NoisyPublisher works."""
- from posttroll.ns import NameServer
- from posttroll.publisher import NoisyPublisher
- from posttroll.subscriber import Subscribe
-
- ns_ = NameServer()
- thr = Thread(target=ns_.run)
- thr.start()
-
- pub = NoisyPublisher("test")
- pub.start()
- time.sleep(0.2)
-
- with Subscribe("test", topics="/heartbeat/test", nameserver="localhost") as sub:
- time.sleep(0.2)
- pub.heartbeat(min_interval=10)
- msg = next(sub.recv(1))
- assert msg.type == "beat"
- assert msg.data == {'min_interval': 10}
- pub.stop()
- ns_.stop()
- thr.join()
+ assert socket.getsockopt(zmq.TCP_KEEPALIVE) == -1
+ assert socket.getsockopt(zmq.TCP_KEEPALIVE_CNT) == -1
+ assert socket.getsockopt(zmq.TCP_KEEPALIVE_IDLE) == -1
+ assert socket.getsockopt(zmq.TCP_KEEPALIVE_INTVL) == -1
+
+
+def test_switch_to_unknown_backend():
+ """Test switching to unknown backend."""
+ from posttroll.publisher import Publisher
+ from posttroll.subscriber import Subscriber
+ with config.set(backend="unsecure_and_deprecated"):
+ with pytest.raises(NotImplementedError):
+ Publisher("ipc://bla.ipc")
+ with pytest.raises(NotImplementedError):
+ Subscriber("ipc://bla.ipc")
diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py
new file mode 100644
index 0000000..f912125
--- /dev/null
+++ b/posttroll/tests/test_secure_zmq_backend.py
@@ -0,0 +1,171 @@
+
+"""Test the curve-based zmq backend."""
+
+import os
+import shutil
+import time
+from threading import Thread
+
+import zmq.auth
+
+from posttroll import config
+from posttroll.backends.zmq import generate_keys
+from posttroll.message import Message
+from posttroll.ns import get_pub_address
+from posttroll.publisher import Publisher, create_publisher_from_dict_config
+from posttroll.subscriber import Subscriber, create_subscriber_from_dict_config
+from posttroll.tests.test_nameserver import create_nameserver_instance
+
+
+def create_keys(tmp_path):
+ """Create keys."""
+ base_dir = tmp_path
+ keys_dir = base_dir / "certificates"
+ public_keys_dir = base_dir / "public_keys"
+ secret_keys_dir = base_dir / "private_keys"
+
+ keys_dir.mkdir()
+ public_keys_dir.mkdir()
+ secret_keys_dir.mkdir()
+
+ # create new keys in certificates dir
+ _server_public_file, _server_secret_file = zmq.auth.create_certificates(
+ keys_dir, "server"
+ )
+ _client_public_file, _client_secret_file = zmq.auth.create_certificates(
+ keys_dir, "client"
+ )
+
+ # move public keys to appropriate directory
+ for key_file in os.listdir(keys_dir):
+ if key_file.endswith(".key"):
+ shutil.move(
+ os.path.join(keys_dir, key_file), os.path.join(public_keys_dir, ".")
+ )
+
+ # move secret keys to appropriate directory
+ for key_file in os.listdir(keys_dir):
+ if key_file.endswith(".key_secret"):
+ shutil.move(
+ os.path.join(keys_dir, key_file), os.path.join(secret_keys_dir, ".")
+ )
+
+
+def test_ipc_pubsub_with_sec(tmp_path):
+ """Test pub-sub on a secure ipc socket."""
+ server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server")
+ client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client")
+
+ ipc_address = f"ipc://{str(tmp_path)}/bla.ipc"
+
+ with config.set(backend="secure_zmq",
+ client_secret_key_file=client_secret_key_file,
+ clients_public_keys_directory=os.path.dirname(client_public_key_file),
+ server_public_key_file=server_public_key_file,
+ server_secret_key_file=server_secret_key_file):
+ subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202)
+ sub = create_subscriber_from_dict_config(subscriber_settings)
+
+ pub = Publisher(ipc_address)
+
+ pub.start()
+
+ def delayed_send(msg):
+ time.sleep(.2)
+ msg = Message(subject="/hi", atype="string", data=msg)
+ pub.send(str(msg))
+ thr = Thread(target=delayed_send, args=["very sensitive message"])
+ thr.start()
+ try:
+ for msg in sub.recv():
+ assert msg.data == "very sensitive message"
+ break
+ finally:
+ sub.stop()
+ thr.join()
+ pub.stop()
+
+
+def test_switch_to_secure_zmq_backend(tmp_path):
+ """Test switching to the secure_zmq backend."""
+ create_keys(tmp_path)
+
+ base_dir = tmp_path
+ public_keys_dir = base_dir / "public_keys"
+ secret_keys_dir = base_dir / "private_keys"
+
+ server_secret_key = secret_keys_dir / "server.key_secret"
+ public_keys_directory = public_keys_dir
+
+ client_secret_key = secret_keys_dir / "client.key_secret"
+ server_public_key = public_keys_dir / "server.key"
+
+ with config.set(backend="secure_zmq",
+ client_secret_key_file=client_secret_key,
+ clients_public_keys_directory=public_keys_directory,
+ server_public_key_file=server_public_key,
+ server_secret_key_file=server_secret_key):
+ Publisher("ipc://bla.ipc")
+ Subscriber("ipc://bla.ipc")
+
+
+def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path):
+ """Test pub-sub on a secure ipc socket."""
+ # create_keys(tmp_path)
+
+ server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server")
+ client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client")
+
+ ipc_address = f"ipc://{str(tmp_path)}/bla.ipc"
+
+ with config.set(backend="secure_zmq",
+ client_secret_key_file=client_secret_key_file,
+ clients_public_keys_directory=os.path.dirname(client_public_key_file),
+ server_public_key_file=server_public_key_file,
+ server_secret_key_file=server_secret_key_file):
+ subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10203)
+ sub = create_subscriber_from_dict_config(subscriber_settings)
+ pub_settings = dict(address=ipc_address,
+ nameservers=False, port=1789)
+ pub = create_publisher_from_dict_config(pub_settings)
+
+ pub.start()
+
+ def delayed_send(msg):
+ time.sleep(.2)
+ msg = Message(subject="/hi", atype="string", data=msg)
+ pub.send(str(msg))
+ thr = Thread(target=delayed_send, args=["very sensitive message"])
+ thr.start()
+ try:
+ for msg in sub.recv():
+ assert msg.data == "very sensitive message"
+ break
+ finally:
+ sub.stop()
+ thr.join()
+ pub.stop()
+
+
+def test_switch_to_secure_backend_for_nameserver(tmp_path):
+ """Test switching backend for nameserver."""
+ server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server")
+ client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client")
+ with config.set(backend="secure_zmq",
+ client_secret_key_file=client_secret_key_file,
+ clients_public_keys_directory=os.path.dirname(client_public_key_file),
+ server_public_key_file=server_public_key_file,
+ server_secret_key_file=server_secret_key_file):
+
+ with create_nameserver_instance():
+ res = get_pub_address("some_name")
+ assert res == ""
+
+
+def test_create_certificates_cli(tmp_path):
+ """Test the certificate creation cli."""
+ name = "server"
+ args = [name, "-d", str(tmp_path)]
+ generate_keys(args)
+ assert (tmp_path / (name + ".key")).exists()
+ assert (tmp_path / (name + ".key_secret")).exists()
diff --git a/posttroll/tests/test_unsecure_zmq_backend.py b/posttroll/tests/test_unsecure_zmq_backend.py
new file mode 100644
index 0000000..1b2b469
--- /dev/null
+++ b/posttroll/tests/test_unsecure_zmq_backend.py
@@ -0,0 +1,67 @@
+"""Tests for the unsecure zmq backend."""
+
+import time
+
+import pytest
+
+from posttroll import config
+from posttroll.publisher import Publisher, create_publisher_from_dict_config
+from posttroll.subscriber import Subscriber, create_subscriber_from_dict_config
+
+
+def test_ipc_pubsub(tmp_path):
+ """Test pub-sub on an ipc socket."""
+ ipc_address = f"ipc://{str(tmp_path)}/bla.ipc"
+
+ with config.set(backend="unsecure_zmq"):
+ subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202)
+ sub = create_subscriber_from_dict_config(subscriber_settings)
+ pub = Publisher(ipc_address)
+ pub.start()
+
+ def delayed_send(msg):
+ time.sleep(.2)
+ from posttroll.message import Message
+ msg = Message(subject="/hi", atype="string", data=msg)
+ pub.send(str(msg))
+ pub.stop()
+ from threading import Thread
+ Thread(target=delayed_send, args=["hi"]).start()
+ for msg in sub.recv():
+ assert msg.data == "hi"
+ break
+ sub.stop()
+
+
+def test_switch_to_unsecure_zmq_backend(tmp_path):
+ """Test switching to the secure_zmq backend."""
+ ipc_address = f"ipc://{str(tmp_path)}/bla.ipc"
+
+ with config.set(backend="unsecure_zmq"):
+ Publisher(ipc_address)
+ Subscriber(ipc_address)
+
+
+def test_ipc_pub_crashes_when_passed_key_files(tmp_path):
+ """Test pub-sub on an ipc socket."""
+ ipc_address = f"ipc://{str(tmp_path)}/bla.ipc"
+
+ with config.set(backend="unsecure_zmq"):
+ subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202,
+ client_secret_key_file="my_secret_key",
+ server_public_key_file="server_public_key")
+ with pytest.raises(TypeError):
+ create_subscriber_from_dict_config(subscriber_settings)
+
+
+def test_ipc_sub_crashes_when_passed_key_files(tmp_path):
+ """Test pub-sub on a secure ipc socket."""
+ ipc_address = f"ipc://{str(tmp_path)}/bla.ipc"
+
+ with config.set(backend="unsecure_zmq"):
+ pub_settings = dict(address=ipc_address,
+ server_secret_key="server.key_secret",
+ public_keys_directory="public_keys_dir",
+ nameservers=False, port=1789)
+ with pytest.raises(TypeError):
+ create_publisher_from_dict_config(pub_settings)
diff --git a/posttroll/version.py b/posttroll/version.py
deleted file mode 100644
index aabec7c..0000000
--- a/posttroll/version.py
+++ /dev/null
@@ -1,657 +0,0 @@
-
-# This file helps to compute a version number in source trees obtained from
-# git-archive tarball (such as those provided by githubs download-from-tag
-# feature). Distribution tarballs (built by setup.py sdist) and build
-# directories (produced by setup.py build) will contain a much shorter file
-# that just contains the computed version number.
-
-# This file is released into the public domain. Generated by
-# versioneer-0.23 (https://github.com/python-versioneer/python-versioneer)
-
-"""Git implementation of _version.py."""
-
-import errno
-import os
-import re
-import subprocess
-import sys
-from typing import Callable, Dict
-import functools
-
-
-def get_keywords():
- """Get the keywords needed to look up the version information."""
- # these strings will be replaced by git during git-archive.
- # setup.py/versioneer.py will grep for the variable names, so they must
- # each be defined on a line of their own. _version.py will just call
- # get_keywords().
- git_refnames = "$Format:%d$"
- git_full = "$Format:%H$"
- git_date = "$Format:%ci$"
- keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
- return keywords
-
-
-class VersioneerConfig:
- """Container for Versioneer configuration parameters."""
-
-
-def get_config():
- """Create, populate and return the VersioneerConfig() object."""
- # these strings are filled in when 'setup.py versioneer' creates
- # _version.py
- cfg = VersioneerConfig()
- cfg.VCS = "git"
- cfg.style = "pep440"
- cfg.tag_prefix = "v"
- cfg.parentdir_prefix = "None"
- cfg.versionfile_source = "posttroll/version.py"
- cfg.verbose = False
- return cfg
-
-
-class NotThisMethod(Exception):
- """Exception raised if a method is not valid for the current scenario."""
-
-
-LONG_VERSION_PY: Dict[str, str] = {}
-HANDLERS: Dict[str, Dict[str, Callable]] = {}
-
-
-def register_vcs_handler(vcs, method): # decorator
- """Create decorator to mark a method as the handler of a VCS."""
- def decorate(f):
- """Store f in HANDLERS[vcs][method]."""
- if vcs not in HANDLERS:
- HANDLERS[vcs] = {}
- HANDLERS[vcs][method] = f
- return f
- return decorate
-
-
-def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
- env=None):
- """Call the given command(s)."""
- assert isinstance(commands, list)
- process = None
-
- popen_kwargs = {}
- if sys.platform == "win32":
- # This hides the console window if pythonw.exe is used
- startupinfo = subprocess.STARTUPINFO()
- startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
- popen_kwargs["startupinfo"] = startupinfo
-
- for command in commands:
- try:
- dispcmd = str([command] + args)
- # remember shell=False, so use git.cmd on windows, not just git
- process = subprocess.Popen([command] + args, cwd=cwd, env=env,
- stdout=subprocess.PIPE,
- stderr=(subprocess.PIPE if hide_stderr
- else None), **popen_kwargs)
- break
- except OSError:
- e = sys.exc_info()[1]
- if e.errno == errno.ENOENT:
- continue
- if verbose:
- print("unable to run %s" % dispcmd)
- print(e)
- return None, None
- else:
- if verbose:
- print("unable to find command, tried %s" % (commands,))
- return None, None
- stdout = process.communicate()[0].strip().decode()
- if process.returncode != 0:
- if verbose:
- print("unable to run %s (error)" % dispcmd)
- print("stdout was %s" % stdout)
- return None, process.returncode
- return stdout, process.returncode
-
-
-def versions_from_parentdir(parentdir_prefix, root, verbose):
- """Try to determine the version from the parent directory name.
-
- Source tarballs conventionally unpack into a directory that includes both
- the project name and a version string. We will also support searching up
- two directory levels for an appropriately named parent directory
- """
- rootdirs = []
-
- for _ in range(3):
- dirname = os.path.basename(root)
- if dirname.startswith(parentdir_prefix):
- return {"version": dirname[len(parentdir_prefix):],
- "full-revisionid": None,
- "dirty": False, "error": None, "date": None}
- rootdirs.append(root)
- root = os.path.dirname(root) # up a level
-
- if verbose:
- print("Tried directories %s but none started with prefix %s" %
- (str(rootdirs), parentdir_prefix))
- raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
-
-
-@register_vcs_handler("git", "get_keywords")
-def git_get_keywords(versionfile_abs):
- """Extract version information from the given file."""
- # the code embedded in _version.py can just fetch the value of these
- # keywords. When used from setup.py, we don't want to import _version.py,
- # so we do it with a regexp instead. This function is not used from
- # _version.py.
- keywords = {}
- try:
- with open(versionfile_abs, "r") as fobj:
- for line in fobj:
- if line.strip().startswith("git_refnames ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["refnames"] = mo.group(1)
- if line.strip().startswith("git_full ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["full"] = mo.group(1)
- if line.strip().startswith("git_date ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["date"] = mo.group(1)
- except OSError:
- pass
- return keywords
-
-
-@register_vcs_handler("git", "keywords")
-def git_versions_from_keywords(keywords, tag_prefix, verbose):
- """Get version information from git keywords."""
- if "refnames" not in keywords:
- raise NotThisMethod("Short version file found")
- date = keywords.get("date")
- if date is not None:
- # Use only the last line. Previous lines may contain GPG signature
- # information.
- date = date.splitlines()[-1]
-
- # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
- # datestamp. However we prefer "%ci" (which expands to an "ISO-8601
- # -like" string, which we must then edit to make compliant), because
- # it's been around since git-1.5.3, and it's too difficult to
- # discover which version we're using, or to work around using an
- # older one.
- date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
- refnames = keywords["refnames"].strip()
- if refnames.startswith("$Format"):
- if verbose:
- print("keywords are unexpanded, not using")
- raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
- refs = {r.strip() for r in refnames.strip("()").split(",")}
- # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
- # just "foo-1.0". If we see a "tag: " prefix, prefer those.
- TAG = "tag: "
- tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
- if not tags:
- # Either we're using git < 1.8.3, or there really are no tags. We use
- # a heuristic: assume all version tags have a digit. The old git %d
- # expansion behaves like git log --decorate=short and strips out the
- # refs/heads/ and refs/tags/ prefixes that would let us distinguish
- # between branches and tags. By ignoring refnames without digits, we
- # filter out many common branch names like "release" and
- # "stabilization", as well as "HEAD" and "master".
- tags = {r for r in refs if re.search(r'\d', r)}
- if verbose:
- print("discarding '%s', no digits" % ",".join(refs - tags))
- if verbose:
- print("likely tags: %s" % ",".join(sorted(tags)))
- for ref in sorted(tags):
- # sorting will prefer e.g. "2.0" over "2.0rc1"
- if ref.startswith(tag_prefix):
- r = ref[len(tag_prefix):]
- # Filter out refs that exactly match prefix or that don't start
- # with a number once the prefix is stripped (mostly a concern
- # when prefix is '')
- if not re.match(r'\d', r):
- continue
- if verbose:
- print("picking %s" % r)
- return {"version": r,
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": None,
- "date": date}
- # no suitable tags, so version is "0+unknown", but full hex is still there
- if verbose:
- print("no suitable tags, using unknown + full revision id")
- return {"version": "0+unknown",
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": "no suitable tags", "date": None}
-
-
-@register_vcs_handler("git", "pieces_from_vcs")
-def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
- """Get version from 'git describe' in the root of the source tree.
-
- This only gets called if the git-archive 'subst' keywords were *not*
- expanded, and _version.py hasn't already been rewritten with a short
- version string, meaning we're inside a checked out source tree.
- """
- GITS = ["git"]
- if sys.platform == "win32":
- GITS = ["git.cmd", "git.exe"]
-
- # GIT_DIR can interfere with correct operation of Versioneer.
- # It may be intended to be passed to the Versioneer-versioned project,
- # but that should not change where we get our version from.
- env = os.environ.copy()
- env.pop("GIT_DIR", None)
- runner = functools.partial(runner, env=env)
-
- _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
- hide_stderr=True)
- if rc != 0:
- if verbose:
- print("Directory %s not under git control" % root)
- raise NotThisMethod("'git rev-parse --git-dir' returned error")
-
- # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
- # if there isn't one, this yields HEX[-dirty] (no NUM)
- describe_out, rc = runner(GITS, [
- "describe", "--tags", "--dirty", "--always", "--long",
- "--match", f"{tag_prefix}[[:digit:]]*"
- ], cwd=root)
- # --long was added in git-1.5.5
- if describe_out is None:
- raise NotThisMethod("'git describe' failed")
- describe_out = describe_out.strip()
- full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
- if full_out is None:
- raise NotThisMethod("'git rev-parse' failed")
- full_out = full_out.strip()
-
- pieces = {}
- pieces["long"] = full_out
- pieces["short"] = full_out[:7] # maybe improved later
- pieces["error"] = None
-
- branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
- cwd=root)
- # --abbrev-ref was added in git-1.6.3
- if rc != 0 or branch_name is None:
- raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
- branch_name = branch_name.strip()
-
- if branch_name == "HEAD":
- # If we aren't exactly on a branch, pick a branch which represents
- # the current commit. If all else fails, we are on a branchless
- # commit.
- branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
- # --contains was added in git-1.5.4
- if rc != 0 or branches is None:
- raise NotThisMethod("'git branch --contains' returned error")
- branches = branches.split("\n")
-
- # Remove the first line if we're running detached
- if "(" in branches[0]:
- branches.pop(0)
-
- # Strip off the leading "* " from the list of branches.
- branches = [branch[2:] for branch in branches]
- if "master" in branches:
- branch_name = "master"
- elif not branches:
- branch_name = None
- else:
- # Pick the first branch that is returned. Good or bad.
- branch_name = branches[0]
-
- pieces["branch"] = branch_name
-
- # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
- # TAG might have hyphens.
- git_describe = describe_out
-
- # look for -dirty suffix
- dirty = git_describe.endswith("-dirty")
- pieces["dirty"] = dirty
- if dirty:
- git_describe = git_describe[:git_describe.rindex("-dirty")]
-
- # now we have TAG-NUM-gHEX or HEX
-
- if "-" in git_describe:
- # TAG-NUM-gHEX
- mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
- if not mo:
- # unparsable. Maybe git-describe is misbehaving?
- pieces["error"] = ("unable to parse git-describe output: '%s'"
- % describe_out)
- return pieces
-
- # tag
- full_tag = mo.group(1)
- if not full_tag.startswith(tag_prefix):
- if verbose:
- fmt = "tag '%s' doesn't start with prefix '%s'"
- print(fmt % (full_tag, tag_prefix))
- pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
- % (full_tag, tag_prefix))
- return pieces
- pieces["closest-tag"] = full_tag[len(tag_prefix):]
-
- # distance: number of commits since tag
- pieces["distance"] = int(mo.group(2))
-
- # commit: short hex revision ID
- pieces["short"] = mo.group(3)
-
- else:
- # HEX: no tags
- pieces["closest-tag"] = None
- out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
- pieces["distance"] = len(out.split()) # total number of commits
-
- # commit date: see ISO-8601 comment in git_versions_from_keywords()
- date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
- # Use only the last line. Previous lines may contain GPG signature
- # information.
- date = date.splitlines()[-1]
- pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
-
- return pieces
-
-
-def plus_or_dot(pieces):
- """Return a + if we don't already have one, else return a ."""
- if "+" in pieces.get("closest-tag", ""):
- return "."
- return "+"
-
-
-def render_pep440(pieces):
- """Build up version string, with post-release "local version identifier".
-
- Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
- get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
-
- Exceptions:
- 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += plus_or_dot(pieces)
- rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- else:
- # exception #1
- rendered = "0+untagged.%d.g%s" % (pieces["distance"],
- pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- return rendered
-
-
-def render_pep440_branch(pieces):
- """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
-
- The ".dev0" means not master branch. Note that .dev0 sorts backwards
- (a feature branch will appear "older" than the master branch).
-
- Exceptions:
- 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- if pieces["branch"] != "master":
- rendered += ".dev0"
- rendered += plus_or_dot(pieces)
- rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- else:
- # exception #1
- rendered = "0"
- if pieces["branch"] != "master":
- rendered += ".dev0"
- rendered += "+untagged.%d.g%s" % (pieces["distance"],
- pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- return rendered
-
-
-def pep440_split_post(ver):
- """Split pep440 version string at the post-release segment.
-
- Returns the release segments before the post-release and the
- post-release version number (or -1 if no post-release segment is present).
- """
- vc = str.split(ver, ".post")
- return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
-
-
-def render_pep440_pre(pieces):
- """TAG[.postN.devDISTANCE] -- No -dirty.
-
- Exceptions:
- 1: no tags. 0.post0.devDISTANCE
- """
- if pieces["closest-tag"]:
- if pieces["distance"]:
- # update the post release segment
- tag_version, post_version = pep440_split_post(pieces["closest-tag"])
- rendered = tag_version
- if post_version is not None:
- rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
- else:
- rendered += ".post0.dev%d" % (pieces["distance"])
- else:
- # no commits, use the tag as the version
- rendered = pieces["closest-tag"]
- else:
- # exception #1
- rendered = "0.post0.dev%d" % pieces["distance"]
- return rendered
-
-
-def render_pep440_post(pieces):
- """TAG[.postDISTANCE[.dev0]+gHEX] .
-
- The ".dev0" means dirty. Note that .dev0 sorts backwards
- (a dirty tree will appear "older" than the corresponding clean one),
- but you shouldn't be releasing software with -dirty anyways.
-
- Exceptions:
- 1: no tags. 0.postDISTANCE[.dev0]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += ".post%d" % pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- rendered += plus_or_dot(pieces)
- rendered += "g%s" % pieces["short"]
- else:
- # exception #1
- rendered = "0.post%d" % pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- rendered += "+g%s" % pieces["short"]
- return rendered
-
-
-def render_pep440_post_branch(pieces):
- """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
-
- The ".dev0" means not master branch.
-
- Exceptions:
- 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += ".post%d" % pieces["distance"]
- if pieces["branch"] != "master":
- rendered += ".dev0"
- rendered += plus_or_dot(pieces)
- rendered += "g%s" % pieces["short"]
- if pieces["dirty"]:
- rendered += ".dirty"
- else:
- # exception #1
- rendered = "0.post%d" % pieces["distance"]
- if pieces["branch"] != "master":
- rendered += ".dev0"
- rendered += "+g%s" % pieces["short"]
- if pieces["dirty"]:
- rendered += ".dirty"
- return rendered
-
-
-def render_pep440_old(pieces):
- """TAG[.postDISTANCE[.dev0]] .
-
- The ".dev0" means dirty.
-
- Exceptions:
- 1: no tags. 0.postDISTANCE[.dev0]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += ".post%d" % pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- else:
- # exception #1
- rendered = "0.post%d" % pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- return rendered
-
-
-def render_git_describe(pieces):
- """TAG[-DISTANCE-gHEX][-dirty].
-
- Like 'git describe --tags --dirty --always'.
-
- Exceptions:
- 1: no tags. HEX[-dirty] (note: no 'g' prefix)
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"]:
- rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
- else:
- # exception #1
- rendered = pieces["short"]
- if pieces["dirty"]:
- rendered += "-dirty"
- return rendered
-
-
-def render_git_describe_long(pieces):
- """TAG-DISTANCE-gHEX[-dirty].
-
- Like 'git describe --tags --dirty --always -long'.
- The distance/hash is unconditional.
-
- Exceptions:
- 1: no tags. HEX[-dirty] (note: no 'g' prefix)
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
- else:
- # exception #1
- rendered = pieces["short"]
- if pieces["dirty"]:
- rendered += "-dirty"
- return rendered
-
-
-def render(pieces, style):
- """Render the given version pieces into the requested style."""
- if pieces["error"]:
- return {"version": "unknown",
- "full-revisionid": pieces.get("long"),
- "dirty": None,
- "error": pieces["error"],
- "date": None}
-
- if not style or style == "default":
- style = "pep440" # the default
-
- if style == "pep440":
- rendered = render_pep440(pieces)
- elif style == "pep440-branch":
- rendered = render_pep440_branch(pieces)
- elif style == "pep440-pre":
- rendered = render_pep440_pre(pieces)
- elif style == "pep440-post":
- rendered = render_pep440_post(pieces)
- elif style == "pep440-post-branch":
- rendered = render_pep440_post_branch(pieces)
- elif style == "pep440-old":
- rendered = render_pep440_old(pieces)
- elif style == "git-describe":
- rendered = render_git_describe(pieces)
- elif style == "git-describe-long":
- rendered = render_git_describe_long(pieces)
- else:
- raise ValueError("unknown style '%s'" % style)
-
- return {"version": rendered, "full-revisionid": pieces["long"],
- "dirty": pieces["dirty"], "error": None,
- "date": pieces.get("date")}
-
-
-def get_versions():
- """Get version information or return default if unable to do so."""
- # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
- # __file__, we can work backwards from there to the root. Some
- # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
- # case we can only use expanded keywords.
-
- cfg = get_config()
- verbose = cfg.verbose
-
- try:
- return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
- verbose)
- except NotThisMethod:
- pass
-
- try:
- root = os.path.realpath(__file__)
- # versionfile_source is the relative path from the top of the source
- # tree (where the .git directory might live) to this file. Invert
- # this to find the root from __file__.
- for _ in cfg.versionfile_source.split('/'):
- root = os.path.dirname(root)
- except NameError:
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None,
- "error": "unable to find root of source tree",
- "date": None}
-
- try:
- pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
- return render(pieces, cfg.style)
- except NotThisMethod:
- pass
-
- try:
- if cfg.parentdir_prefix:
- return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
- except NotThisMethod:
- pass
-
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None,
- "error": "unable to compute version", "date": None}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..09c9663
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,77 @@
+[project]
+name = "posttroll"
+dynamic = ["version"]
+description = "Messaging system for pytroll"
+authors = [
+ { name = "The Pytroll Team", email = "pytroll@googlegroups.com" }
+]
+dependencies = [
+ "pyzmq",
+ "netifaces-plus",
+ "donfig",
+]
+readme = "README.md"
+requires-python = ">=3.10"
+license = { text = "GPLv3" }
+classifiers = [
+ "Development Status :: 5 - Production/Stable",
+ "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
+ "Programming Language :: Python",
+ "Operating System :: OS Independent",
+ "Intended Audience :: Science/Research",
+ "Topic :: Scientific/Engineering",
+ "Topic :: Communications"
+]
+
+[project.scripts]
+pytroll-logger = "posttroll.logger:run"
+posttroll-generate-keys = "posttroll.backends.zmq:generate_keys"
+
+[project.urls]
+Homepage = "https://github.com/pytroll/posttroll"
+"Bug Tracker" = "https://github.com/pytroll/posttroll/issues"
+Documentation = "https://posttroll.readthedocs.io/"
+"Source Code" = "https://github.com/pytroll/posttroll"
+Organization = "https://pytroll.github.io/"
+Slack = "https://pytroll.slack.com/"
+"Release Notes" = "https://github.com/pytroll/posttroll/blob/main/CHANGELOG.md"
+
+[build-system]
+requires = ["hatchling", "hatch-vcs"]
+build-backend = "hatchling.build"
+
+[tool.hatch.metadata]
+allow-direct-references = true
+
+[tool.hatch.build.targets.wheel]
+packages = ["posttroll"]
+
+[tool.hatch.version]
+source = "vcs"
+
+[tool.hatch.build.hooks.vcs]
+version-file = "posttroll/version.py"
+
+[tool.isort]
+sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
+skip_gitignore = true
+default_section = "THIRDPARTY"
+known_first_party = "posttroll"
+line_length = 120
+
+
+[tool.ruff]
+# See https://docs.astral.sh/ruff/rules/
+# In the future, add "B", "S", "N"
+lint.select = ["A", "D", "E", "W", "F", "I", "PT", "TID", "C90", "Q", "T10", "T20"]
+line-length = 120
+exclude = ["versioneer.py",
+ "posttroll/version.py",
+ "doc"]
+
+[tool.ruff.lint.pydocstyle]
+convention = "google"
+
+[tool.ruff.lint.mccabe]
+# Unlike Flake8, default to a complexity level of 10.
+max-complexity = 10
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index 96fbf05..0000000
--- a/setup.cfg
+++ /dev/null
@@ -1,14 +0,0 @@
-[bdist_rpm]
-requires=python-daemon pyzmq
-release=1
-
-[versioneer]
-VCS = git
-style = pep440
-versionfile_source = posttroll/version.py
-versionfile_build =
-tag_prefix = v
-#parentdir_prefix = myproject-
-
-[flake8]
-max-line-length = 120
diff --git a/setup.py b/setup.py
deleted file mode 100644
index e80d8d6..0000000
--- a/setup.py
+++ /dev/null
@@ -1,57 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# Copyright (c) 2011, 2012, 2014, 2015, 2020.
-
-# Author(s):
-
-# The pytroll team:
-# Martin Raspaud
-
-# This file is part of pytroll.
-
-# This is free software: you can redistribute it and/or modify it under the
-# terms of the GNU General Public License as published by the Free Software
-# Foundation, either version 3 of the License, or (at your option) any later
-# version.
-
-# This program is distributed in the hope that it will be useful, but WITHOUT
-# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
-# details.
-
-# You should have received a copy of the GNU General Public License along with
-# this program. If not, see .
-
-from setuptools import setup
-import versioneer
-
-
-requirements = ['pyzmq', 'netifaces', "donfig"]
-
-
-setup(name="posttroll",
- version=versioneer.get_version(),
- cmdclass=versioneer.get_cmdclass(),
- description='Messaging system for pytroll',
- author='The pytroll team',
- author_email='pytroll@googlegroups.com',
- url="http://github.com/pytroll/posttroll",
- packages=['posttroll'],
- entry_points={
- 'console_scripts': ['pytroll-logger = posttroll.logger:run', ]},
- scripts=['bin/nameserver'],
- zip_safe=False,
- license="GPLv3",
- install_requires=requirements,
- classifiers=[
- 'Development Status :: 5 - Production/Stable',
- 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)',
- 'Programming Language :: Python',
- 'Operating System :: OS Independent',
- 'Intended Audience :: Science/Research',
- 'Topic :: Scientific/Engineering',
- 'Topic :: Communications'
- ],
- python_requires='>=3.7',
- test_suite='posttroll.tests.suite',
- )
diff --git a/versioneer.py b/versioneer.py
deleted file mode 100644
index 070e384..0000000
--- a/versioneer.py
+++ /dev/null
@@ -1,2146 +0,0 @@
-
-# Version: 0.23
-
-"""The Versioneer - like a rocketeer, but for versions.
-
-The Versioneer
-==============
-
-* like a rocketeer, but for versions!
-* https://github.com/python-versioneer/python-versioneer
-* Brian Warner
-* License: Public Domain (CC0-1.0)
-* Compatible with: Python 3.7, 3.8, 3.9, 3.10 and pypy3
-* [![Latest Version][pypi-image]][pypi-url]
-* [![Build Status][travis-image]][travis-url]
-
-This is a tool for managing a recorded version number in distutils/setuptools-based
-python projects. The goal is to remove the tedious and error-prone "update
-the embedded version string" step from your release process. Making a new
-release should be as easy as recording a new tag in your version-control
-system, and maybe making new tarballs.
-
-
-## Quick Install
-
-* `pip install versioneer` to somewhere in your $PATH
-* add a `[versioneer]` section to your setup.cfg (see [Install](INSTALL.md))
-* run `versioneer install` in your source tree, commit the results
-* Verify version information with `python setup.py version`
-
-## Version Identifiers
-
-Source trees come from a variety of places:
-
-* a version-control system checkout (mostly used by developers)
-* a nightly tarball, produced by build automation
-* a snapshot tarball, produced by a web-based VCS browser, like github's
- "tarball from tag" feature
-* a release tarball, produced by "setup.py sdist", distributed through PyPI
-
-Within each source tree, the version identifier (either a string or a number,
-this tool is format-agnostic) can come from a variety of places:
-
-* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows
- about recent "tags" and an absolute revision-id
-* the name of the directory into which the tarball was unpacked
-* an expanded VCS keyword ($Id$, etc)
-* a `_version.py` created by some earlier build step
-
-For released software, the version identifier is closely related to a VCS
-tag. Some projects use tag names that include more than just the version
-string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool
-needs to strip the tag prefix to extract the version identifier. For
-unreleased software (between tags), the version identifier should provide
-enough information to help developers recreate the same tree, while also
-giving them an idea of roughly how old the tree is (after version 1.2, before
-version 1.3). Many VCS systems can report a description that captures this,
-for example `git describe --tags --dirty --always` reports things like
-"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the
-0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has
-uncommitted changes).
-
-The version identifier is used for multiple purposes:
-
-* to allow the module to self-identify its version: `myproject.__version__`
-* to choose a name and prefix for a 'setup.py sdist' tarball
-
-## Theory of Operation
-
-Versioneer works by adding a special `_version.py` file into your source
-tree, where your `__init__.py` can import it. This `_version.py` knows how to
-dynamically ask the VCS tool for version information at import time.
-
-`_version.py` also contains `$Revision$` markers, and the installation
-process marks `_version.py` to have this marker rewritten with a tag name
-during the `git archive` command. As a result, generated tarballs will
-contain enough information to get the proper version.
-
-To allow `setup.py` to compute a version too, a `versioneer.py` is added to
-the top level of your source tree, next to `setup.py` and the `setup.cfg`
-that configures it. This overrides several distutils/setuptools commands to
-compute the version when invoked, and changes `setup.py build` and `setup.py
-sdist` to replace `_version.py` with a small static file that contains just
-the generated version data.
-
-## Installation
-
-See [INSTALL.md](./INSTALL.md) for detailed installation instructions.
-
-## Version-String Flavors
-
-Code which uses Versioneer can learn about its version string at runtime by
-importing `_version` from your main `__init__.py` file and running the
-`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can
-import the top-level `versioneer.py` and run `get_versions()`.
-
-Both functions return a dictionary with different flavors of version
-information:
-
-* `['version']`: A condensed version string, rendered using the selected
- style. This is the most commonly used value for the project's version
- string. The default "pep440" style yields strings like `0.11`,
- `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section
- below for alternative styles.
-
-* `['full-revisionid']`: detailed revision identifier. For Git, this is the
- full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac".
-
-* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the
- commit date in ISO 8601 format. This will be None if the date is not
- available.
-
-* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that
- this is only accurate if run in a VCS checkout, otherwise it is likely to
- be False or None
-
-* `['error']`: if the version string could not be computed, this will be set
- to a string describing the problem, otherwise it will be None. It may be
- useful to throw an exception in setup.py if this is set, to avoid e.g.
- creating tarballs with a version string of "unknown".
-
-Some variants are more useful than others. Including `full-revisionid` in a
-bug report should allow developers to reconstruct the exact code being tested
-(or indicate the presence of local changes that should be shared with the
-developers). `version` is suitable for display in an "about" box or a CLI
-`--version` output: it can be easily compared against release notes and lists
-of bugs fixed in various releases.
-
-The installer adds the following text to your `__init__.py` to place a basic
-version in `YOURPROJECT.__version__`:
-
- from ._version import get_versions
- __version__ = get_versions()['version']
- del get_versions
-
-## Styles
-
-The setup.cfg `style=` configuration controls how the VCS information is
-rendered into a version string.
-
-The default style, "pep440", produces a PEP440-compliant string, equal to the
-un-prefixed tag name for actual releases, and containing an additional "local
-version" section with more detail for in-between builds. For Git, this is
-TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags
---dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the
-tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and
-that this commit is two revisions ("+2") beyond the "0.11" tag. For released
-software (exactly equal to a known tag), the identifier will only contain the
-stripped tag, e.g. "0.11".
-
-Other styles are available. See [details.md](details.md) in the Versioneer
-source tree for descriptions.
-
-## Debugging
-
-Versioneer tries to avoid fatal errors: if something goes wrong, it will tend
-to return a version of "0+unknown". To investigate the problem, run `setup.py
-version`, which will run the version-lookup code in a verbose mode, and will
-display the full contents of `get_versions()` (including the `error` string,
-which may help identify what went wrong).
-
-## Known Limitations
-
-Some situations are known to cause problems for Versioneer. This details the
-most significant ones. More can be found on Github
-[issues page](https://github.com/python-versioneer/python-versioneer/issues).
-
-### Subprojects
-
-Versioneer has limited support for source trees in which `setup.py` is not in
-the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are
-two common reasons why `setup.py` might not be in the root:
-
-* Source trees which contain multiple subprojects, such as
- [Buildbot](https://github.com/buildbot/buildbot), which contains both
- "master" and "slave" subprojects, each with their own `setup.py`,
- `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI
- distributions (and upload multiple independently-installable tarballs).
-* Source trees whose main purpose is to contain a C library, but which also
- provide bindings to Python (and perhaps other languages) in subdirectories.
-
-Versioneer will look for `.git` in parent directories, and most operations
-should get the right version string. However `pip` and `setuptools` have bugs
-and implementation details which frequently cause `pip install .` from a
-subproject directory to fail to find a correct version string (so it usually
-defaults to `0+unknown`).
-
-`pip install --editable .` should work correctly. `setup.py install` might
-work too.
-
-Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in
-some later version.
-
-[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking
-this issue. The discussion in
-[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the
-issue from the Versioneer side in more detail.
-[pip PR#3176](https://github.com/pypa/pip/pull/3176) and
-[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve
-pip to let Versioneer work correctly.
-
-Versioneer-0.16 and earlier only looked for a `.git` directory next to the
-`setup.cfg`, so subprojects were completely unsupported with those releases.
-
-### Editable installs with setuptools <= 18.5
-
-`setup.py develop` and `pip install --editable .` allow you to install a
-project into a virtualenv once, then continue editing the source code (and
-test) without re-installing after every change.
-
-"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a
-convenient way to specify executable scripts that should be installed along
-with the python package.
-
-These both work as expected when using modern setuptools. When using
-setuptools-18.5 or earlier, however, certain operations will cause
-`pkg_resources.DistributionNotFound` errors when running the entrypoint
-script, which must be resolved by re-installing the package. This happens
-when the install happens with one version, then the egg_info data is
-regenerated while a different version is checked out. Many setup.py commands
-cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into
-a different virtualenv), so this can be surprising.
-
-[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes
-this one, but upgrading to a newer version of setuptools should probably
-resolve it.
-
-
-## Updating Versioneer
-
-To upgrade your project to a new release of Versioneer, do the following:
-
-* install the new Versioneer (`pip install -U versioneer` or equivalent)
-* edit `setup.cfg`, if necessary, to include any new configuration settings
- indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details.
-* re-run `versioneer install` in your source tree, to replace
- `SRC/_version.py`
-* commit any changed files
-
-## Future Directions
-
-This tool is designed to make it easily extended to other version-control
-systems: all VCS-specific components are in separate directories like
-src/git/ . The top-level `versioneer.py` script is assembled from these
-components by running make-versioneer.py . In the future, make-versioneer.py
-will take a VCS name as an argument, and will construct a version of
-`versioneer.py` that is specific to the given VCS. It might also take the
-configuration arguments that are currently provided manually during
-installation by editing setup.py . Alternatively, it might go the other
-direction and include code from all supported VCS systems, reducing the
-number of intermediate scripts.
-
-## Similar projects
-
-* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time
- dependency
-* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of
- versioneer
-* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools
- plugin
-
-## License
-
-To make Versioneer easier to embed, all its code is dedicated to the public
-domain. The `_version.py` that it creates is also in the public domain.
-Specifically, both are released under the Creative Commons "Public Domain
-Dedication" license (CC0-1.0), as described in
-https://creativecommons.org/publicdomain/zero/1.0/ .
-
-[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg
-[pypi-url]: https://pypi.python.org/pypi/versioneer/
-[travis-image]:
-https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg
-[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer
-
-"""
-# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring
-# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements
-# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error
-# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with
-# pylint:disable=attribute-defined-outside-init,too-many-arguments
-
-import configparser
-import errno
-import json
-import os
-import re
-import subprocess
-import sys
-from typing import Callable, Dict
-import functools
-
-
-class VersioneerConfig:
- """Container for Versioneer configuration parameters."""
-
-
-def get_root():
- """Get the project root directory.
-
- We require that all commands are run from the project root, i.e. the
- directory that contains setup.py, setup.cfg, and versioneer.py .
- """
- root = os.path.realpath(os.path.abspath(os.getcwd()))
- setup_py = os.path.join(root, "setup.py")
- versioneer_py = os.path.join(root, "versioneer.py")
- if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)):
- # allow 'python path/to/setup.py COMMAND'
- root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0])))
- setup_py = os.path.join(root, "setup.py")
- versioneer_py = os.path.join(root, "versioneer.py")
- if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)):
- err = ("Versioneer was unable to run the project root directory. "
- "Versioneer requires setup.py to be executed from "
- "its immediate directory (like 'python setup.py COMMAND'), "
- "or in a way that lets it use sys.argv[0] to find the root "
- "(like 'python path/to/setup.py COMMAND').")
- raise VersioneerBadRootError(err)
- try:
- # Certain runtime workflows (setup.py install/develop in a setuptools
- # tree) execute all dependencies in a single python process, so
- # "versioneer" may be imported multiple times, and python's shared
- # module-import table will cache the first one. So we can't use
- # os.path.dirname(__file__), as that will find whichever
- # versioneer.py was first imported, even in later projects.
- my_path = os.path.realpath(os.path.abspath(__file__))
- me_dir = os.path.normcase(os.path.splitext(my_path)[0])
- vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0])
- if me_dir != vsr_dir:
- print("Warning: build in %s is using versioneer.py from %s"
- % (os.path.dirname(my_path), versioneer_py))
- except NameError:
- pass
- return root
-
-
-def get_config_from_root(root):
- """Read the project setup.cfg file to determine Versioneer config."""
- # This might raise OSError (if setup.cfg is missing), or
- # configparser.NoSectionError (if it lacks a [versioneer] section), or
- # configparser.NoOptionError (if it lacks "VCS="). See the docstring at
- # the top of versioneer.py for instructions on writing your setup.cfg .
- setup_cfg = os.path.join(root, "setup.cfg")
- parser = configparser.ConfigParser()
- with open(setup_cfg, "r") as cfg_file:
- parser.read_file(cfg_file)
- VCS = parser.get("versioneer", "VCS") # mandatory
-
- # Dict-like interface for non-mandatory entries
- section = parser["versioneer"]
-
- cfg = VersioneerConfig()
- cfg.VCS = VCS
- cfg.style = section.get("style", "")
- cfg.versionfile_source = section.get("versionfile_source")
- cfg.versionfile_build = section.get("versionfile_build")
- cfg.tag_prefix = section.get("tag_prefix")
- if cfg.tag_prefix in ("''", '""', None):
- cfg.tag_prefix = ""
- cfg.parentdir_prefix = section.get("parentdir_prefix")
- cfg.verbose = section.get("verbose")
- return cfg
-
-
-class NotThisMethod(Exception):
- """Exception raised if a method is not valid for the current scenario."""
-
-
-# these dictionaries contain VCS-specific tools
-LONG_VERSION_PY: Dict[str, str] = {}
-HANDLERS: Dict[str, Dict[str, Callable]] = {}
-
-
-def register_vcs_handler(vcs, method): # decorator
- """Create decorator to mark a method as the handler of a VCS."""
- def decorate(f):
- """Store f in HANDLERS[vcs][method]."""
- HANDLERS.setdefault(vcs, {})[method] = f
- return f
- return decorate
-
-
-def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
- env=None):
- """Call the given command(s)."""
- assert isinstance(commands, list)
- process = None
-
- popen_kwargs = {}
- if sys.platform == "win32":
- # This hides the console window if pythonw.exe is used
- startupinfo = subprocess.STARTUPINFO()
- startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
- popen_kwargs["startupinfo"] = startupinfo
-
- for command in commands:
- try:
- dispcmd = str([command] + args)
- # remember shell=False, so use git.cmd on windows, not just git
- process = subprocess.Popen([command] + args, cwd=cwd, env=env,
- stdout=subprocess.PIPE,
- stderr=(subprocess.PIPE if hide_stderr
- else None), **popen_kwargs)
- break
- except OSError:
- e = sys.exc_info()[1]
- if e.errno == errno.ENOENT:
- continue
- if verbose:
- print("unable to run %s" % dispcmd)
- print(e)
- return None, None
- else:
- if verbose:
- print("unable to find command, tried %s" % (commands,))
- return None, None
- stdout = process.communicate()[0].strip().decode()
- if process.returncode != 0:
- if verbose:
- print("unable to run %s (error)" % dispcmd)
- print("stdout was %s" % stdout)
- return None, process.returncode
- return stdout, process.returncode
-
-
-LONG_VERSION_PY['git'] = r'''
-# This file helps to compute a version number in source trees obtained from
-# git-archive tarball (such as those provided by githubs download-from-tag
-# feature). Distribution tarballs (built by setup.py sdist) and build
-# directories (produced by setup.py build) will contain a much shorter file
-# that just contains the computed version number.
-
-# This file is released into the public domain. Generated by
-# versioneer-0.23 (https://github.com/python-versioneer/python-versioneer)
-
-"""Git implementation of _version.py."""
-
-import errno
-import os
-import re
-import subprocess
-import sys
-from typing import Callable, Dict
-import functools
-
-
-def get_keywords():
- """Get the keywords needed to look up the version information."""
- # these strings will be replaced by git during git-archive.
- # setup.py/versioneer.py will grep for the variable names, so they must
- # each be defined on a line of their own. _version.py will just call
- # get_keywords().
- git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s"
- git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s"
- git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s"
- keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
- return keywords
-
-
-class VersioneerConfig:
- """Container for Versioneer configuration parameters."""
-
-
-def get_config():
- """Create, populate and return the VersioneerConfig() object."""
- # these strings are filled in when 'setup.py versioneer' creates
- # _version.py
- cfg = VersioneerConfig()
- cfg.VCS = "git"
- cfg.style = "%(STYLE)s"
- cfg.tag_prefix = "%(TAG_PREFIX)s"
- cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s"
- cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s"
- cfg.verbose = False
- return cfg
-
-
-class NotThisMethod(Exception):
- """Exception raised if a method is not valid for the current scenario."""
-
-
-LONG_VERSION_PY: Dict[str, str] = {}
-HANDLERS: Dict[str, Dict[str, Callable]] = {}
-
-
-def register_vcs_handler(vcs, method): # decorator
- """Create decorator to mark a method as the handler of a VCS."""
- def decorate(f):
- """Store f in HANDLERS[vcs][method]."""
- if vcs not in HANDLERS:
- HANDLERS[vcs] = {}
- HANDLERS[vcs][method] = f
- return f
- return decorate
-
-
-def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
- env=None):
- """Call the given command(s)."""
- assert isinstance(commands, list)
- process = None
-
- popen_kwargs = {}
- if sys.platform == "win32":
- # This hides the console window if pythonw.exe is used
- startupinfo = subprocess.STARTUPINFO()
- startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
- popen_kwargs["startupinfo"] = startupinfo
-
- for command in commands:
- try:
- dispcmd = str([command] + args)
- # remember shell=False, so use git.cmd on windows, not just git
- process = subprocess.Popen([command] + args, cwd=cwd, env=env,
- stdout=subprocess.PIPE,
- stderr=(subprocess.PIPE if hide_stderr
- else None), **popen_kwargs)
- break
- except OSError:
- e = sys.exc_info()[1]
- if e.errno == errno.ENOENT:
- continue
- if verbose:
- print("unable to run %%s" %% dispcmd)
- print(e)
- return None, None
- else:
- if verbose:
- print("unable to find command, tried %%s" %% (commands,))
- return None, None
- stdout = process.communicate()[0].strip().decode()
- if process.returncode != 0:
- if verbose:
- print("unable to run %%s (error)" %% dispcmd)
- print("stdout was %%s" %% stdout)
- return None, process.returncode
- return stdout, process.returncode
-
-
-def versions_from_parentdir(parentdir_prefix, root, verbose):
- """Try to determine the version from the parent directory name.
-
- Source tarballs conventionally unpack into a directory that includes both
- the project name and a version string. We will also support searching up
- two directory levels for an appropriately named parent directory
- """
- rootdirs = []
-
- for _ in range(3):
- dirname = os.path.basename(root)
- if dirname.startswith(parentdir_prefix):
- return {"version": dirname[len(parentdir_prefix):],
- "full-revisionid": None,
- "dirty": False, "error": None, "date": None}
- rootdirs.append(root)
- root = os.path.dirname(root) # up a level
-
- if verbose:
- print("Tried directories %%s but none started with prefix %%s" %%
- (str(rootdirs), parentdir_prefix))
- raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
-
-
-@register_vcs_handler("git", "get_keywords")
-def git_get_keywords(versionfile_abs):
- """Extract version information from the given file."""
- # the code embedded in _version.py can just fetch the value of these
- # keywords. When used from setup.py, we don't want to import _version.py,
- # so we do it with a regexp instead. This function is not used from
- # _version.py.
- keywords = {}
- try:
- with open(versionfile_abs, "r") as fobj:
- for line in fobj:
- if line.strip().startswith("git_refnames ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["refnames"] = mo.group(1)
- if line.strip().startswith("git_full ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["full"] = mo.group(1)
- if line.strip().startswith("git_date ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["date"] = mo.group(1)
- except OSError:
- pass
- return keywords
-
-
-@register_vcs_handler("git", "keywords")
-def git_versions_from_keywords(keywords, tag_prefix, verbose):
- """Get version information from git keywords."""
- if "refnames" not in keywords:
- raise NotThisMethod("Short version file found")
- date = keywords.get("date")
- if date is not None:
- # Use only the last line. Previous lines may contain GPG signature
- # information.
- date = date.splitlines()[-1]
-
- # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant
- # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601
- # -like" string, which we must then edit to make compliant), because
- # it's been around since git-1.5.3, and it's too difficult to
- # discover which version we're using, or to work around using an
- # older one.
- date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
- refnames = keywords["refnames"].strip()
- if refnames.startswith("$Format"):
- if verbose:
- print("keywords are unexpanded, not using")
- raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
- refs = {r.strip() for r in refnames.strip("()").split(",")}
- # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
- # just "foo-1.0". If we see a "tag: " prefix, prefer those.
- TAG = "tag: "
- tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
- if not tags:
- # Either we're using git < 1.8.3, or there really are no tags. We use
- # a heuristic: assume all version tags have a digit. The old git %%d
- # expansion behaves like git log --decorate=short and strips out the
- # refs/heads/ and refs/tags/ prefixes that would let us distinguish
- # between branches and tags. By ignoring refnames without digits, we
- # filter out many common branch names like "release" and
- # "stabilization", as well as "HEAD" and "master".
- tags = {r for r in refs if re.search(r'\d', r)}
- if verbose:
- print("discarding '%%s', no digits" %% ",".join(refs - tags))
- if verbose:
- print("likely tags: %%s" %% ",".join(sorted(tags)))
- for ref in sorted(tags):
- # sorting will prefer e.g. "2.0" over "2.0rc1"
- if ref.startswith(tag_prefix):
- r = ref[len(tag_prefix):]
- # Filter out refs that exactly match prefix or that don't start
- # with a number once the prefix is stripped (mostly a concern
- # when prefix is '')
- if not re.match(r'\d', r):
- continue
- if verbose:
- print("picking %%s" %% r)
- return {"version": r,
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": None,
- "date": date}
- # no suitable tags, so version is "0+unknown", but full hex is still there
- if verbose:
- print("no suitable tags, using unknown + full revision id")
- return {"version": "0+unknown",
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": "no suitable tags", "date": None}
-
-
-@register_vcs_handler("git", "pieces_from_vcs")
-def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
- """Get version from 'git describe' in the root of the source tree.
-
- This only gets called if the git-archive 'subst' keywords were *not*
- expanded, and _version.py hasn't already been rewritten with a short
- version string, meaning we're inside a checked out source tree.
- """
- GITS = ["git"]
- if sys.platform == "win32":
- GITS = ["git.cmd", "git.exe"]
-
- # GIT_DIR can interfere with correct operation of Versioneer.
- # It may be intended to be passed to the Versioneer-versioned project,
- # but that should not change where we get our version from.
- env = os.environ.copy()
- env.pop("GIT_DIR", None)
- runner = functools.partial(runner, env=env)
-
- _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
- hide_stderr=True)
- if rc != 0:
- if verbose:
- print("Directory %%s not under git control" %% root)
- raise NotThisMethod("'git rev-parse --git-dir' returned error")
-
- # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
- # if there isn't one, this yields HEX[-dirty] (no NUM)
- describe_out, rc = runner(GITS, [
- "describe", "--tags", "--dirty", "--always", "--long",
- "--match", f"{tag_prefix}[[:digit:]]*"
- ], cwd=root)
- # --long was added in git-1.5.5
- if describe_out is None:
- raise NotThisMethod("'git describe' failed")
- describe_out = describe_out.strip()
- full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
- if full_out is None:
- raise NotThisMethod("'git rev-parse' failed")
- full_out = full_out.strip()
-
- pieces = {}
- pieces["long"] = full_out
- pieces["short"] = full_out[:7] # maybe improved later
- pieces["error"] = None
-
- branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
- cwd=root)
- # --abbrev-ref was added in git-1.6.3
- if rc != 0 or branch_name is None:
- raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
- branch_name = branch_name.strip()
-
- if branch_name == "HEAD":
- # If we aren't exactly on a branch, pick a branch which represents
- # the current commit. If all else fails, we are on a branchless
- # commit.
- branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
- # --contains was added in git-1.5.4
- if rc != 0 or branches is None:
- raise NotThisMethod("'git branch --contains' returned error")
- branches = branches.split("\n")
-
- # Remove the first line if we're running detached
- if "(" in branches[0]:
- branches.pop(0)
-
- # Strip off the leading "* " from the list of branches.
- branches = [branch[2:] for branch in branches]
- if "master" in branches:
- branch_name = "master"
- elif not branches:
- branch_name = None
- else:
- # Pick the first branch that is returned. Good or bad.
- branch_name = branches[0]
-
- pieces["branch"] = branch_name
-
- # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
- # TAG might have hyphens.
- git_describe = describe_out
-
- # look for -dirty suffix
- dirty = git_describe.endswith("-dirty")
- pieces["dirty"] = dirty
- if dirty:
- git_describe = git_describe[:git_describe.rindex("-dirty")]
-
- # now we have TAG-NUM-gHEX or HEX
-
- if "-" in git_describe:
- # TAG-NUM-gHEX
- mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
- if not mo:
- # unparsable. Maybe git-describe is misbehaving?
- pieces["error"] = ("unable to parse git-describe output: '%%s'"
- %% describe_out)
- return pieces
-
- # tag
- full_tag = mo.group(1)
- if not full_tag.startswith(tag_prefix):
- if verbose:
- fmt = "tag '%%s' doesn't start with prefix '%%s'"
- print(fmt %% (full_tag, tag_prefix))
- pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'"
- %% (full_tag, tag_prefix))
- return pieces
- pieces["closest-tag"] = full_tag[len(tag_prefix):]
-
- # distance: number of commits since tag
- pieces["distance"] = int(mo.group(2))
-
- # commit: short hex revision ID
- pieces["short"] = mo.group(3)
-
- else:
- # HEX: no tags
- pieces["closest-tag"] = None
- out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
- pieces["distance"] = len(out.split()) # total number of commits
-
- # commit date: see ISO-8601 comment in git_versions_from_keywords()
- date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip()
- # Use only the last line. Previous lines may contain GPG signature
- # information.
- date = date.splitlines()[-1]
- pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
-
- return pieces
-
-
-def plus_or_dot(pieces):
- """Return a + if we don't already have one, else return a ."""
- if "+" in pieces.get("closest-tag", ""):
- return "."
- return "+"
-
-
-def render_pep440(pieces):
- """Build up version string, with post-release "local version identifier".
-
- Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
- get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
-
- Exceptions:
- 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += plus_or_dot(pieces)
- rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- else:
- # exception #1
- rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"],
- pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- return rendered
-
-
-def render_pep440_branch(pieces):
- """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
-
- The ".dev0" means not master branch. Note that .dev0 sorts backwards
- (a feature branch will appear "older" than the master branch).
-
- Exceptions:
- 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- if pieces["branch"] != "master":
- rendered += ".dev0"
- rendered += plus_or_dot(pieces)
- rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- else:
- # exception #1
- rendered = "0"
- if pieces["branch"] != "master":
- rendered += ".dev0"
- rendered += "+untagged.%%d.g%%s" %% (pieces["distance"],
- pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- return rendered
-
-
-def pep440_split_post(ver):
- """Split pep440 version string at the post-release segment.
-
- Returns the release segments before the post-release and the
- post-release version number (or -1 if no post-release segment is present).
- """
- vc = str.split(ver, ".post")
- return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
-
-
-def render_pep440_pre(pieces):
- """TAG[.postN.devDISTANCE] -- No -dirty.
-
- Exceptions:
- 1: no tags. 0.post0.devDISTANCE
- """
- if pieces["closest-tag"]:
- if pieces["distance"]:
- # update the post release segment
- tag_version, post_version = pep440_split_post(pieces["closest-tag"])
- rendered = tag_version
- if post_version is not None:
- rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"])
- else:
- rendered += ".post0.dev%%d" %% (pieces["distance"])
- else:
- # no commits, use the tag as the version
- rendered = pieces["closest-tag"]
- else:
- # exception #1
- rendered = "0.post0.dev%%d" %% pieces["distance"]
- return rendered
-
-
-def render_pep440_post(pieces):
- """TAG[.postDISTANCE[.dev0]+gHEX] .
-
- The ".dev0" means dirty. Note that .dev0 sorts backwards
- (a dirty tree will appear "older" than the corresponding clean one),
- but you shouldn't be releasing software with -dirty anyways.
-
- Exceptions:
- 1: no tags. 0.postDISTANCE[.dev0]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += ".post%%d" %% pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- rendered += plus_or_dot(pieces)
- rendered += "g%%s" %% pieces["short"]
- else:
- # exception #1
- rendered = "0.post%%d" %% pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- rendered += "+g%%s" %% pieces["short"]
- return rendered
-
-
-def render_pep440_post_branch(pieces):
- """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
-
- The ".dev0" means not master branch.
-
- Exceptions:
- 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += ".post%%d" %% pieces["distance"]
- if pieces["branch"] != "master":
- rendered += ".dev0"
- rendered += plus_or_dot(pieces)
- rendered += "g%%s" %% pieces["short"]
- if pieces["dirty"]:
- rendered += ".dirty"
- else:
- # exception #1
- rendered = "0.post%%d" %% pieces["distance"]
- if pieces["branch"] != "master":
- rendered += ".dev0"
- rendered += "+g%%s" %% pieces["short"]
- if pieces["dirty"]:
- rendered += ".dirty"
- return rendered
-
-
-def render_pep440_old(pieces):
- """TAG[.postDISTANCE[.dev0]] .
-
- The ".dev0" means dirty.
-
- Exceptions:
- 1: no tags. 0.postDISTANCE[.dev0]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += ".post%%d" %% pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- else:
- # exception #1
- rendered = "0.post%%d" %% pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- return rendered
-
-
-def render_git_describe(pieces):
- """TAG[-DISTANCE-gHEX][-dirty].
-
- Like 'git describe --tags --dirty --always'.
-
- Exceptions:
- 1: no tags. HEX[-dirty] (note: no 'g' prefix)
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"]:
- rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"])
- else:
- # exception #1
- rendered = pieces["short"]
- if pieces["dirty"]:
- rendered += "-dirty"
- return rendered
-
-
-def render_git_describe_long(pieces):
- """TAG-DISTANCE-gHEX[-dirty].
-
- Like 'git describe --tags --dirty --always -long'.
- The distance/hash is unconditional.
-
- Exceptions:
- 1: no tags. HEX[-dirty] (note: no 'g' prefix)
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"])
- else:
- # exception #1
- rendered = pieces["short"]
- if pieces["dirty"]:
- rendered += "-dirty"
- return rendered
-
-
-def render(pieces, style):
- """Render the given version pieces into the requested style."""
- if pieces["error"]:
- return {"version": "unknown",
- "full-revisionid": pieces.get("long"),
- "dirty": None,
- "error": pieces["error"],
- "date": None}
-
- if not style or style == "default":
- style = "pep440" # the default
-
- if style == "pep440":
- rendered = render_pep440(pieces)
- elif style == "pep440-branch":
- rendered = render_pep440_branch(pieces)
- elif style == "pep440-pre":
- rendered = render_pep440_pre(pieces)
- elif style == "pep440-post":
- rendered = render_pep440_post(pieces)
- elif style == "pep440-post-branch":
- rendered = render_pep440_post_branch(pieces)
- elif style == "pep440-old":
- rendered = render_pep440_old(pieces)
- elif style == "git-describe":
- rendered = render_git_describe(pieces)
- elif style == "git-describe-long":
- rendered = render_git_describe_long(pieces)
- else:
- raise ValueError("unknown style '%%s'" %% style)
-
- return {"version": rendered, "full-revisionid": pieces["long"],
- "dirty": pieces["dirty"], "error": None,
- "date": pieces.get("date")}
-
-
-def get_versions():
- """Get version information or return default if unable to do so."""
- # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
- # __file__, we can work backwards from there to the root. Some
- # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
- # case we can only use expanded keywords.
-
- cfg = get_config()
- verbose = cfg.verbose
-
- try:
- return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
- verbose)
- except NotThisMethod:
- pass
-
- try:
- root = os.path.realpath(__file__)
- # versionfile_source is the relative path from the top of the source
- # tree (where the .git directory might live) to this file. Invert
- # this to find the root from __file__.
- for _ in cfg.versionfile_source.split('/'):
- root = os.path.dirname(root)
- except NameError:
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None,
- "error": "unable to find root of source tree",
- "date": None}
-
- try:
- pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
- return render(pieces, cfg.style)
- except NotThisMethod:
- pass
-
- try:
- if cfg.parentdir_prefix:
- return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
- except NotThisMethod:
- pass
-
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None,
- "error": "unable to compute version", "date": None}
-'''
-
-
-@register_vcs_handler("git", "get_keywords")
-def git_get_keywords(versionfile_abs):
- """Extract version information from the given file."""
- # the code embedded in _version.py can just fetch the value of these
- # keywords. When used from setup.py, we don't want to import _version.py,
- # so we do it with a regexp instead. This function is not used from
- # _version.py.
- keywords = {}
- try:
- with open(versionfile_abs, "r") as fobj:
- for line in fobj:
- if line.strip().startswith("git_refnames ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["refnames"] = mo.group(1)
- if line.strip().startswith("git_full ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["full"] = mo.group(1)
- if line.strip().startswith("git_date ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["date"] = mo.group(1)
- except OSError:
- pass
- return keywords
-
-
-@register_vcs_handler("git", "keywords")
-def git_versions_from_keywords(keywords, tag_prefix, verbose):
- """Get version information from git keywords."""
- if "refnames" not in keywords:
- raise NotThisMethod("Short version file found")
- date = keywords.get("date")
- if date is not None:
- # Use only the last line. Previous lines may contain GPG signature
- # information.
- date = date.splitlines()[-1]
-
- # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
- # datestamp. However we prefer "%ci" (which expands to an "ISO-8601
- # -like" string, which we must then edit to make compliant), because
- # it's been around since git-1.5.3, and it's too difficult to
- # discover which version we're using, or to work around using an
- # older one.
- date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
- refnames = keywords["refnames"].strip()
- if refnames.startswith("$Format"):
- if verbose:
- print("keywords are unexpanded, not using")
- raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
- refs = {r.strip() for r in refnames.strip("()").split(",")}
- # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
- # just "foo-1.0". If we see a "tag: " prefix, prefer those.
- TAG = "tag: "
- tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
- if not tags:
- # Either we're using git < 1.8.3, or there really are no tags. We use
- # a heuristic: assume all version tags have a digit. The old git %d
- # expansion behaves like git log --decorate=short and strips out the
- # refs/heads/ and refs/tags/ prefixes that would let us distinguish
- # between branches and tags. By ignoring refnames without digits, we
- # filter out many common branch names like "release" and
- # "stabilization", as well as "HEAD" and "master".
- tags = {r for r in refs if re.search(r'\d', r)}
- if verbose:
- print("discarding '%s', no digits" % ",".join(refs - tags))
- if verbose:
- print("likely tags: %s" % ",".join(sorted(tags)))
- for ref in sorted(tags):
- # sorting will prefer e.g. "2.0" over "2.0rc1"
- if ref.startswith(tag_prefix):
- r = ref[len(tag_prefix):]
- # Filter out refs that exactly match prefix or that don't start
- # with a number once the prefix is stripped (mostly a concern
- # when prefix is '')
- if not re.match(r'\d', r):
- continue
- if verbose:
- print("picking %s" % r)
- return {"version": r,
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": None,
- "date": date}
- # no suitable tags, so version is "0+unknown", but full hex is still there
- if verbose:
- print("no suitable tags, using unknown + full revision id")
- return {"version": "0+unknown",
- "full-revisionid": keywords["full"].strip(),
- "dirty": False, "error": "no suitable tags", "date": None}
-
-
-@register_vcs_handler("git", "pieces_from_vcs")
-def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
- """Get version from 'git describe' in the root of the source tree.
-
- This only gets called if the git-archive 'subst' keywords were *not*
- expanded, and _version.py hasn't already been rewritten with a short
- version string, meaning we're inside a checked out source tree.
- """
- GITS = ["git"]
- if sys.platform == "win32":
- GITS = ["git.cmd", "git.exe"]
-
- # GIT_DIR can interfere with correct operation of Versioneer.
- # It may be intended to be passed to the Versioneer-versioned project,
- # but that should not change where we get our version from.
- env = os.environ.copy()
- env.pop("GIT_DIR", None)
- runner = functools.partial(runner, env=env)
-
- _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
- hide_stderr=True)
- if rc != 0:
- if verbose:
- print("Directory %s not under git control" % root)
- raise NotThisMethod("'git rev-parse --git-dir' returned error")
-
- # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
- # if there isn't one, this yields HEX[-dirty] (no NUM)
- describe_out, rc = runner(GITS, [
- "describe", "--tags", "--dirty", "--always", "--long",
- "--match", f"{tag_prefix}[[:digit:]]*"
- ], cwd=root)
- # --long was added in git-1.5.5
- if describe_out is None:
- raise NotThisMethod("'git describe' failed")
- describe_out = describe_out.strip()
- full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
- if full_out is None:
- raise NotThisMethod("'git rev-parse' failed")
- full_out = full_out.strip()
-
- pieces = {}
- pieces["long"] = full_out
- pieces["short"] = full_out[:7] # maybe improved later
- pieces["error"] = None
-
- branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
- cwd=root)
- # --abbrev-ref was added in git-1.6.3
- if rc != 0 or branch_name is None:
- raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
- branch_name = branch_name.strip()
-
- if branch_name == "HEAD":
- # If we aren't exactly on a branch, pick a branch which represents
- # the current commit. If all else fails, we are on a branchless
- # commit.
- branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
- # --contains was added in git-1.5.4
- if rc != 0 or branches is None:
- raise NotThisMethod("'git branch --contains' returned error")
- branches = branches.split("\n")
-
- # Remove the first line if we're running detached
- if "(" in branches[0]:
- branches.pop(0)
-
- # Strip off the leading "* " from the list of branches.
- branches = [branch[2:] for branch in branches]
- if "master" in branches:
- branch_name = "master"
- elif not branches:
- branch_name = None
- else:
- # Pick the first branch that is returned. Good or bad.
- branch_name = branches[0]
-
- pieces["branch"] = branch_name
-
- # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
- # TAG might have hyphens.
- git_describe = describe_out
-
- # look for -dirty suffix
- dirty = git_describe.endswith("-dirty")
- pieces["dirty"] = dirty
- if dirty:
- git_describe = git_describe[:git_describe.rindex("-dirty")]
-
- # now we have TAG-NUM-gHEX or HEX
-
- if "-" in git_describe:
- # TAG-NUM-gHEX
- mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
- if not mo:
- # unparsable. Maybe git-describe is misbehaving?
- pieces["error"] = ("unable to parse git-describe output: '%s'"
- % describe_out)
- return pieces
-
- # tag
- full_tag = mo.group(1)
- if not full_tag.startswith(tag_prefix):
- if verbose:
- fmt = "tag '%s' doesn't start with prefix '%s'"
- print(fmt % (full_tag, tag_prefix))
- pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
- % (full_tag, tag_prefix))
- return pieces
- pieces["closest-tag"] = full_tag[len(tag_prefix):]
-
- # distance: number of commits since tag
- pieces["distance"] = int(mo.group(2))
-
- # commit: short hex revision ID
- pieces["short"] = mo.group(3)
-
- else:
- # HEX: no tags
- pieces["closest-tag"] = None
- out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
- pieces["distance"] = len(out.split()) # total number of commits
-
- # commit date: see ISO-8601 comment in git_versions_from_keywords()
- date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
- # Use only the last line. Previous lines may contain GPG signature
- # information.
- date = date.splitlines()[-1]
- pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
-
- return pieces
-
-
-def do_vcs_install(versionfile_source, ipy):
- """Git-specific installation logic for Versioneer.
-
- For Git, this means creating/changing .gitattributes to mark _version.py
- for export-subst keyword substitution.
- """
- GITS = ["git"]
- if sys.platform == "win32":
- GITS = ["git.cmd", "git.exe"]
- files = [versionfile_source]
- if ipy:
- files.append(ipy)
- try:
- my_path = __file__
- if my_path.endswith(".pyc") or my_path.endswith(".pyo"):
- my_path = os.path.splitext(my_path)[0] + ".py"
- versioneer_file = os.path.relpath(my_path)
- except NameError:
- versioneer_file = "versioneer.py"
- files.append(versioneer_file)
- present = False
- try:
- with open(".gitattributes", "r") as fobj:
- for line in fobj:
- if line.strip().startswith(versionfile_source):
- if "export-subst" in line.strip().split()[1:]:
- present = True
- break
- except OSError:
- pass
- if not present:
- with open(".gitattributes", "a+") as fobj:
- fobj.write(f"{versionfile_source} export-subst\n")
- files.append(".gitattributes")
- run_command(GITS, ["add", "--"] + files)
-
-
-def versions_from_parentdir(parentdir_prefix, root, verbose):
- """Try to determine the version from the parent directory name.
-
- Source tarballs conventionally unpack into a directory that includes both
- the project name and a version string. We will also support searching up
- two directory levels for an appropriately named parent directory
- """
- rootdirs = []
-
- for _ in range(3):
- dirname = os.path.basename(root)
- if dirname.startswith(parentdir_prefix):
- return {"version": dirname[len(parentdir_prefix):],
- "full-revisionid": None,
- "dirty": False, "error": None, "date": None}
- rootdirs.append(root)
- root = os.path.dirname(root) # up a level
-
- if verbose:
- print("Tried directories %s but none started with prefix %s" %
- (str(rootdirs), parentdir_prefix))
- raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
-
-
-SHORT_VERSION_PY = """
-# This file was generated by 'versioneer.py' (0.23) from
-# revision-control system data, or from the parent directory name of an
-# unpacked source archive. Distribution tarballs contain a pre-generated copy
-# of this file.
-
-import json
-
-version_json = '''
-%s
-''' # END VERSION_JSON
-
-
-def get_versions():
- return json.loads(version_json)
-"""
-
-
-def versions_from_file(filename):
- """Try to determine the version from _version.py if present."""
- try:
- with open(filename) as f:
- contents = f.read()
- except OSError:
- raise NotThisMethod("unable to read _version.py")
- mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON",
- contents, re.M | re.S)
- if not mo:
- mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON",
- contents, re.M | re.S)
- if not mo:
- raise NotThisMethod("no version_json in _version.py")
- return json.loads(mo.group(1))
-
-
-def write_to_version_file(filename, versions):
- """Write the given version number to the given _version.py file."""
- os.unlink(filename)
- contents = json.dumps(versions, sort_keys=True,
- indent=1, separators=(",", ": "))
- with open(filename, "w") as f:
- f.write(SHORT_VERSION_PY % contents)
-
- print("set %s to '%s'" % (filename, versions["version"]))
-
-
-def plus_or_dot(pieces):
- """Return a + if we don't already have one, else return a ."""
- if "+" in pieces.get("closest-tag", ""):
- return "."
- return "+"
-
-
-def render_pep440(pieces):
- """Build up version string, with post-release "local version identifier".
-
- Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
- get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
-
- Exceptions:
- 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += plus_or_dot(pieces)
- rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- else:
- # exception #1
- rendered = "0+untagged.%d.g%s" % (pieces["distance"],
- pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- return rendered
-
-
-def render_pep440_branch(pieces):
- """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
-
- The ".dev0" means not master branch. Note that .dev0 sorts backwards
- (a feature branch will appear "older" than the master branch).
-
- Exceptions:
- 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- if pieces["branch"] != "master":
- rendered += ".dev0"
- rendered += plus_or_dot(pieces)
- rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- else:
- # exception #1
- rendered = "0"
- if pieces["branch"] != "master":
- rendered += ".dev0"
- rendered += "+untagged.%d.g%s" % (pieces["distance"],
- pieces["short"])
- if pieces["dirty"]:
- rendered += ".dirty"
- return rendered
-
-
-def pep440_split_post(ver):
- """Split pep440 version string at the post-release segment.
-
- Returns the release segments before the post-release and the
- post-release version number (or -1 if no post-release segment is present).
- """
- vc = str.split(ver, ".post")
- return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
-
-
-def render_pep440_pre(pieces):
- """TAG[.postN.devDISTANCE] -- No -dirty.
-
- Exceptions:
- 1: no tags. 0.post0.devDISTANCE
- """
- if pieces["closest-tag"]:
- if pieces["distance"]:
- # update the post release segment
- tag_version, post_version = pep440_split_post(pieces["closest-tag"])
- rendered = tag_version
- if post_version is not None:
- rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
- else:
- rendered += ".post0.dev%d" % (pieces["distance"])
- else:
- # no commits, use the tag as the version
- rendered = pieces["closest-tag"]
- else:
- # exception #1
- rendered = "0.post0.dev%d" % pieces["distance"]
- return rendered
-
-
-def render_pep440_post(pieces):
- """TAG[.postDISTANCE[.dev0]+gHEX] .
-
- The ".dev0" means dirty. Note that .dev0 sorts backwards
- (a dirty tree will appear "older" than the corresponding clean one),
- but you shouldn't be releasing software with -dirty anyways.
-
- Exceptions:
- 1: no tags. 0.postDISTANCE[.dev0]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += ".post%d" % pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- rendered += plus_or_dot(pieces)
- rendered += "g%s" % pieces["short"]
- else:
- # exception #1
- rendered = "0.post%d" % pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- rendered += "+g%s" % pieces["short"]
- return rendered
-
-
-def render_pep440_post_branch(pieces):
- """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
-
- The ".dev0" means not master branch.
-
- Exceptions:
- 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += ".post%d" % pieces["distance"]
- if pieces["branch"] != "master":
- rendered += ".dev0"
- rendered += plus_or_dot(pieces)
- rendered += "g%s" % pieces["short"]
- if pieces["dirty"]:
- rendered += ".dirty"
- else:
- # exception #1
- rendered = "0.post%d" % pieces["distance"]
- if pieces["branch"] != "master":
- rendered += ".dev0"
- rendered += "+g%s" % pieces["short"]
- if pieces["dirty"]:
- rendered += ".dirty"
- return rendered
-
-
-def render_pep440_old(pieces):
- """TAG[.postDISTANCE[.dev0]] .
-
- The ".dev0" means dirty.
-
- Exceptions:
- 1: no tags. 0.postDISTANCE[.dev0]
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"] or pieces["dirty"]:
- rendered += ".post%d" % pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- else:
- # exception #1
- rendered = "0.post%d" % pieces["distance"]
- if pieces["dirty"]:
- rendered += ".dev0"
- return rendered
-
-
-def render_git_describe(pieces):
- """TAG[-DISTANCE-gHEX][-dirty].
-
- Like 'git describe --tags --dirty --always'.
-
- Exceptions:
- 1: no tags. HEX[-dirty] (note: no 'g' prefix)
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- if pieces["distance"]:
- rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
- else:
- # exception #1
- rendered = pieces["short"]
- if pieces["dirty"]:
- rendered += "-dirty"
- return rendered
-
-
-def render_git_describe_long(pieces):
- """TAG-DISTANCE-gHEX[-dirty].
-
- Like 'git describe --tags --dirty --always -long'.
- The distance/hash is unconditional.
-
- Exceptions:
- 1: no tags. HEX[-dirty] (note: no 'g' prefix)
- """
- if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
- rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
- else:
- # exception #1
- rendered = pieces["short"]
- if pieces["dirty"]:
- rendered += "-dirty"
- return rendered
-
-
-def render(pieces, style):
- """Render the given version pieces into the requested style."""
- if pieces["error"]:
- return {"version": "unknown",
- "full-revisionid": pieces.get("long"),
- "dirty": None,
- "error": pieces["error"],
- "date": None}
-
- if not style or style == "default":
- style = "pep440" # the default
-
- if style == "pep440":
- rendered = render_pep440(pieces)
- elif style == "pep440-branch":
- rendered = render_pep440_branch(pieces)
- elif style == "pep440-pre":
- rendered = render_pep440_pre(pieces)
- elif style == "pep440-post":
- rendered = render_pep440_post(pieces)
- elif style == "pep440-post-branch":
- rendered = render_pep440_post_branch(pieces)
- elif style == "pep440-old":
- rendered = render_pep440_old(pieces)
- elif style == "git-describe":
- rendered = render_git_describe(pieces)
- elif style == "git-describe-long":
- rendered = render_git_describe_long(pieces)
- else:
- raise ValueError("unknown style '%s'" % style)
-
- return {"version": rendered, "full-revisionid": pieces["long"],
- "dirty": pieces["dirty"], "error": None,
- "date": pieces.get("date")}
-
-
-class VersioneerBadRootError(Exception):
- """The project root directory is unknown or missing key files."""
-
-
-def get_versions(verbose=False):
- """Get the project version from whatever source is available.
-
- Returns dict with two keys: 'version' and 'full'.
- """
- if "versioneer" in sys.modules:
- # see the discussion in cmdclass.py:get_cmdclass()
- del sys.modules["versioneer"]
-
- root = get_root()
- cfg = get_config_from_root(root)
-
- assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg"
- handlers = HANDLERS.get(cfg.VCS)
- assert handlers, "unrecognized VCS '%s'" % cfg.VCS
- verbose = verbose or cfg.verbose
- assert cfg.versionfile_source is not None, \
- "please set versioneer.versionfile_source"
- assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix"
-
- versionfile_abs = os.path.join(root, cfg.versionfile_source)
-
- # extract version from first of: _version.py, VCS command (e.g. 'git
- # describe'), parentdir. This is meant to work for developers using a
- # source checkout, for users of a tarball created by 'setup.py sdist',
- # and for users of a tarball/zipball created by 'git archive' or github's
- # download-from-tag feature or the equivalent in other VCSes.
-
- get_keywords_f = handlers.get("get_keywords")
- from_keywords_f = handlers.get("keywords")
- if get_keywords_f and from_keywords_f:
- try:
- keywords = get_keywords_f(versionfile_abs)
- ver = from_keywords_f(keywords, cfg.tag_prefix, verbose)
- if verbose:
- print("got version from expanded keyword %s" % ver)
- return ver
- except NotThisMethod:
- pass
-
- try:
- ver = versions_from_file(versionfile_abs)
- if verbose:
- print("got version from file %s %s" % (versionfile_abs, ver))
- return ver
- except NotThisMethod:
- pass
-
- from_vcs_f = handlers.get("pieces_from_vcs")
- if from_vcs_f:
- try:
- pieces = from_vcs_f(cfg.tag_prefix, root, verbose)
- ver = render(pieces, cfg.style)
- if verbose:
- print("got version from VCS %s" % ver)
- return ver
- except NotThisMethod:
- pass
-
- try:
- if cfg.parentdir_prefix:
- ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
- if verbose:
- print("got version from parentdir %s" % ver)
- return ver
- except NotThisMethod:
- pass
-
- if verbose:
- print("unable to compute version")
-
- return {"version": "0+unknown", "full-revisionid": None,
- "dirty": None, "error": "unable to compute version",
- "date": None}
-
-
-def get_version():
- """Get the short version string for this project."""
- return get_versions()["version"]
-
-
-def get_cmdclass(cmdclass=None):
- """Get the custom setuptools subclasses used by Versioneer.
-
- If the package uses a different cmdclass (e.g. one from numpy), it
- should be provide as an argument.
- """
- if "versioneer" in sys.modules:
- del sys.modules["versioneer"]
- # this fixes the "python setup.py develop" case (also 'install' and
- # 'easy_install .'), in which subdependencies of the main project are
- # built (using setup.py bdist_egg) in the same python process. Assume
- # a main project A and a dependency B, which use different versions
- # of Versioneer. A's setup.py imports A's Versioneer, leaving it in
- # sys.modules by the time B's setup.py is executed, causing B to run
- # with the wrong versioneer. Setuptools wraps the sub-dep builds in a
- # sandbox that restores sys.modules to it's pre-build state, so the
- # parent is protected against the child's "import versioneer". By
- # removing ourselves from sys.modules here, before the child build
- # happens, we protect the child from the parent's versioneer too.
- # Also see https://github.com/python-versioneer/python-versioneer/issues/52
-
- cmds = {} if cmdclass is None else cmdclass.copy()
-
- # we add "version" to setuptools
- from setuptools import Command
-
- class cmd_version(Command):
- description = "report generated version string"
- user_options = []
- boolean_options = []
-
- def initialize_options(self):
- pass
-
- def finalize_options(self):
- pass
-
- def run(self):
- vers = get_versions(verbose=True)
- print("Version: %s" % vers["version"])
- print(" full-revisionid: %s" % vers.get("full-revisionid"))
- print(" dirty: %s" % vers.get("dirty"))
- print(" date: %s" % vers.get("date"))
- if vers["error"]:
- print(" error: %s" % vers["error"])
- cmds["version"] = cmd_version
-
- # we override "build_py" in setuptools
- #
- # most invocation pathways end up running build_py:
- # distutils/build -> build_py
- # distutils/install -> distutils/build ->..
- # setuptools/bdist_wheel -> distutils/install ->..
- # setuptools/bdist_egg -> distutils/install_lib -> build_py
- # setuptools/install -> bdist_egg ->..
- # setuptools/develop -> ?
- # pip install:
- # copies source tree to a tempdir before running egg_info/etc
- # if .git isn't copied too, 'git describe' will fail
- # then does setup.py bdist_wheel, or sometimes setup.py install
- # setup.py egg_info -> ?
-
- # pip install -e . and setuptool/editable_wheel will invoke build_py
- # but the build_py command is not expected to copy any files.
-
- # we override different "build_py" commands for both environments
- if 'build_py' in cmds:
- _build_py = cmds['build_py']
- else:
- from setuptools.command.build_py import build_py as _build_py
-
- class cmd_build_py(_build_py):
- def run(self):
- root = get_root()
- cfg = get_config_from_root(root)
- versions = get_versions()
- _build_py.run(self)
- if getattr(self, "editable_mode", False):
- # During editable installs `.py` and data files are
- # not copied to build_lib
- return
- # now locate _version.py in the new build/ directory and replace
- # it with an updated value
- if cfg.versionfile_build:
- target_versionfile = os.path.join(self.build_lib,
- cfg.versionfile_build)
- print("UPDATING %s" % target_versionfile)
- write_to_version_file(target_versionfile, versions)
- cmds["build_py"] = cmd_build_py
-
- if 'build_ext' in cmds:
- _build_ext = cmds['build_ext']
- else:
- from setuptools.command.build_ext import build_ext as _build_ext
-
- class cmd_build_ext(_build_ext):
- def run(self):
- root = get_root()
- cfg = get_config_from_root(root)
- versions = get_versions()
- _build_ext.run(self)
- if self.inplace:
- # build_ext --inplace will only build extensions in
- # build/lib<..> dir with no _version.py to write to.
- # As in place builds will already have a _version.py
- # in the module dir, we do not need to write one.
- return
- # now locate _version.py in the new build/ directory and replace
- # it with an updated value
- target_versionfile = os.path.join(self.build_lib,
- cfg.versionfile_build)
- if not os.path.exists(target_versionfile):
- print(f"Warning: {target_versionfile} does not exist, skipping "
- "version update. This can happen if you are running build_ext "
- "without first running build_py.")
- return
- print("UPDATING %s" % target_versionfile)
- write_to_version_file(target_versionfile, versions)
- cmds["build_ext"] = cmd_build_ext
-
- if "cx_Freeze" in sys.modules: # cx_freeze enabled?
- from cx_Freeze.dist import build_exe as _build_exe
- # nczeczulin reports that py2exe won't like the pep440-style string
- # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g.
- # setup(console=[{
- # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION
- # "product_version": versioneer.get_version(),
- # ...
-
- class cmd_build_exe(_build_exe):
- def run(self):
- root = get_root()
- cfg = get_config_from_root(root)
- versions = get_versions()
- target_versionfile = cfg.versionfile_source
- print("UPDATING %s" % target_versionfile)
- write_to_version_file(target_versionfile, versions)
-
- _build_exe.run(self)
- os.unlink(target_versionfile)
- with open(cfg.versionfile_source, "w") as f:
- LONG = LONG_VERSION_PY[cfg.VCS]
- f.write(LONG %
- {"DOLLAR": "$",
- "STYLE": cfg.style,
- "TAG_PREFIX": cfg.tag_prefix,
- "PARENTDIR_PREFIX": cfg.parentdir_prefix,
- "VERSIONFILE_SOURCE": cfg.versionfile_source,
- })
- cmds["build_exe"] = cmd_build_exe
- del cmds["build_py"]
-
- if 'py2exe' in sys.modules: # py2exe enabled?
- from py2exe.distutils_buildexe import py2exe as _py2exe
-
- class cmd_py2exe(_py2exe):
- def run(self):
- root = get_root()
- cfg = get_config_from_root(root)
- versions = get_versions()
- target_versionfile = cfg.versionfile_source
- print("UPDATING %s" % target_versionfile)
- write_to_version_file(target_versionfile, versions)
-
- _py2exe.run(self)
- os.unlink(target_versionfile)
- with open(cfg.versionfile_source, "w") as f:
- LONG = LONG_VERSION_PY[cfg.VCS]
- f.write(LONG %
- {"DOLLAR": "$",
- "STYLE": cfg.style,
- "TAG_PREFIX": cfg.tag_prefix,
- "PARENTDIR_PREFIX": cfg.parentdir_prefix,
- "VERSIONFILE_SOURCE": cfg.versionfile_source,
- })
- cmds["py2exe"] = cmd_py2exe
-
- # sdist farms its file list building out to egg_info
- if 'egg_info' in cmds:
- _sdist = cmds['egg_info']
- else:
- from setuptools.command.egg_info import egg_info as _egg_info
-
- class cmd_egg_info(_egg_info):
- def find_sources(self):
- # egg_info.find_sources builds the manifest list and writes it
- # in one shot
- super().find_sources()
-
- # Modify the filelist and normalize it
- root = get_root()
- cfg = get_config_from_root(root)
- self.filelist.append('versioneer.py')
- if cfg.versionfile_source:
- # There are rare cases where versionfile_source might not be
- # included by default, so we must be explicit
- self.filelist.append(cfg.versionfile_source)
- self.filelist.sort()
- self.filelist.remove_duplicates()
-
- # The write method is hidden in the manifest_maker instance that
- # generated the filelist and was thrown away
- # We will instead replicate their final normalization (to unicode,
- # and POSIX-style paths)
- from setuptools import unicode_utils
- normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/')
- for f in self.filelist.files]
-
- manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt')
- with open(manifest_filename, 'w') as fobj:
- fobj.write('\n'.join(normalized))
-
- cmds['egg_info'] = cmd_egg_info
-
- # we override different "sdist" commands for both environments
- if 'sdist' in cmds:
- _sdist = cmds['sdist']
- else:
- from setuptools.command.sdist import sdist as _sdist
-
- class cmd_sdist(_sdist):
- def run(self):
- versions = get_versions()
- self._versioneer_generated_versions = versions
- # unless we update this, the command will keep using the old
- # version
- self.distribution.metadata.version = versions["version"]
- return _sdist.run(self)
-
- def make_release_tree(self, base_dir, files):
- root = get_root()
- cfg = get_config_from_root(root)
- _sdist.make_release_tree(self, base_dir, files)
- # now locate _version.py in the new base_dir directory
- # (remembering that it may be a hardlink) and replace it with an
- # updated value
- target_versionfile = os.path.join(base_dir, cfg.versionfile_source)
- print("UPDATING %s" % target_versionfile)
- write_to_version_file(target_versionfile,
- self._versioneer_generated_versions)
- cmds["sdist"] = cmd_sdist
-
- return cmds
-
-
-CONFIG_ERROR = """
-setup.cfg is missing the necessary Versioneer configuration. You need
-a section like:
-
- [versioneer]
- VCS = git
- style = pep440
- versionfile_source = src/myproject/_version.py
- versionfile_build = myproject/_version.py
- tag_prefix =
- parentdir_prefix = myproject-
-
-You will also need to edit your setup.py to use the results:
-
- import versioneer
- setup(version=versioneer.get_version(),
- cmdclass=versioneer.get_cmdclass(), ...)
-
-Please read the docstring in ./versioneer.py for configuration instructions,
-edit setup.cfg, and re-run the installer or 'python versioneer.py setup'.
-"""
-
-SAMPLE_CONFIG = """
-# See the docstring in versioneer.py for instructions. Note that you must
-# re-run 'versioneer.py setup' after changing this section, and commit the
-# resulting files.
-
-[versioneer]
-#VCS = git
-#style = pep440
-#versionfile_source =
-#versionfile_build =
-#tag_prefix =
-#parentdir_prefix =
-
-"""
-
-OLD_SNIPPET = """
-from ._version import get_versions
-__version__ = get_versions()['version']
-del get_versions
-"""
-
-INIT_PY_SNIPPET = """
-from . import {0}
-__version__ = {0}.get_versions()['version']
-"""
-
-
-def do_setup():
- """Do main VCS-independent setup function for installing Versioneer."""
- root = get_root()
- try:
- cfg = get_config_from_root(root)
- except (OSError, configparser.NoSectionError,
- configparser.NoOptionError) as e:
- if isinstance(e, (OSError, configparser.NoSectionError)):
- print("Adding sample versioneer config to setup.cfg",
- file=sys.stderr)
- with open(os.path.join(root, "setup.cfg"), "a") as f:
- f.write(SAMPLE_CONFIG)
- print(CONFIG_ERROR, file=sys.stderr)
- return 1
-
- print(" creating %s" % cfg.versionfile_source)
- with open(cfg.versionfile_source, "w") as f:
- LONG = LONG_VERSION_PY[cfg.VCS]
- f.write(LONG % {"DOLLAR": "$",
- "STYLE": cfg.style,
- "TAG_PREFIX": cfg.tag_prefix,
- "PARENTDIR_PREFIX": cfg.parentdir_prefix,
- "VERSIONFILE_SOURCE": cfg.versionfile_source,
- })
-
- ipy = os.path.join(os.path.dirname(cfg.versionfile_source),
- "__init__.py")
- if os.path.exists(ipy):
- try:
- with open(ipy, "r") as f:
- old = f.read()
- except OSError:
- old = ""
- module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0]
- snippet = INIT_PY_SNIPPET.format(module)
- if OLD_SNIPPET in old:
- print(" replacing boilerplate in %s" % ipy)
- with open(ipy, "w") as f:
- f.write(old.replace(OLD_SNIPPET, snippet))
- elif snippet not in old:
- print(" appending to %s" % ipy)
- with open(ipy, "a") as f:
- f.write(snippet)
- else:
- print(" %s unmodified" % ipy)
- else:
- print(" %s doesn't exist, ok" % ipy)
- ipy = None
-
- # Make VCS-specific changes. For git, this means creating/changing
- # .gitattributes to mark _version.py for export-subst keyword
- # substitution.
- do_vcs_install(cfg.versionfile_source, ipy)
- return 0
-
-
-def scan_setup_py():
- """Validate the contents of setup.py against Versioneer's expectations."""
- found = set()
- setters = False
- errors = 0
- with open("setup.py", "r") as f:
- for line in f.readlines():
- if "import versioneer" in line:
- found.add("import")
- if "versioneer.get_cmdclass()" in line:
- found.add("cmdclass")
- if "versioneer.get_version()" in line:
- found.add("get_version")
- if "versioneer.VCS" in line:
- setters = True
- if "versioneer.versionfile_source" in line:
- setters = True
- if len(found) != 3:
- print("")
- print("Your setup.py appears to be missing some important items")
- print("(but I might be wrong). Please make sure it has something")
- print("roughly like the following:")
- print("")
- print(" import versioneer")
- print(" setup( version=versioneer.get_version(),")
- print(" cmdclass=versioneer.get_cmdclass(), ...)")
- print("")
- errors += 1
- if setters:
- print("You should remove lines like 'versioneer.VCS = ' and")
- print("'versioneer.versionfile_source = ' . This configuration")
- print("now lives in setup.cfg, and should be removed from setup.py")
- print("")
- errors += 1
- return errors
-
-
-if __name__ == "__main__":
- cmd = sys.argv[1]
- if cmd == "setup":
- errors = do_setup()
- errors += scan_setup_py()
- if errors:
- sys.exit(1)