diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0b794a5..924a954 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,7 +21,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install -U pytest pytest-cov pyzmq netifaces donfig + pip install -U pytest pytest-cov pyzmq netifaces-plus donfig pytest-reraise - name: Install posttroll run: | pip install --no-deps -e . diff --git a/bin/nameserver b/bin/nameserver index d2d3370..8aa89e6 100755 --- a/bin/nameserver +++ b/bin/nameserver @@ -20,19 +20,17 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -"""The nameserver. Port 5555 (hardcoded) is used for communications. -""" +"""The nameserver. Port 5555 (hardcoded) is used for communications.""" # TODO: make port configurable. -from posttroll.ns import NameServer - import logging -import _strptime + +from posttroll.ns import NameServer logger = logging.getLogger(__name__) -if __name__ == '__main__': +if __name__ == "__main__": import argparse @@ -58,14 +56,14 @@ if __name__ == '__main__': handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("[%(levelname)s: %(asctime)s :" " %(name)s] %(message)s", - '%Y-%m-%d %H:%M:%S')) + "%Y-%m-%d %H:%M:%S")) if opts.verbose: loglevel = logging.DEBUG else: loglevel = logging.INFO handler.setLevel(loglevel) - logging.getLogger('').setLevel(loglevel) - logging.getLogger('').addHandler(handler) + logging.getLogger("").setLevel(loglevel) + logging.getLogger("").addHandler(handler) logger = logging.getLogger("nameserver") multicast_enabled = (opts.no_multicast == False) diff --git a/doc/Makefile b/doc/Makefile index bd4a009..d0c3cbf 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -1,130 +1,20 @@ -# Makefile for Sphinx documentation +# Minimal makefile for Sphinx documentation # -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source BUILDDIR = build -# Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source - -.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest - +# Put it first so that "make" without argument is like "make help". help: - @echo "Please use \`make ' where is one of" - @echo " html to make standalone HTML files" - @echo " dirhtml to make HTML files named index.html in directories" - @echo " singlehtml to make a single large HTML file" - @echo " pickle to make pickle files" - @echo " json to make JSON files" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " qthelp to make HTML files and a qthelp project" - @echo " devhelp to make HTML files and a Devhelp project" - @echo " epub to make an epub" - @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" - @echo " latexpdf to make LaTeX files and run them through pdflatex" - @echo " text to make text files" - @echo " man to make manual pages" - @echo " changes to make an overview of all changed/added/deprecated items" - @echo " linkcheck to check all external links for integrity" - @echo " doctest to run all doctests embedded in the documentation (if enabled)" - -clean: - -rm -rf $(BUILDDIR)/* - -html: - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." - -dirhtml: - $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." - -singlehtml: - $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml - @echo - @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." - -pickle: - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle - @echo - @echo "Build finished; now you can process the pickle files." - -json: - $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json - @echo - @echo "Build finished; now you can process the JSON files." - -htmlhelp: - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in $(BUILDDIR)/htmlhelp." - -qthelp: - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp - @echo - @echo "Build finished; now you can run "qcollectiongenerator" with the" \ - ".qhcp project file in $(BUILDDIR)/qthelp, like this:" - @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/PostTroll.qhcp" - @echo "To view the help file:" - @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/PostTroll.qhc" - -devhelp: - $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp - @echo - @echo "Build finished." - @echo "To view the help file:" - @echo "# mkdir -p $$HOME/.local/share/devhelp/PostTroll" - @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/PostTroll" - @echo "# devhelp" - -epub: - $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub - @echo - @echo "Build finished. The epub file is in $(BUILDDIR)/epub." - -latex: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo - @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." - @echo "Run \`make' in that directory to run these through (pdf)latex" \ - "(use \`make latexpdf' here to do that automatically)." - -latexpdf: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through pdflatex..." - make -C $(BUILDDIR)/latex all-pdf - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -text: - $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text - @echo - @echo "Build finished. The text files are in $(BUILDDIR)/text." - -man: - $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man - @echo - @echo "Build finished. The manual pages are in $(BUILDDIR)/man." - -changes: - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes - @echo - @echo "The overview file is in $(BUILDDIR)/changes." + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -linkcheck: - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck - @echo - @echo "Link check complete; look for any errors in the above output " \ - "or in $(BUILDDIR)/linkcheck/output.txt." +.PHONY: help Makefile -doctest: - $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest - @echo "Testing of doctests in the sources finished, look at the " \ - "results in $(BUILDDIR)/doctest/output.txt." +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/doc/requirements.txt b/doc/requirements.txt index 9c558e3..91f8e5d 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1 +1,2 @@ +sphinx-rtd-theme . diff --git a/doc/source/conf.py b/doc/source/conf.py index 5593311..08ac172 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -1,247 +1,30 @@ -# -*- coding: utf-8 -*- +# Configuration file for the Sphinx documentation builder. # -# PostTroll documentation build configuration file, created by -# sphinx-quickstart on Tue Sep 11 12:58:14 2012. -# -# This file is execfile()d with the current directory set to its containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -import sys -import os -from posttroll import __version__ -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) -sys.path.insert(0, os.path.abspath('../../')) -sys.path.insert(0, os.path.abspath('../../posttroll')) - - - -class Mock(object): - """A mocking class - """ - def __init__(self, *args, **kwargs): - pass - - def __call__(self, *args, **kwargs): - return Mock() - - @classmethod - def __getattr__(cls, name): - if name in ('__file__', '__path__'): - return '/dev/null' - elif name[0] == name[0].upper(): - mock_type = type(name, (), {}) - mock_type.__module__ = __name__ - return mock_type - else: - return Mock() - -MOCK_MODULES = ['zmq'] -for mod_name in MOCK_MODULES: - sys.modules[mod_name] = Mock() - -# -- General configuration ----------------------------------------------------- - -# If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be extensions -# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest'] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['sphinx_templates'] - -# The suffix of source filenames. -source_suffix = '.rst' +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html -# The encoding of source files. -#source_encoding = 'utf-8-sig' +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information +from posttroll.version import version -# The master toctree document. -master_doc = 'index' +project = "Posttroll" +copyright = "2012, Pytroll Crew" +author = "Pytroll Crew" +release = version -# General information about the project. -project = u'PostTroll' -copyright = u'2012-2014, Pytroll crew' +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# - - -# The full version, including alpha/beta/rc tags. -release = __version__ -# The short X.Y version. -version = ".".join(release.split(".")[:2]) - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -#language = None +extensions = ["sphinx.ext.napoleon", "sphinx.ext.autodoc"] +autodoc_mock_imports = ["pyzmq"] -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -#today = '' -# Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. +templates_path = ["_templates"] exclude_patterns = [] -# The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -#add_module_names = True - -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -#show_authors = False - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - -# A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] - - -# -- Options for HTML output --------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -html_theme = 'default' - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -#html_theme_options = {} - -# Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -#html_title = None - -# A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. -#html_logo = None - -# The name of an image file (within the static path) to use as favicon of the -# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. -#html_favicon = None - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['sphinx_static'] - -# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, -# using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -#html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -#html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -#html_additional_pages = {} - -# If false, no module index is generated. -#html_domain_indices = True - -# If false, no index is generated. -#html_use_index = True - -# If true, the index is split into individual pages for each letter. -#html_split_index = False - -# If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -#html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None - -# Output file base name for HTML help builder. -htmlhelp_basename = 'PostTrolldoc' - - -# -- Options for LaTeX output -------------------------------------------------- - -# The paper size ('letter' or 'a4'). -#latex_paper_size = 'letter' - -# The font size ('10pt', '11pt' or '12pt'). -#latex_font_size = '10pt' - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, author, documentclass [howto/manual]). -latex_documents = [ - ('index', 'PostTroll.tex', u'PostTroll Documentation', - u'Pytroll crew', 'manual'), -] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -#latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -#latex_use_parts = False - -# If true, show page references after internal links. -#latex_show_pagerefs = False - -# If true, show URL addresses after external links. -#latex_show_urls = False - -# Additional stuff for the LaTeX preamble. -#latex_preamble = '' - -# Documents to append as an appendix to all manuals. -#latex_appendices = [] - -# If false, no module index is generated. -#latex_domain_indices = True -# -- Options for manual page output -------------------------------------------- +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'posttroll', u'PostTroll Documentation', - [u'Pytroll crew'], 1) -] +html_theme = "sphinx_rtd_theme" +html_static_path = ["_static"] diff --git a/doc/source/index.rst b/doc/source/index.rst index e66bba3..7ff0e12 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -115,6 +115,17 @@ to specify the nameserver(s) explicitly in the publishing code:: .. seealso:: :class:`posttroll.publisher.Publish` and :class:`posttroll.subscriber.Subscribe` +Configuration parameters +------------------------ + +Global configuration variables that are available through a Donfig configuration object: +- tcp_keepalive +- tcp_keepalive_cnt +- tcp_keepalive_idle +- tcp_keepalive_intvl +- multicast_interface +- mc_group + Setting TCP keep-alive ---------------------- @@ -139,6 +150,55 @@ relevant socket options. .. _zmq_setsockopts: http://api.zeromq.org/master:zmq-setsockopt +Using secure ZeroMQ backend +--------------------------- + +To use securely authenticated sockets with posttroll (uses ZMQ's curve authentication), the backend needs to be defined +through posttroll config system, for example using an environment variable:: + + POSTTROLL_BACKEND=secure_zmq + +On the server side (for example a publisher), we need to define the server's secret key and the directory where the +accepted client keys are provided:: + + POSTTROLL_SERVER_SECRET_KEY_FILE=/path/to/server.key_secret + POSTTROLL_PUBLIC_SECRET_KEYS_DIRECTORY=/path/to/client_public_keys/ + +On the client side (for example a subscriber), we need to define the server's public key file and the client's secret +key file:: + + POSTTROLL_CLIENT_SECRET_KEY_FILE=/path/to/client.key_secret + POSTTROLL_SERVER_PUBLIC_KEY_FILE=/path/to/server.key + +These settings can also be set using the posttroll config object, for example:: + + >>> from posttroll import config + >>> with config.set(backend="secure_zmq", server_pubic_key_file="..."): + ... + +The posttroll configuration uses donfig, for more information, check https://donfig.readthedocs.io/en/latest/. + + +Generating the public and secret key pairs +****************************************** + +In order for the secure ZMQ backend to work, public/secret key pairs need to be generated, one for the client side and +one for the server side. A command-line script is provided for this purpose:: + + > posttroll-generate-keys -h + usage: posttroll-generate-keys [-h] [-d DIRECTORY] name + + Create a public/secret key pair for the secure zmq backend. This will create two files (in the current directory if not otherwise specified) with the suffixes '.key' and '.key_secret'. The name of the files will be the one provided. + + positional arguments: + name Name of the file. + + options: + -h, --help show this help message and exit + -d DIRECTORY, --directory DIRECTORY + Directory to place the keys in. + + Converting from older posttroll versions ---------------------------------------- @@ -229,15 +289,6 @@ Multicast code :members: :undoc-members: - -Connections -~~~~~~~~~~~ - -.. automodule:: posttroll.connections - :members: - :undoc-members: - - Misc ~~~~ diff --git a/posttroll/__init__.py b/posttroll/__init__.py index d812970..df053e3 100644 --- a/posttroll/__init__.py +++ b/posttroll/__init__.py @@ -26,16 +26,12 @@ import datetime as dt import logging -import os import sys -import zmq from donfig import Config -from .version import get_versions - -config = Config('posttroll') -context = {} +config = Config("posttroll", defaults=[dict(backend="unsecure_zmq")]) +# context = {} logger = logging.getLogger(__name__) @@ -44,11 +40,12 @@ def get_context(): This function takes care of creating new contexts in case of forks. """ - pid = os.getpid() - if pid not in context: - context[pid] = zmq.Context() - logger.debug('renewed context for PID %d', pid) - return context[pid] + backend = config["backend"] + if "zmq" in backend: + from posttroll.backends.zmq import get_context + return get_context() + else: + raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") def strp_isoformat(strg): @@ -62,30 +59,14 @@ def strp_isoformat(strg): return strg if len(strg) < 19 or len(strg) > 26: if len(strg) > 30: - strg = strg[:30] + '...' + strg = strg[:30] + "..." raise ValueError("Invalid ISO formatted time string '%s'" % strg) if strg.find(".") == -1: - strg += '.000000' - if sys.version[0:3] >= '2.6': + strg += ".000000" + if sys.version[0:3] >= "2.6": return dt.datetime.strptime(strg, "%Y-%m-%dT%H:%M:%S.%f") else: dat, mis = strg.split(".") dat = dt.datetime.strptime(dat, "%Y-%m-%dT%H:%M:%S") - mis = int(float('.' + mis)*1000000) + mis = int(float("." + mis) * 1000000) return dat.replace(microsecond=mis) - - -def _set_tcp_keepalive(socket): - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE, config.get("tcp_keepalive", None)) - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_CNT, config.get("tcp_keepalive_cnt", None)) - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_IDLE, config.get("tcp_keepalive_idle", None)) - _set_int_sockopt(socket, zmq.TCP_KEEPALIVE_INTVL, config.get("tcp_keepalive_intvl", None)) - - -def _set_int_sockopt(socket, param, value): - if value is not None: - socket.setsockopt(param, int(value)) - - -__version__ = get_versions()['version'] -del get_versions diff --git a/posttroll/address_receiver.py b/posttroll/address_receiver.py index dc8076b..42c0d4c 100644 --- a/posttroll/address_receiver.py +++ b/posttroll/address_receiver.py @@ -35,27 +35,29 @@ import time import netifaces -from zmq import REP, LINGER +from zmq import ZMQError -from posttroll.bbmcast import MulticastReceiver, SocketTimeout +from posttroll import config +from posttroll.bbmcast import MulticastReceiver, get_configured_broadcast_port from posttroll.message import Message from posttroll.publisher import Publish -from posttroll import get_context - -__all__ = ('AddressReceiver', 'getaddress') +__all__ = ("AddressReceiver", "getaddress") LOGGER = logging.getLogger(__name__) -debug = os.environ.get('DEBUG', False) -broadcast_port = 21200 +debug = os.environ.get("DEBUG", False) -default_publish_port = 16543 +DEFAULT_ADDRESS_PUBLISH_PORT = 16543 ten_minutes = dt.timedelta(minutes=10) zero_seconds = dt.timedelta(seconds=0) +def get_configured_address_port(): + return config.get("address_publish_port", DEFAULT_ADDRESS_PUBLISH_PORT) + + def get_local_ips(): """Get local IP addresses.""" inet_addrs = [netifaces.ifaddresses(iface).get(netifaces.AF_INET) @@ -64,7 +66,7 @@ def get_local_ips(): for addr in inet_addrs: if addr is not None: for add in addr: - ips.append(add['addr']) + ips.append(add["addr"]) return ips # ----------------------------------------------------------------------------- @@ -74,17 +76,17 @@ def get_local_ips(): # ----------------------------------------------------------------------------- -class AddressReceiver(object): +class AddressReceiver: """General thread to receive broadcast addresses.""" def __init__(self, max_age=ten_minutes, port=None, do_heartbeat=True, multicast_enabled=True, restrict_to_localhost=False): - """Initialize addres receiver.""" + """Set up the address receiver.""" self._max_age = max_age - self._port = port or default_publish_port + self._port = port or get_configured_address_port() self._address_lock = threading.Lock() self._addresses = {} - self._subject = '/address' + self._subject = "/address" self._do_heartbeat = do_heartbeat self._multicast_enabled = multicast_enabled self._last_age_check = dt.datetime(1900, 1, 1) @@ -121,7 +123,7 @@ def get(self, name=""): mda = copy.copy(metadata) mda["receive_time"] = mda["receive_time"].isoformat() addrs.append(mda) - LOGGER.debug('return address %s', str(addrs)) + LOGGER.debug("return address %s", str(addrs)) return addrs def _check_age(self, pub, min_interval=zero_seconds): @@ -137,41 +139,20 @@ def _check_age(self, pub, min_interval=zero_seconds): for addr, metadata in self._addresses.items(): atime = metadata["receive_time"] if now - atime > self._max_age: - mda = {'status': False, - 'URI': addr, - 'service': metadata['service']} - msg = Message('/address/' + metadata['name'], 'info', mda) + mda = {"status": False, + "URI": addr, + "service": metadata["service"]} + msg = Message("/address/" + metadata["name"], "info", mda) to_del.append(addr) - LOGGER.info("publish remove '%s'", str(msg)) - pub.send(msg.encode()) + LOGGER.info(f"publish remove '{msg}'") + pub.send(str(msg.encode())) for addr in to_del: del self._addresses[addr] def _run(self): """Run the receiver.""" - port = broadcast_port - nameservers = [] - if self._multicast_enabled: - while True: - try: - recv = MulticastReceiver(port) - except IOError as err: - if err.errno == errno.ENODEV: - LOGGER.error("Receiver initialization failed " - "(no such device). " - "Trying again in %d s", - 10) - time.sleep(10) - else: - raise - else: - recv.settimeout(tout=2.0) - LOGGER.info("Receiver initialized.") - break - - else: - recv = _SimpleReceiver(port) - nameservers = ["localhost"] + port = get_configured_broadcast_port() + nameservers, recv = self.set_up_address_receiver(port) self._is_running = True with Publish("address_receiver", self._port, ["addresses"], @@ -180,30 +161,35 @@ def _run(self): while self._do_run: try: data, fromaddr = recv() - if self._multicast_enabled: - ip_, port = fromaddr - if self._restrict_to_localhost and ip_ not in self._local_ips: - # discard external message - LOGGER.debug('Discard external message') - continue - LOGGER.debug("data %s", data) - except SocketTimeout: - if self._multicast_enabled: - LOGGER.debug("Multicast socket timed out on recv!") + except TimeoutError: + if self._do_run: + if self._multicast_enabled: + LOGGER.debug("Multicast socket timed out on recv!") continue + else: + raise + except ZMQError: + return finally: self._check_age(pub, min_interval=self._max_age / 20) if self._do_heartbeat: pub.heartbeat(min_interval=29) + if self._multicast_enabled: + ip_, port = fromaddr + if self._restrict_to_localhost and ip_ not in self._local_ips: + # discard external message + LOGGER.debug("Discard external message") + continue + LOGGER.debug("data %s", data) msg = Message.decode(data) name = msg.subject.split("/")[1] - if msg.type == 'info' and msg.subject.lower().startswith(self._subject): + if msg.type == "info" and msg.subject.lower().startswith(self._subject): addr = msg.data["URI"] - msg.data['status'] = True + msg.data["status"] = True metadata = copy.copy(msg.data) metadata["name"] = name - LOGGER.debug('receiving address %s %s %s', str(addr), + LOGGER.debug("receiving address %s %s %s", str(addr), str(name), str(metadata)) if addr not in self._addresses: LOGGER.info("nameserver: publish add '%s'", @@ -214,6 +200,35 @@ def _run(self): self._is_running = False recv.close() + def set_up_address_receiver(self, port): + """Set up the address receiver depending on if it is multicast or not.""" + nameservers = False + if self._multicast_enabled: + while True: + try: + recv = MulticastReceiver(port) + except IOError as err: + if err.errno == errno.ENODEV: + LOGGER.error("Receiver initialization failed " + "(no such device). " + "Trying again in %d s", + 10) + time.sleep(10) + else: + raise + else: + recv.settimeout(tout=2.0) + LOGGER.info("Receiver initialized.") + break + + else: + if config["backend"] not in ["unsecure_zmq", "secure_zmq"]: + raise NotImplementedError + from posttroll.backends.zmq.address_receiver import SimpleReceiver + recv = SimpleReceiver(port, timeout=2) + nameservers = ["localhost"] + return nameservers, recv + def _add(self, adr, metadata): """Add an address.""" with self._address_lock: @@ -221,27 +236,6 @@ def _add(self, adr, metadata): self._addresses[adr] = metadata -class _SimpleReceiver(object): - """Simple listing on port for address messages.""" - - def __init__(self, port=None): - """Initialize receiver.""" - self._port = port or default_publish_port - self._socket = get_context().socket(REP) - self._socket.bind("tcp://*:" + str(port)) - - def __call__(self): - """Receive and return a message.""" - message = self._socket.recv_string() - self._socket.send_string("ok") - return message, None - - def close(self): - """Close the receiver.""" - self._socket.setsockopt(LINGER, 1) - self._socket.close() - - # ----------------------------------------------------------------------------- # default getaddress = AddressReceiver diff --git a/posttroll/backends/__init__.py b/posttroll/backends/__init__.py new file mode 100644 index 0000000..982ba70 --- /dev/null +++ b/posttroll/backends/__init__.py @@ -0,0 +1 @@ +"""Init file for the backends.""" diff --git a/posttroll/backends/zmq/__init__.py b/posttroll/backends/zmq/__init__.py new file mode 100644 index 0000000..c943737 --- /dev/null +++ b/posttroll/backends/zmq/__init__.py @@ -0,0 +1,70 @@ +"""Main module for the zmq backend.""" +import argparse +import logging +import os +from pathlib import Path + +import zmq +from zmq.auth.certs import create_certificates + +from posttroll import config + +logger = logging.getLogger(__name__) +context = {} + + +def get_context(): + """Provide the context to use. + + This function takes care of creating new contexts in case of forks. + """ + pid = os.getpid() + if pid not in context: + context[pid] = zmq.Context() + logger.debug("renewed context for PID %d", pid) + return context[pid] + + +def destroy_context(linger=None): + """Destroy the context.""" + pid = os.getpid() + context.pop(pid).destroy(linger) + + +def _set_tcp_keepalive(socket): + """Set the tcp keepalive parameters on *socket*.""" + keepalive_options = get_tcp_keepalive_options() + for param, value in keepalive_options.items(): + socket.setsockopt(param, value) + + +def get_tcp_keepalive_options(): + """Get the tcp_keepalive options from config.""" + keepalive_options = dict() + for opt in ("tcp_keepalive", + "tcp_keepalive_cnt", + "tcp_keepalive_idle", + "tcp_keepalive_intvl"): + try: + value = int(config[opt]) + except (KeyError, TypeError): + continue + param = getattr(zmq, opt.upper()) + keepalive_options[param] = value + return keepalive_options + + +def generate_keys(args=None): + """Generate a public/secret key pair.""" + parser = argparse.ArgumentParser( + prog="posttroll-generate-keys", + description=("Create a public/secret key pair for the secure zmq backend. This will create two " + "files (in the current directory if not otherwise specified) with the suffixes '.key'" + " and '.key_secret'. The name of the files will be the one provided.")) + + parser.add_argument("name", type=str, help="Name of the file.") + parser.add_argument("-d", "--directory", help="Directory to place the keys in.", default=".", type=Path) + + parsed = parser.parse_args(args) + + create_certificates(parsed.directory, parsed.name) diff --git a/posttroll/backends/zmq/address_receiver.py b/posttroll/backends/zmq/address_receiver.py new file mode 100644 index 0000000..ef58dfa --- /dev/null +++ b/posttroll/backends/zmq/address_receiver.py @@ -0,0 +1,36 @@ +"""ZMQ implementation of the the simple receiver.""" + +from zmq import REP + +from posttroll.address_receiver import get_configured_address_port +from posttroll.backends.zmq.socket import close_socket, set_up_server_socket + + +class SimpleReceiver(object): + """Simple listing on port for address messages.""" + + def __init__(self, port=None, timeout=2): + """Set up the receiver.""" + self._port = port or get_configured_address_port() + address = "tcp://*:" + str(port) + self._socket, _, self._authenticator = set_up_server_socket(REP, address) + self._running = True + self.timeout = timeout + + def __call__(self): + """Receive a message.""" + while self._running: + try: + message = self._socket.recv_string(self.timeout) + except TimeoutError: + continue + else: + self._socket.send_string("ok") + return message, None + + def close(self): + """Close the receiver.""" + self._running = False + close_socket(self._socket) + if self._authenticator: + self._authenticator.stop() diff --git a/posttroll/backends/zmq/message_broadcaster.py b/posttroll/backends/zmq/message_broadcaster.py new file mode 100644 index 0000000..238d9eb --- /dev/null +++ b/posttroll/backends/zmq/message_broadcaster.py @@ -0,0 +1,54 @@ +"""Message broadcaster implementation using zmq.""" + +import logging +import threading + +from zmq import LINGER, NOBLOCK, REQ, ZMQError + +from posttroll.backends.zmq.socket import close_socket, set_up_client_socket + +logger = logging.getLogger(__name__) + + +class ZMQDesignatedReceiversSender: + """Sends message to multiple *receivers* on *port*.""" + + def __init__(self, default_port, receivers): + """Set up the sender.""" + self.default_port = default_port + self.receivers = receivers + self._shutdown_event = threading.Event() + + def __call__(self, data): + """Send data.""" + for receiver in self.receivers: + self._send_to_address(receiver, data) + + def _send_to_address(self, address, data, timeout=10): + """Send data to *address* and *port* without verification of response.""" + # Socket to talk to server + if address.find(":") == -1: + full_address = "tcp://%s:%d" % (address, self.default_port) + else: + full_address = "tcp://%s" % address + options = {LINGER: int(timeout * 1000)} + socket = set_up_client_socket(REQ, full_address, options) + try: + + socket.send_string(data) + while not self._shutdown_event.is_set(): + try: + message = socket.recv_string(NOBLOCK) + except ZMQError: + self._shutdown_event.wait(.1) + continue + if message != "ok": + logger.warning("invalid acknowledge received: %s" % message) + break + + finally: + close_socket(socket) + + def close(self): + """Close the sender.""" + self._shutdown_event.set() diff --git a/posttroll/backends/zmq/ns.py b/posttroll/backends/zmq/ns.py new file mode 100644 index 0000000..dc0fcfb --- /dev/null +++ b/posttroll/backends/zmq/ns.py @@ -0,0 +1,108 @@ +"""ZMQ implexentation of ns.""" + +import logging +from contextlib import suppress +from threading import Lock + +from zmq import LINGER, REP, REQ + +from posttroll.backends.zmq.socket import SocketReceiver, close_socket, set_up_client_socket, set_up_server_socket +from posttroll.message import Message +from posttroll.ns import get_active_address, get_configured_nameserver_port + +logger = logging.getLogger("__name__") + +nslock = Lock() + + +def zmq_get_pub_address(name, timeout=10, nameserver="localhost"): + """Get the address of the publisher. + + For a given publisher *name* from the nameserver on *nameserver* (localhost by default). + """ + nameserver_address = create_nameserver_address(nameserver) + # Socket to talk to server + logger.debug(f"Connecting to {nameserver_address}") + socket = create_req_socket(timeout, nameserver_address) + return _fetch_address_using_socket(socket, name, timeout) + + +def create_nameserver_address(nameserver): + """Create the nameserver address.""" + port = get_configured_nameserver_port() + nameserver_address = "tcp://" + nameserver + ":" + str(port) + return nameserver_address + + +def _fetch_address_using_socket(socket, name, timeout): + try: + socket_receiver = SocketReceiver() + socket_receiver.register(socket) + + message = Message("/oper/ns", "request", {"service": name}) + socket.send_string(str(message)) + + # Get the reply. + for message, _ in socket_receiver.receive(socket, timeout=timeout): + return message.data + except TimeoutError: + raise TimeoutError("Didn't get an address after %d seconds." + % timeout) + finally: + socket_receiver.unregister(socket) + close_socket(socket) + + +def create_req_socket(timeout, nameserver_address): + """Create a REQ socket.""" + options = {LINGER: int(timeout * 1000)} + socket = set_up_client_socket(REQ, nameserver_address, options) + return socket + + +class ZMQNameServer: + """The name server.""" + + def __init__(self): + """Set up the nameserver.""" + self.running = True + self.listener = None + + def run(self, address_receiver): + """Run the listener and answer to requests.""" + port = get_configured_nameserver_port() + + try: + # stop was called before we could start running, exit + if not self.running: + return + address = "tcp://*:" + str(port) + self.listener, _, self._authenticator = set_up_server_socket(REP, address) + logger.debug(f"Nameserver listening on port {port}") + socket_receiver = SocketReceiver() + socket_receiver.register(self.listener) + while self.running: + try: + for msg, _ in socket_receiver.receive(self.listener, timeout=1): + logger.debug("Replying to request: " + str(msg)) + active_address = get_active_address(msg.data["service"], address_receiver) + self.listener.send_unicode(str(active_address)) + except TimeoutError: + continue + except KeyboardInterrupt: + # Needed to stop the nameserver. + pass + finally: + socket_receiver.unregister(self.listener) + self.close_sockets_and_threads() + + def close_sockets_and_threads(self): + """Close all sockets and threads.""" + with suppress(AttributeError): + close_socket(self.listener) + with suppress(AttributeError): + self._authenticator.stop() + + def stop(self): + """Stop the name server.""" + self.running = False diff --git a/posttroll/backends/zmq/publisher.py b/posttroll/backends/zmq/publisher.py new file mode 100644 index 0000000..8d2bec5 --- /dev/null +++ b/posttroll/backends/zmq/publisher.py @@ -0,0 +1,59 @@ +"""ZMQ implementation of the publisher.""" + +import logging +from contextlib import suppress +from threading import Lock + +import zmq + +from posttroll.backends.zmq import get_tcp_keepalive_options +from posttroll.backends.zmq.socket import close_socket, set_up_server_socket + +LOGGER = logging.getLogger(__name__) + + +class ZMQPublisher: + """Unsecure ZMQ implementation of the publisher class.""" + + def __init__(self, address, name="", min_port=None, max_port=None): + """Set up the publisher. + + Args: + address: the address to connect to. + name: the name of this publishing service. + min_port: the minimal port number to use. + max_port: the maximal port number to use. + + """ + self.name = name + self.destination = address + self.publish_socket = None + self.min_port = min_port + self.max_port = max_port + self.port_number = None + self._pub_lock = Lock() + self._authenticator = None + + def start(self): + """Start the publisher.""" + self._create_socket() + LOGGER.info(f"Publisher for {self.destination} started on port {self.port_number}.") + + return self + + def _create_socket(self): + options = get_tcp_keepalive_options() + self.publish_socket, port, self._authenticator = set_up_server_socket(zmq.PUB, self.destination, options, + (self.min_port, self.max_port)) + self.port_number = port + + def send(self, msg): + """Send the given message.""" + with self._pub_lock: + self.publish_socket.send_string(msg) + + def stop(self): + """Stop the publisher.""" + close_socket(self.publish_socket) + with suppress(AttributeError): + self._authenticator.stop() diff --git a/posttroll/backends/zmq/socket.py b/posttroll/backends/zmq/socket.py new file mode 100644 index 0000000..7adb295 --- /dev/null +++ b/posttroll/backends/zmq/socket.py @@ -0,0 +1,155 @@ +"""ZMQ socket handling functions.""" + +from urllib.parse import urlsplit, urlunsplit + +import zmq +from zmq.auth.thread import ThreadAuthenticator + +from posttroll import config, get_context +from posttroll.message import Message + + +def close_socket(sock): + """Close a zmq socket.""" + sock.setsockopt(zmq.LINGER, 1) + sock.close() + + +def set_up_client_socket(socket_type, address, options=None): + """Set up a client (connecting) zmq socket.""" + backend = config["backend"] + if backend == "unsecure_zmq": + sock = create_unsecure_client_socket(socket_type) + elif backend == "secure_zmq": + sock = create_secure_client_socket(socket_type) + add_options(sock, options) + sock.connect(address) + return sock + + +def create_unsecure_client_socket(socket_type): + """Create an unsecure client socket.""" + return get_context().socket(socket_type) + + +def add_options(sock, options=None): + """Add options to a socket.""" + if not options: + return + for param, val in options.items(): + sock.setsockopt(param, val) + + +def create_secure_client_socket(socket_type): + """Create a secure client socket.""" + subscriber = get_context().socket(socket_type) + + client_secret_key_file = config["client_secret_key_file"] + server_public_key_file = config["server_public_key_file"] + client_public, client_secret = zmq.auth.load_certificate(client_secret_key_file) + subscriber.curve_secretkey = client_secret + subscriber.curve_publickey = client_public + + server_public, _ = zmq.auth.load_certificate(server_public_key_file) + # The client must know the server's public key to make a CURVE connection. + subscriber.curve_serverkey = server_public + return subscriber + + +def set_up_server_socket(socket_type, destination, options=None, port_interval=(None, None)): + """Set up a server (binding) socket.""" + if options is None: + options = {} + backend = config["backend"] + if backend == "unsecure_zmq": + sock = create_unsecure_server_socket(socket_type) + authenticator = None + elif backend == "secure_zmq": + sock, authenticator = create_secure_server_socket(socket_type) + + add_options(sock, options) + + port = bind(sock, destination, port_interval) + return sock, port, authenticator + + +def create_unsecure_server_socket(socket_type): + """Create an unsecure server socket.""" + return get_context().socket(socket_type) + + +def bind(sock, destination, port_interval): + """Bind the socket to a destination. + + If a random port is to be chosen, the port_interval is used. + """ + # Check for port 0 (random port) + min_port, max_port = port_interval + u__ = urlsplit(destination) + port = u__.port + if port == 0: + dest = urlunsplit((u__.scheme, u__.hostname, + u__.path, u__.query, u__.fragment)) + port_number = sock.bind_to_random_port(dest, + min_port=min_port, + max_port=max_port) + netloc = u__.hostname + ":" + str(port_number) + destination = urlunsplit((u__.scheme, netloc, u__.path, + u__.query, u__.fragment)) + else: + sock.bind(destination) + port_number = port + return port_number + + +def create_secure_server_socket(socket_type): + """Create a secure server socket.""" + server_secret_key = config["server_secret_key_file"] + clients_public_keys_directory = config["clients_public_keys_directory"] + authorized_sub_addresses = config.get("authorized_client_addresses", []) + + ctx = get_context() + + # Start an authenticator for this context. + authenticator_thread = ThreadAuthenticator(ctx) + authenticator_thread.start() + authenticator_thread.allow(*authorized_sub_addresses) + # Tell authenticator to use the certificate in a directory + authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory) + + server_socket = ctx.socket(socket_type) + + server_public, server_secret = zmq.auth.load_certificate(server_secret_key) + server_socket.curve_secretkey = server_secret + server_socket.curve_publickey = server_public + server_socket.curve_server = True + return server_socket, authenticator_thread + + +class SocketReceiver: + """A receiver for mulitple sockets.""" + + def __init__(self): + """Set up the receiver.""" + self._poller = zmq.Poller() + + def register(self, socket): + """Register the socket.""" + self._poller.register(socket, zmq.POLLIN) + + def unregister(self, socket): + """Unregister the socket.""" + self._poller.unregister(socket) + + def receive(self, *sockets, timeout=None): + """Timeout is in seconds.""" + if timeout: + timeout *= 1000 + socks = dict(self._poller.poll(timeout=timeout)) + if socks: + for sock in sockets: + if socks.get(sock) == zmq.POLLIN: + received = sock.recv_string(zmq.NOBLOCK) + yield Message.decode(received), sock + else: + raise TimeoutError("Did not receive anything on sockets.") diff --git a/posttroll/backends/zmq/subscriber.py b/posttroll/backends/zmq/subscriber.py new file mode 100644 index 0000000..836f590 --- /dev/null +++ b/posttroll/backends/zmq/subscriber.py @@ -0,0 +1,209 @@ +"""ZMQ implementation of the subscriber.""" + +import logging +from threading import Lock +from time import sleep +from urllib.parse import urlsplit + +from zmq import PULL, SUB, SUBSCRIBE, ZMQError + +from posttroll.backends.zmq import get_tcp_keepalive_options +from posttroll.backends.zmq.socket import SocketReceiver, close_socket, set_up_client_socket + +LOGGER = logging.getLogger(__name__) + + +class ZMQSubscriber: + """A ZMQ subscriber class.""" + + def __init__(self, addresses, topics="", message_filter=None, translate=False): + """Initialize the subscriber.""" + self._topics = topics + self._filter = message_filter + self._translate = translate + + self.sub_addr = {} + self.addr_sub = {} + + self._hooks = [] + self._hooks_cb = {} + + self._sock_receiver = SocketReceiver() + self._lock = Lock() + + self.update(addresses) + + self._loop = None + + @property + def running(self): + """Check if suscriber is running.""" + return self._loop + + def add(self, address, topics=None): + """Add *address* to the subscribing list for *topics*. + + It topics is None we will subscribe to already specified topics. + """ + with self._lock: + if address in self.addresses: + return + + topics = topics or self._topics + LOGGER.info("Subscriber adding address %s with topics %s", + str(address), str(topics)) + subscriber = self._add_sub_socket(address, topics) + self.sub_addr[subscriber] = address + self.addr_sub[address] = subscriber + + def remove(self, address): + """Remove *address* from the subscribing list for *topics*.""" + with self._lock: + try: + subscriber = self.addr_sub[address] + except KeyError: + return + LOGGER.info("Subscriber removing address %s", str(address)) + del self.addr_sub[address] + del self.sub_addr[subscriber] + self._remove_sub_socket(subscriber) + + def _remove_sub_socket(self, subscriber): + if self._sock_receiver: + self._sock_receiver.unregister(subscriber) + subscriber.close() + + def update(self, addresses): + """Update with a set of addresses.""" + if isinstance(addresses, str): + addresses = [addresses, ] + current_addresses, new_addresses = set(self.addresses), set(addresses) + addresses_to_remove = current_addresses.difference(new_addresses) + addresses_to_add = new_addresses.difference(current_addresses) + for addr in addresses_to_remove: + self.remove(addr) + for addr in addresses_to_add: + self.add(addr) + return bool(addresses_to_remove or addresses_to_add) + + def add_hook_sub(self, address, topics, callback): + """Specify a SUB *callback* in the same stream (thread) as the main receive loop. + + The callback will be called with the received messages from the + specified subscription. + + Good for operations, which is required to be done in the same thread as + the main recieve loop (e.q operations on the underlying sockets). + """ + topics = topics + LOGGER.info("Subscriber adding SUB hook %s for topics %s", + str(address), str(topics)) + socket = self._add_sub_socket(address, topics) + self._add_hook(socket, callback) + + def add_hook_pull(self, address, callback): + """Specify a PULL *callback* in the same stream (thread) as the main receive loop. + + The callback will be called with the received messages from the + specified subscription. Good for pushed 'inproc' messages from another thread. + """ + LOGGER.info("Subscriber adding PULL hook %s", str(address)) + socket = self._create_socket(PULL, address) + if self._sock_receiver: + self._sock_receiver.register(socket) + self._add_hook(socket, callback) + + def _add_hook(self, socket, callback): + """Add a generic hook. The passed socket has to be "receive only".""" + self._hooks.append(socket) + self._hooks_cb[socket] = callback + + @property + def addresses(self): + """Get the addresses.""" + return self.sub_addr.values() + + @property + def subscribers(self): + """Get the subscribers.""" + return self.sub_addr.keys() + + def recv(self, timeout=None): + """Receive, optionally with *timeout* in seconds.""" + for sub in list(self.subscribers) + self._hooks: + self._sock_receiver.register(sub) + self._loop = True + try: + while self._loop: + sleep(0) + yield from self._new_messages(timeout) + finally: + for sub in list(self.subscribers) + self._hooks: + self._sock_receiver.unregister(sub) + # self.poller.unregister(sub) + + def _new_messages(self, timeout): + """Check for new messages to yield and pass to the callbacks.""" + all_subs = list(self.subscribers) + self._hooks + try: + for m__, sock in self._sock_receiver.receive(*all_subs, timeout=timeout): + if sock in self.subscribers: + if not self._filter or self._filter(m__): + if self._translate: + url = urlsplit(self.sub_addr[sock]) + host = url[1].split(":")[0] + m__.sender = (m__.sender.split("@")[0] + + "@" + host) + yield m__ + elif sock in self._hooks: + self._hooks_cb[sock](m__) + except TimeoutError: + yield None + except ZMQError as err: + if self._loop: + LOGGER.exception("Receive failed: %s", str(err)) + + def __call__(self, **kwargs): + """Handle calls with class instance.""" + return self.recv(**kwargs) + + def stop(self): + """Stop the subscriber.""" + self._loop = False + + def close(self): + """Close the subscriber: stop it and close the local subscribers.""" + self.stop() + for sub in list(self.subscribers) + self._hooks: + try: + close_socket(sub) + except ZMQError: + pass + + def __del__(self): + """Clean up after the instance is deleted.""" + for sub in list(self.subscribers) + self._hooks: + try: + close_socket(sub) + except Exception: # noqa: E722 + pass + + def _add_sub_socket(self, address, topics): + + options = get_tcp_keepalive_options() + + subscriber = self._create_socket(SUB, address, options) + add_subscriptions(subscriber, topics) + + if self._sock_receiver: + self._sock_receiver.register(subscriber) + return subscriber + + def _create_socket(self, socket_type, address, options): + return set_up_client_socket(socket_type, address, options) + + +def add_subscriptions(socket, topics): + """Add subscriptions to a socket.""" + for t__ in topics: + socket.setsockopt_string(SUBSCRIBE, str(t__)) diff --git a/posttroll/bbmcast.py b/posttroll/bbmcast.py index 3aa6b70..d9f3ae0 100644 --- a/posttroll/bbmcast.py +++ b/posttroll/bbmcast.py @@ -31,24 +31,51 @@ import logging import os import struct -from socket import (AF_INET, INADDR_ANY, IP_ADD_MEMBERSHIP, IP_MULTICAST_LOOP, - IP_MULTICAST_TTL, IPPROTO_IP, SO_BROADCAST, SO_REUSEADDR, - SOCK_DGRAM, SOL_IP, SOL_SOCKET, gethostbyname, socket, - timeout, SO_LINGER) - -__all__ = ('MulticastSender', 'MulticastReceiver', 'mcast_sender', - 'mcast_receiver', 'SocketTimeout') +import warnings +from contextlib import suppress +from socket import ( + AF_INET, + INADDR_ANY, + IP_ADD_MEMBERSHIP, + IP_MULTICAST_IF, + IP_MULTICAST_LOOP, + IP_MULTICAST_TTL, + IPPROTO_IP, + SO_BROADCAST, + SO_LINGER, + SO_REUSEADDR, + SOCK_DGRAM, + SOL_IP, + SOL_SOCKET, + gethostbyname, + inet_aton, + socket, + timeout, +) + +from posttroll import config + +__all__ = ("MulticastSender", "MulticastReceiver", "mcast_sender", + "mcast_receiver", "SocketTimeout") # 224.0.0.0 through 224.0.0.255 is reserved administrative tasks -MC_GROUP = os.environ.get('PYTROLL_MC_GROUP', '225.0.0.212') +DEFAULT_MC_GROUP = "225.0.0.212" # local network multicast (<32) -TTL_LOCALNET = int(os.environ.get('PYTROLL_MC_TTL', 31)) +TTL_LOCALNET = int(os.environ.get("PYTROLL_MC_TTL", 31)) logger = logging.getLogger(__name__) SocketTimeout = timeout # for easy access to socket.timeout +DEFAULT_BROADCAST_PORT = 21200 + + +def get_configured_broadcast_port(): + """Get the configured nameserver port.""" + return config.get("broadcast_port", DEFAULT_BROADCAST_PORT) + + # ----------------------------------------------------------------------------- # # Sender. @@ -56,15 +83,15 @@ # ----------------------------------------------------------------------------- -class MulticastSender(object): +class MulticastSender: """Multicast sender on *port* and *mcgroup*.""" - def __init__(self, port, mcgroup=MC_GROUP): - """Initialize multicast sending.""" + def __init__(self, port, mcgroup=None): + """Set up the multicast sender.""" self.port = port self.group = mcgroup self.socket, self.group = mcast_sender(mcgroup) - logger.debug('Started multicast group %s', mcgroup) + logger.debug("Started multicast group %s", self.group) def __call__(self, data): """Send data to a socket.""" @@ -77,25 +104,43 @@ def close(self): # Allow non-object interface -def mcast_sender(mcgroup=MC_GROUP): +def mcast_sender(mcgroup=None): """Non-object interface for sending multicast messages.""" + if mcgroup is None: + mcgroup = get_mc_group() sock = socket(AF_INET, SOCK_DGRAM) try: sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) if _is_broadcast_group(mcgroup): - group = '' + group = "" sock.setsockopt(SOL_SOCKET, SO_BROADCAST, 1) - elif int(mcgroup.split(".")[0]) > 239 or int(mcgroup.split(".")[0]) < 224: - raise IOError("Invalid multicast address.") + elif ((int(mcgroup.split(".")[0]) > 239) or + (int(mcgroup.split(".")[0]) < 224)): + raise IOError(f"Invalid multicast address {mcgroup}") else: group = mcgroup - ttl = struct.pack('b', TTL_LOCALNET) # Time-to-live + ttl = struct.pack("b", TTL_LOCALNET) # Time-to-live sock.setsockopt(IPPROTO_IP, IP_MULTICAST_TTL, ttl) + + with suppress(KeyError): + multicast_interface = config.get("multicast_interface") + sock.setsockopt(IPPROTO_IP, IP_MULTICAST_IF, inet_aton(multicast_interface)) except Exception: sock.close() raise return sock, group + +def get_mc_group(): + try: + mcgroup = os.environ["PYTROLL_MC_GROUP"] + warnings.warn("PYTROLL_MC_GROUP is pending deprecation, please use POSTTROLL_MC_GROUP instead.", + PendingDeprecationWarning) + except KeyError: + mcgroup = DEFAULT_MC_GROUP + mcgroup = config.get("mc_group", mcgroup) + return mcgroup + # ----------------------------------------------------------------------------- # # Receiver. @@ -103,13 +148,13 @@ def mcast_sender(mcgroup=MC_GROUP): # ----------------------------------------------------------------------------- -class MulticastReceiver(object): +class MulticastReceiver: """Multicast receiver on *port* for an *mcgroup*.""" BUFSIZE = 1024 - def __init__(self, port, mcgroup=MC_GROUP): - """Initialize multicast receiver.""" + def __init__(self, port, mcgroup=None): + """Set up the multicast receiver.""" # Note: a multicast receiver will also receive broadcast on same port. self.port = port self.socket, self.group = mcast_receiver(port, mcgroup) @@ -129,14 +174,16 @@ def __call__(self): def close(self): """Close the receiver.""" - self.socket.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack('ii', 1, 1)) + self.socket.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 1, 1)) self.socket.close() # Allow non-object interface -def mcast_receiver(port, mcgroup=MC_GROUP): +def mcast_receiver(port, mcgroup=None): """Open a UDP socket, bind it to a port and select a multicast group.""" + if mcgroup is None: + mcgroup = get_mc_group() if _is_broadcast_group(mcgroup): group = None else: @@ -154,22 +201,21 @@ def mcast_receiver(port, mcgroup=MC_GROUP): sock.setsockopt(SOL_IP, IP_MULTICAST_LOOP, 1) # default # Bind it to the port - sock.bind(('', port)) + sock.bind(("", port)) # Look up multicast group address in name server # (doesn't hurt if it is already in ddd.ddd.ddd.ddd format) if group: group = gethostbyname(group) - # Construct binary group address - bytes_ = [int(b) for b in group.split(".")] - grpaddr = 0 - for byte in bytes_: - grpaddr = (grpaddr << 8) | byte - - # Construct struct mreq from grpaddr and ifaddr - ifaddr = INADDR_ANY - mreq = struct.pack('!LL', grpaddr, ifaddr) + # Construct struct mreq + try: + multicast_interface = config.get("multicast_interface") + ifaddr = inet_aton(multicast_interface) + mreq = struct.pack("=4s4s", inet_aton(group), ifaddr) + except KeyError: + ifaddr = INADDR_ANY + mreq = struct.pack("=4sl", inet_aton(group), ifaddr) # Add group membership sock.setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, mreq) @@ -177,7 +223,7 @@ def mcast_receiver(port, mcgroup=MC_GROUP): sock.close() raise - return sock, group or '' + return sock, group or "" # ----------------------------------------------------------------------------- # @@ -188,6 +234,6 @@ def mcast_receiver(port, mcgroup=MC_GROUP): def _is_broadcast_group(group): """Check if *group* is a valid multicasting group.""" - if not group or gethostbyname(group) in ('0.0.0.0', '255.255.255.255'): + if not group or gethostbyname(group) in ("0.0.0.0", "255.255.255.255"): return True return False diff --git a/posttroll/listener.py b/posttroll/listener.py index 608d134..e89c486 100644 --- a/posttroll/listener.py +++ b/posttroll/listener.py @@ -23,11 +23,12 @@ """Listener module.""" -from posttroll.subscriber import create_subscriber_from_dict_config +import logging +import time from queue import Queue from threading import Thread -import time -import logging + +from posttroll.subscriber import create_subscriber_from_dict_config class ListenerContainer: @@ -106,11 +107,11 @@ def create_subscriber(self): def _get_subscriber_config(self): config = { - 'services': self.services, - 'topics': self.topics, - 'addr_listener': True, - 'addresses': self.addresses, - 'nameserver': self.nameserver, + "services": self.services, + "topics": self.topics, + "addr_listener": True, + "addresses": self.addresses, + "nameserver": self.nameserver, } return config diff --git a/posttroll/logger.py b/posttroll/logger.py index 31fd152..2155a76 100644 --- a/posttroll/logger.py +++ b/posttroll/logger.py @@ -25,14 +25,14 @@ # TODO: remove old hanging subscriptions -from posttroll.subscriber import Subscribe -from posttroll.publisher import NoisyPublisher -from posttroll.message import Message -from threading import Thread - import copy import logging import logging.handlers +from threading import Thread + +from posttroll.message import Message +from posttroll.publisher import NoisyPublisher +from posttroll.subscriber import Subscribe LOGGER = logging.getLogger(__name__) @@ -75,11 +75,11 @@ def close(self): BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) COLORS = { - 'WARNING': YELLOW, - 'INFO': GREEN, - 'DEBUG': BLUE, - 'CRITICAL': MAGENTA, - 'ERROR': RED + "WARNING": YELLOW, + "INFO": GREEN, + "DEBUG": BLUE, + "CRITICAL": MAGENTA, + "ERROR": RED } COLOR_SEQ = "\033[1;%dm" @@ -201,8 +201,8 @@ def run(): time.sleep(1) except KeyboardInterrupt: tlogger.stop() - print("Thanks for using pytroll/logger. See you soon on www.pytroll.org!") + print("Thanks for using pytroll/logger. See you soon on www.pytroll.org!") # noqa -if __name__ == '__main__': +if __name__ == "__main__": run() diff --git a/posttroll/message.py b/posttroll/message.py index 6c10c0a..ab68484 100644 --- a/posttroll/message.py +++ b/posttroll/message.py @@ -48,8 +48,8 @@ from posttroll import strp_isoformat -_MAGICK = 'pytroll:/' -_VERSION = 'v1.01' +_MAGICK = "pytroll:/" +_VERSION = "v1.01" class MessageError(Exception): @@ -117,17 +117,17 @@ class Message(object): - It will make a Message pickleable. """ - def __init__(self, subject='', atype='', data='', binary=False, rawstr=None): + def __init__(self, subject="", atype="", data="", binary=False, rawstr=None): """Initialize a Message from a subject, type and data, or from a raw string.""" if rawstr: self.__dict__ = _decode(rawstr) else: try: - self.subject = subject.decode('utf-8') + self.subject = subject.decode("utf-8") except AttributeError: self.subject = subject try: - self.type = atype.decode('utf-8') + self.type = atype.decode("utf-8") except AttributeError: self.type = atype self.type = atype @@ -142,17 +142,17 @@ def __init__(self, subject='', atype='', data='', binary=False, rawstr=None): def user(self): """Try to return a user from a sender.""" try: - return self.sender[:self.sender.index('@')] + return self.sender[:self.sender.index("@")] except ValueError: - return '' + return "" @property def host(self): """Try to return a host from a sender.""" try: - return self.sender[self.sender.index('@') + 1:] + return self.sender[self.sender.index("@") + 1:] except ValueError: - return '' + return "" @property def head(self): @@ -181,7 +181,7 @@ def __unicode__(self): def __str__(self): """Return the human readable representation of the Message.""" try: - return unicode(self).encode('utf-8') + return unicode(self).encode("utf-8") except NameError: return self.encode() @@ -243,36 +243,17 @@ def datetime_decoder(dct): def _decode(rawstr): """Convert a raw string to a Message.""" - # Check for the magick word. - try: - rawstr = rawstr.decode('utf-8') - except (AttributeError, UnicodeEncodeError): - pass - except (UnicodeDecodeError): - try: - rawstr = rawstr.decode('iso-8859-1') - except (UnicodeDecodeError): - rawstr = rawstr.decode('utf-8', 'ignore') - if not rawstr.startswith(_MAGICK): - raise MessageError("This is not a '%s' message (wrong magick word)" - % _MAGICK) - rawstr = rawstr[len(_MAGICK):] + rawstr = _check_for_magic_word(rawstr) - # Check for element count and version - raw = re.split(r"\s+", rawstr, maxsplit=6) - if len(raw) < 5: - raise MessageError("Could node decode raw string: '%s ...'" - % str(rawstr[:36])) - version = raw[4][:len(_VERSION)] - if not _is_valid_version(version): - raise MessageError("Invalid Message version: '%s'" % str(version)) + raw = _check_for_element_count(rawstr) + version = _check_for_version(raw) # Start to build message - msg = dict((('subject', raw[0].strip()), - ('type', raw[1].strip()), - ('sender', raw[2].strip()), - ('time', strp_isoformat(raw[3].strip())), - ('version', version))) + msg = dict((("subject", raw[0].strip()), + ("type", raw[1].strip()), + ("sender", raw[2].strip()), + ("time", strp_isoformat(raw[3].strip())), + ("version", version))) # Data part try: @@ -282,26 +263,59 @@ def _decode(rawstr): mimetype = None if mimetype is None: - msg['data'] = '' - msg['binary'] = False - elif mimetype == 'application/json': + msg["data"] = "" + msg["binary"] = False + elif mimetype == "application/json": try: - msg['data'] = json.loads(raw[6], object_hook=datetime_decoder) - msg['binary'] = False + msg["data"] = json.loads(raw[6], object_hook=datetime_decoder) + msg["binary"] = False except ValueError: raise MessageError("JSON decode failed on '%s ...'" % raw[6][:36]) - elif mimetype == 'text/ascii': - msg['data'] = str(data) - msg['binary'] = False - elif mimetype == 'binary/octet-stream': - msg['data'] = data - msg['binary'] = True + elif mimetype == "text/ascii": + msg["data"] = str(data) + msg["binary"] = False + elif mimetype == "binary/octet-stream": + msg["data"] = data + msg["binary"] = True else: raise MessageError("Unknown mime-type '%s'" % mimetype) return msg +def _check_for_version(raw): + version = raw[4][:len(_VERSION)] + if not _is_valid_version(version): + raise MessageError("Invalid Message version: '%s'" % str(version)) + return version + + +def _check_for_element_count(rawstr): + raw = re.split(r"\s+", rawstr, maxsplit=6) + if len(raw) < 5: + raise MessageError("Could node decode raw string: '%s ...'" + % str(rawstr[:36])) + + return raw + + +def _check_for_magic_word(rawstr): + """Check for the magick word.""" + try: + rawstr = rawstr.decode("utf-8") + except (AttributeError, UnicodeEncodeError): + pass + except (UnicodeDecodeError): + try: + rawstr = rawstr.decode("iso-8859-1") + except (UnicodeDecodeError): + rawstr = rawstr.decode("utf-8", "ignore") + if not rawstr.startswith(_MAGICK): + raise MessageError("This is not a '%s' message (wrong magick word)" + % _MAGICK) + return rawstr[len(_MAGICK):] + + def datetime_encoder(obj): """Encode datetimes into iso format.""" try: @@ -317,15 +331,15 @@ def _encode(msg, head=False, binary=False): if not head and msg.data: if not binary and isinstance(msg.data, str): - return (rawstr + ' ' + - 'text/ascii' + ' ' + msg.data) + return (rawstr + " " + + "text/ascii" + " " + msg.data) elif not binary: - return (rawstr + ' ' + - 'application/json' + ' ' + + return (rawstr + " " + + "application/json" + " " + json.dumps(msg.data, default=datetime_encoder)) else: - return (rawstr + ' ' + - 'binary/octet-stream' + ' ' + msg.data) + return (rawstr + " " + + "binary/octet-stream" + " " + msg.data) return rawstr diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py index 32103c8..4990c36 100644 --- a/posttroll/message_broadcaster.py +++ b/posttroll/message_broadcaster.py @@ -23,58 +23,34 @@ """Message broadcast module.""" -import time -import threading -import logging import errno +import logging +import threading -from posttroll import message -from posttroll.bbmcast import MulticastSender, MC_GROUP -from posttroll import get_context -from zmq import REQ, LINGER +from posttroll import config, message +from posttroll.bbmcast import MulticastSender, get_configured_broadcast_port -__all__ = ('MessageBroadcaster', 'AddressBroadcaster', 'sendaddress') +__all__ = ("MessageBroadcaster", "AddressBroadcaster", "sendaddress") LOGGER = logging.getLogger(__name__) -broadcast_port = 21200 - -class DesignatedReceiversSender(object): +class DesignatedReceiversSender: """Sends message to multiple *receivers* on *port*.""" - def __init__(self, default_port, receivers): """Set settings.""" - self.default_port = default_port - self.receivers = receivers + backend = config.get("backend", "unsecure_zmq") + if backend == "unsecure_zmq": + from posttroll.backends.zmq.message_broadcaster import ZMQDesignatedReceiversSender + self._sender = ZMQDesignatedReceiversSender(default_port, receivers) def __call__(self, data): """Send messages from all receivers.""" - for receiver in self.receivers: - self._send_to_address(receiver, data) - - def _send_to_address(self, address, data, timeout=10): - """Send data to *address* and *port* without verification of response.""" - # Socket to talk to server - socket = get_context().socket(REQ) - try: - socket.setsockopt(LINGER, timeout * 1000) - if address.find(":") == -1: - socket.connect("tcp://%s:%d" % (address, self.default_port)) - else: - socket.connect("tcp://%s" % address) - socket.send_string(data) - message = socket.recv_string() - if message != "ok": - LOGGER.warn("invalid acknowledge received: %s" % message) - - finally: - socket.close() + return self._sender(data) def close(self): """Close the sender.""" - pass - + return self._sender.close() # ---------------------------------------------------------------------------- # @@ -83,51 +59,48 @@ def close(self): # ---------------------------------------------------------------------------- -class MessageBroadcaster(object): +class MessageBroadcaster: """Class to broadcast stuff. If *interval* is 0 or negative, no broadcasting is done. """ def __init__(self, msg, port, interval, designated_receivers=None): - """Initialize message broadcaster.""" + """Set up the message broadcaster.""" if designated_receivers: self._sender = DesignatedReceiversSender(port, designated_receivers) else: - # mcgroup = None or '' is broadcast - # mcgroup = MC_GROUP is default multicast group - self._sender = MulticastSender(port, mcgroup=MC_GROUP) + self._sender = MulticastSender(port) self._interval = interval self._message = msg - self._do_run = False - self._is_running = False + self._shutdown_event = threading.Event() self._thread = threading.Thread(target=self._run) def start(self): """Start the broadcasting.""" if self._interval > 0: - if not self._is_running: - self._do_run = True + if not self._thread.is_alive(): self._thread.start() return self def is_running(self): """Are we running.""" - return self._is_running + return self._thread.is_alive() def stop(self): """Stop the broadcasting.""" - self._do_run = False + self._shutdown_event.set() + self._sender.close() + self._thread.join() return self def _run(self): """Broadcasts forever.""" - self._is_running = True network_fail = False try: - while self._do_run: + while not self._shutdown_event.is_set(): try: if network_fail is True: LOGGER.info("Network connection re-established!") @@ -141,9 +114,8 @@ def _run(self): network_fail = True else: raise - time.sleep(self._interval) + self._shutdown_event.wait(self._interval) finally: - self._is_running = False self._sender.close() @@ -155,13 +127,13 @@ def _run(self): class AddressBroadcaster(MessageBroadcaster): - """Class to broadcast stuff.""" + """Class to broadcast addresses.""" def __init__(self, name, address, interval, nameservers): - """Initialize address broadcasting.""" + """Set up the Address broadcaster.""" msg = message.Message("/address/%s" % name, "info", {"URI": "%s:%d" % address}).encode() - MessageBroadcaster.__init__(self, msg, broadcast_port, interval, + MessageBroadcaster.__init__(self, msg, get_configured_broadcast_port(), interval, nameservers) @@ -184,7 +156,7 @@ def __init__(self, name, address, data_type, interval=2, nameservers=None): msg = message.Message("/address/%s" % name, "info", {"URI": address, "service": data_type}).encode() - MessageBroadcaster.__init__(self, msg, broadcast_port, interval, + MessageBroadcaster.__init__(self, msg, get_configured_broadcast_port(), interval, nameservers) diff --git a/posttroll/ns.py b/posttroll/ns.py index 1ecff05..1bf05b4 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -23,35 +23,36 @@ """Manage other's subscriptions. -Default port is 5557, if $NAMESERVER_PORT is not defined. +Default port is 5557, if $POSTTROLL_NAMESERVER_PORT is not defined. """ import datetime as dt import logging import os import time +import warnings -from threading import Lock -# pylint: disable=E0611 -from zmq import LINGER, NOBLOCK, POLLIN, REP, REQ, Poller - -from posttroll import get_context +from posttroll import config from posttroll.address_receiver import AddressReceiver from posttroll.message import Message # pylint: enable=E0611 -PORT = int(os.environ.get("NAMESERVER_PORT", 5557)) +DEFAULT_NAMESERVER_PORT = 5557 logger = logging.getLogger(__name__) -nslock = Lock() - -class TimeoutError(BaseException): - """A timeout.""" +def get_configured_nameserver_port(): + """Get the configured nameserver port.""" + try: + port = int(os.environ["NAMESERVER_PORT"]) + warnings.warn("NAMESERVER_PORT is pending deprecation, please use POSTTROLL_NAMESERVER_PORT instead.", + PendingDeprecationWarning) + except KeyError: + port = DEFAULT_NAMESERVER_PORT + return config.get("nameserver_port", port) - pass # Client functions. @@ -79,34 +80,16 @@ def get_pub_addresses(names=None, timeout=10, nameserver="localhost"): def get_pub_address(name, timeout=10, nameserver="localhost"): """Get the address of the named publisher. - Kwargs: - - name: name of the publishers - - nameserver: nameserver address to query the publishers from (default: localhost). + Args: + name: name of the publishers + timeout: how long to wait for an address, in seconds. + nameserver: nameserver address to query the publishers from (default: localhost). """ - # Socket to talk to server - socket = get_context().socket(REQ) - try: - socket.setsockopt(LINGER, int(timeout * 1000)) - socket.connect("tcp://" + nameserver + ":" + str(PORT)) - logger.debug('Connecting to %s', - "tcp://" + nameserver + ":" + str(PORT)) - poller = Poller() - poller.register(socket, POLLIN) - - message = Message("/oper/ns", "request", {"service": name}) - socket.send_string(str(message)) - - # Get the reply. - sock = poller.poll(timeout=timeout * 1000) - if sock: - if sock[0][0] == socket: - message = Message.decode(socket.recv_string(NOBLOCK)) - return message.data - else: - raise TimeoutError("Didn't get an address after %d seconds." - % timeout) - finally: - socket.close() + if config["backend"] not in ["unsecure_zmq", "secure_zmq"]: + raise NotImplementedError(f"Did not recognize backend: {config['backend']}") + from posttroll.backends.zmq.ns import zmq_get_pub_address + return zmq_get_pub_address(name, timeout, nameserver) + # Server part. @@ -130,6 +113,11 @@ def __init__(self, max_age=None, multicast_enabled=True, restrict_to_localhost=F self._max_age = max_age or dt.timedelta(minutes=10) self._multicast_enabled = multicast_enabled self._restrict_to_localhost = restrict_to_localhost + backend = config["backend"] + if backend not in ["unsecure_zmq", "secure_zmq"]: + raise NotImplementedError(f"Did not recognize backend: {backend}") + from posttroll.backends.zmq.ns import ZMQNameServer + self._ns = ZMQNameServer() def run(self, *args): """Run the listener and answer to requests.""" @@ -139,37 +127,11 @@ def run(self, *args): multicast_enabled=self._multicast_enabled, restrict_to_localhost=self._restrict_to_localhost) arec.start() - port = PORT - try: - with nslock: - self.listener = get_context().socket(REP) - self.listener.bind("tcp://*:" + str(port)) - logger.debug('Listening on port %s', str(port)) - poller = Poller() - poller.register(self.listener, POLLIN) - while self.loop: - with nslock: - socks = dict(poller.poll(1000)) - if socks: - if socks.get(self.listener) == POLLIN: - msg = self.listener.recv_string() - else: - continue - logger.debug("Replying to request: " + str(msg)) - msg = Message.decode(msg) - active_address = get_active_address(msg.data["service"], arec) - self.listener.send_unicode(str(active_address)) - except KeyboardInterrupt: - # Needed to stop the nameserver. - pass + return self._ns.run(arec) finally: arec.stop() - self.stop() def stop(self): - """Stop the name server.""" - self.listener.setsockopt(LINGER, 1) - self.loop = False - with nslock: - self.listener.close() + """Stop the nameserver.""" + return self._ns.stop() diff --git a/posttroll/publisher.py b/posttroll/publisher.py index 4a9bdec..dee85cc 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -26,15 +26,10 @@ import datetime as dt import logging import socket -from threading import Lock -from urllib.parse import urlsplit, urlunsplit -import zmq -from posttroll import get_context -from posttroll import _set_tcp_keepalive +from posttroll import config from posttroll.message import Message from posttroll.message_broadcaster import sendaddressservice -from posttroll import config LOGGER = logging.getLogger(__name__) @@ -92,56 +87,31 @@ class Publisher: def __init__(self, address, name="", min_port=None, max_port=None): """Bind the publisher class to a port.""" - self.name = name - self.destination = address - self.publish_socket = None # Limit port range or use the defaults when no port is defined # by the user - self.min_port = min_port or int(config.get('pub_min_port', 49152)) - self.max_port = max_port or int(config.get('pub_max_port', 65536)) - self.port_number = None - + min_port = min_port or int(config.get("pub_min_port", 49152)) + max_port = max_port or int(config.get("pub_max_port", 65536)) # Initialize no heartbeat self._heartbeat = None - self._pub_lock = Lock() + + backend = config.get("backend", "unsecure_zmq") + if backend not in ["unsecure_zmq", "secure_zmq"]: + raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") + from posttroll.backends.zmq.publisher import ZMQPublisher + self._publisher = ZMQPublisher(address, name=name, min_port=min_port, max_port=max_port) def start(self): """Start the publisher.""" - self.publish_socket = get_context().socket(zmq.PUB) - _set_tcp_keepalive(self.publish_socket) - - self.bind() - LOGGER.info("publisher started on port %s", str(self.port_number)) + self._publisher.start() return self - def bind(self): - """Bind the port.""" - # Check for port 0 (random port) - u__ = urlsplit(self.destination) - port = u__.port - if port == 0: - dest = urlunsplit((u__.scheme, u__.hostname, - u__.path, u__.query, u__.fragment)) - self.port_number = self.publish_socket.bind_to_random_port( - dest, - min_port=self.min_port, - max_port=self.max_port) - netloc = u__.hostname + ":" + str(self.port_number) - self.destination = urlunsplit((u__.scheme, netloc, u__.path, - u__.query, u__.fragment)) - else: - self.publish_socket.bind(self.destination) - self.port_number = port - def send(self, msg): """Send the given message.""" - with self._pub_lock: - self.publish_socket.send_string(msg) + return self._publisher.send(msg) def stop(self): """Stop the publisher.""" - self.publish_socket.setsockopt(zmq.LINGER, 1) - self.publish_socket.close() + return self._publisher.stop() def close(self): """Alias for stop.""" @@ -153,13 +123,23 @@ def heartbeat(self, min_interval=0): self._heartbeat = _PublisherHeartbeat(self) self._heartbeat(min_interval) + @property + def name(self): + """Get the name of the publisher.""" + return self._publisher.name + + @property + def port_number(self): + """Get the port number from the actual publisher.""" + return self._publisher.port_number + class _PublisherHeartbeat: """Publisher for heartbeat.""" def __init__(self, publisher): self.publisher = publisher - self.subject = '/heartbeat/' + publisher.name + self.subject = "/heartbeat/" + publisher.name self.lastbeat = dt.datetime(1900, 1, 1) def __call__(self, min_interval=0): @@ -211,17 +191,18 @@ def __init__(self, name, port=0, aliases=None, broadcast_interval=2, def start(self): """Start the publisher.""" - pub_addr = _get_publish_address(self._port) - self._publisher = self._publisher_class(pub_addr, self._name, + pub_addr = _create_tcp_publish_address(self._port) + self._publisher = self._publisher_class(pub_addr, name=self._name, min_port=self.min_port, - max_port=self.max_port).start() - LOGGER.debug("entering publish %s", str(self._publisher.destination)) - addr = _get_publish_address(self._publisher.port_number, str(get_own_ip())) + max_port=self.max_port) + self._publisher.start() + addr = _create_tcp_publish_address(self._publisher.port_number, str(get_own_ip())) self._broadcaster = sendaddressservice(self._name, addr, self._aliases, self._broadcast_interval, - self._nameservers).start() - return self._publisher + self._nameservers) + self._broadcaster.start() + return self def send(self, msg): """Send a *msg*.""" @@ -241,12 +222,17 @@ def close(self): """Alias for stop.""" self.stop() + @property + def port_number(self): + """Get the port number.""" + return self._publisher.port_number + def heartbeat(self, min_interval=0): - """Publish a heartbeat.""" - self._publisher.heartbeat(min_interval=min_interval) + """Send a heartbeat ... but only if *min_interval* seconds has passed since last beat.""" + self._publisher.heartbeat(min_interval) -def _get_publish_address(port, ip_address="*"): +def _create_tcp_publish_address(port, ip_address="*"): return "tcp://" + ip_address + ":" + str(port) @@ -255,7 +241,7 @@ class Publish: See :class:`Publisher` and :class:`NoisyPublisher` for more information on the arguments. - The publisher is selected based on the arguments, see :function:`create_publisher_from_dict_config` for + The publisher is selected based on the arguments, see :func:`create_publisher_from_dict_config` for information how the selection is done. Example on how to use the :class:`Publish` context:: @@ -281,9 +267,9 @@ class Publish: def __init__(self, name, port=0, aliases=None, broadcast_interval=2, nameservers=None, min_port=None, max_port=None): """Initialize the class.""" - settings = {'name': name, 'port': port, 'min_port': min_port, 'max_port': max_port, - 'aliases': aliases, 'broadcast_interval': broadcast_interval, - 'nameservers': nameservers} + settings = {"name": name, "port": port, "min_port": min_port, "max_port": max_port, + "aliases": aliases, "broadcast_interval": broadcast_interval, + "nameservers": nameservers} self.publisher = create_publisher_from_dict_config(settings) def __enter__(self): @@ -315,18 +301,21 @@ def create_publisher_from_dict_config(settings): described in the docstrings of the respective classes, namely :class:`~posttroll.publisher.Publisher` and :class:`~posttroll.publisher.NoisyPublisher`. """ - if settings.get('port') and settings.get('nameservers') is False: + if (settings.get("port") or settings.get("address")) and settings.get("nameservers") is False: return _get_publisher_instance(settings) return _get_noisypublisher_instance(settings) def _get_publisher_instance(settings): - publisher_address = _get_publish_address(settings['port']) - publisher_name = settings.get("name", "") - min_port = settings.get("min_port") - max_port = settings.get("max_port") - - return Publisher(publisher_address, name=publisher_name, min_port=min_port, max_port=max_port) + settings = settings.copy() + publisher_address = settings.pop("address", None) + port = settings.pop("port", None) + if not publisher_address: + publisher_address = _create_tcp_publish_address(port) + settings.pop("nameservers", None) + settings.pop("aliases", None) + settings.pop("broadcast_interval", None) + return Publisher(publisher_address, **settings) def _get_noisypublisher_instance(settings): diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index d966548..fc3a8c1 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -27,17 +27,10 @@ import datetime as dt import logging import time -from time import sleep -from threading import Lock -from urllib.parse import urlsplit -# pylint: disable=E0611 -from zmq import LINGER, NOBLOCK, POLLIN, PULL, SUB, SUBSCRIBE, Poller, ZMQError - -# pylint: enable=E0611 -from posttroll import get_context -from posttroll import _set_tcp_keepalive -from posttroll.message import _MAGICK, Message +from posttroll import config +from posttroll.address_receiver import get_configured_address_port +from posttroll.message import _MAGICK from posttroll.ns import get_pub_address LOGGER = logging.getLogger(__name__) @@ -67,81 +60,31 @@ class Subscriber: """ - def __init__(self, addresses, topics='', message_filter=None, translate=False): + def __init__(self, addresses, topics="", message_filter=None, translate=False): """Initialize the subscriber.""" - self._topics = self._magickfy_topics(topics) - self._filter = message_filter - self._translate = translate - - self.sub_addr = {} - self.addr_sub = {} - - self._hooks = [] - self._hooks_cb = {} - - self.poller = Poller() - self._lock = Lock() - - self.update(addresses) + topics = self._magickfy_topics(topics) + backend = config.get("backend", "unsecure_zmq") + if backend not in ["unsecure_zmq", "secure_zmq"]: + raise NotImplementedError(f"No support for backend {backend} implemented (yet?).") - self._loop = None + from posttroll.backends.zmq.subscriber import ZMQSubscriber + self._subscriber = ZMQSubscriber(addresses, topics=topics, + message_filter=message_filter, translate=translate) def add(self, address, topics=None): """Add *address* to the subscribing list for *topics*. - It topics is None we will subscibe to already specified topics. + It topics is None we will subscribe to already specified topics. """ - with self._lock: - if address in self.addresses: - return - - topics = self._magickfy_topics(topics) or self._topics - LOGGER.info("Subscriber adding address %s with topics %s", - str(address), str(topics)) - subscriber = self._add_sub_socket(address, topics) - self.sub_addr[subscriber] = address - self.addr_sub[address] = subscriber - - def _add_sub_socket(self, address, topics): - subscriber = get_context().socket(SUB) - _set_tcp_keepalive(subscriber) - for t__ in topics: - subscriber.setsockopt_string(SUBSCRIBE, str(t__)) - subscriber.connect(address) - - if self.poller: - self.poller.register(subscriber, POLLIN) - return subscriber + return self._subscriber.add(address, self._magickfy_topics(topics)) def remove(self, address): """Remove *address* from the subscribing list for *topics*.""" - with self._lock: - try: - subscriber = self.addr_sub[address] - except KeyError: - return - LOGGER.info("Subscriber removing address %s", str(address)) - del self.addr_sub[address] - del self.sub_addr[subscriber] - self._remove_sub_socket(subscriber) - - def _remove_sub_socket(self, subscriber): - if self.poller: - self.poller.unregister(subscriber) - subscriber.close() + return self._subscriber.remove(address) def update(self, addresses): """Update with a set of addresses.""" - if isinstance(addresses, str): - addresses = [addresses, ] - current_addresses, new_addresses = set(self.addresses), set(addresses) - addresses_to_remove = current_addresses.difference(new_addresses) - addresses_to_add = new_addresses.difference(current_addresses) - for addr in addresses_to_remove: - self.remove(addr) - for addr in addresses_to_add: - self.add(addr) - return bool(addresses_to_remove or addresses_to_add) + return self._subscriber.update(addresses) def add_hook_sub(self, address, topics, callback): """Specify a SUB *callback* in the same stream (thread) as the main receive loop. @@ -152,11 +95,7 @@ def add_hook_sub(self, address, topics, callback): Good for operations, which is required to be done in the same thread as the main recieve loop (e.q operations on the underlying sockets). """ - topics = self._magickfy_topics(topics) - LOGGER.info("Subscriber adding SUB hook %s for topics %s", - str(address), str(topics)) - socket = self._add_sub_socket(address, topics) - self._add_hook(socket, callback) + return self._subscriber.add_hook_sub(address, self._magickfy_topics(topics), callback) def add_hook_pull(self, address, callback): """Specify a PULL *callback* in the same stream (thread) as the main receive loop. @@ -164,84 +103,38 @@ def add_hook_pull(self, address, callback): The callback will be called with the received messages from the specified subscription. Good for pushed 'inproc' messages from another thread. """ - LOGGER.info("Subscriber adding PULL hook %s", str(address)) - socket = get_context().socket(PULL) - socket.connect(address) - if self.poller: - self.poller.register(socket, POLLIN) - self._add_hook(socket, callback) - - def _add_hook(self, socket, callback): - """Add a generic hook. The passed socket has to be "receive only".""" - self._hooks.append(socket) - self._hooks_cb[socket] = callback + return self._subscriber.add_hook_pull(address, callback) @property def addresses(self): """Get the addresses.""" - return self.sub_addr.values() + return self._subscriber.addresses @property def subscribers(self): """Get the subscribers.""" - return self.sub_addr.keys() + return self._subscriber.subscribers def recv(self, timeout=None): """Receive, optionally with *timeout* in seconds.""" - if timeout: - timeout *= 1000. - - for sub in list(self.subscribers) + self._hooks: - self.poller.register(sub, POLLIN) - self._loop = True - try: - while self._loop: - sleep(0) - try: - socks = dict(self.poller.poll(timeout=timeout)) - if socks: - for sub in self.subscribers: - if sub in socks and socks[sub] == POLLIN: - m__ = Message.decode(sub.recv_string(NOBLOCK)) - if not self._filter or self._filter(m__): - if self._translate: - url = urlsplit(self.sub_addr[sub]) - host = url[1].split(":")[0] - m__.sender = (m__.sender.split("@")[0] - + "@" + host) - yield m__ - - for sub in self._hooks: - if sub in socks and socks[sub] == POLLIN: - m__ = Message.decode(sub.recv_string(NOBLOCK)) - self._hooks_cb[sub](m__) - else: - # timeout - yield None - except ZMQError as err: - if self._loop: - LOGGER.exception("Receive failed: %s", str(err)) - finally: - for sub in list(self.subscribers) + self._hooks: - self.poller.unregister(sub) + return self._subscriber.recv(timeout) def __call__(self, **kwargs): """Handle calls with class instance.""" - return self.recv(**kwargs) + return self._subscriber(**kwargs) def stop(self): """Stop the subscriber.""" - self._loop = False + return self._subscriber.stop() def close(self): """Close the subscriber: stop it and close the local subscribers.""" - self.stop() - for sub in list(self.subscribers) + self._hooks: - try: - sub.setsockopt(LINGER, 1) - sub.close() - except ZMQError: - pass + return self._subscriber.close() + + @property + def running(self): + """Check if suscriber is running.""" + return self._subscriber.running @staticmethod def _magickfy_topics(topics): @@ -255,21 +148,13 @@ def _magickfy_topics(topics): ts_ = [] for t__ in topics: if not t__.startswith(_MAGICK): - if t__ and t__[0] == '/': + if t__ and t__[0] == "/": t__ = _MAGICK + t__ else: - t__ = _MAGICK + '/' + t__ + t__ = _MAGICK + "/" + t__ ts_.append(t__) return ts_ - def __del__(self): - """Clean up after the instance is deleted.""" - for sub in list(self.subscribers) + self._hooks: - try: - sub.close() - except Exception: # noqa: E722 - pass - class NSSubscriber: """Automatically subscribe to *services*. @@ -296,9 +181,9 @@ def __init__(self, services="", topics=_MAGICK, addr_listener=False, Default is to listen to all available services. """ - self._services = _to_array(services) - self._topics = _to_array(topics) - self._addresses = _to_array(addresses) + self._services = _to_list(services) + self._topics = _to_list(topics) + self._addresses = _to_list(addresses) self._timeout = timeout self._translate = translate @@ -313,7 +198,7 @@ def _get_addr_loop(service, timeout): """Try to get the address of *service* until for *timeout* seconds.""" then = dt.datetime.now() + dt.timedelta(seconds=timeout) while dt.datetime.now() < then: - addrs = get_pub_address(service, nameserver=self._nameserver) + addrs = get_pub_address(service, self._timeout, nameserver=self._nameserver) if addrs: return [addr["URI"] for addr in addrs] time.sleep(1) @@ -360,7 +245,7 @@ class Subscribe: See :class:`NSSubscriber` and :class:`Subscriber` for initialization parameters. - The subscriber is selected based on the arguments, see :function:`create_subscriber_from_dict_config` for + The subscriber is selected based on the arguments, see :func:`create_subscriber_from_dict_config` for information how the selection is done. Example:: @@ -379,14 +264,14 @@ def __init__(self, services="", topics=_MAGICK, addr_listener=False, message_filter=None): """Initialize the class.""" settings = { - 'services': services, - 'topics': topics, - 'message_filter': message_filter, - 'translate': translate, - 'addr_listener': addr_listener, - 'addresses': addresses, - 'timeout': timeout, - 'nameserver': nameserver, + "services": services, + "topics": topics, + "message_filter": message_filter, + "translate": translate, + "addr_listener": addr_listener, + "addresses": addresses, + "timeout": timeout, + "nameserver": nameserver, } self.subscriber = create_subscriber_from_dict_config(settings) @@ -399,7 +284,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): return self.subscriber.stop() -def _to_array(obj): +def _to_list(obj): """Convert *obj* to list if not already one.""" if isinstance(obj, str): return [obj, ] @@ -417,16 +302,17 @@ def __init__(self, subscriber, services="", nameserver="localhost"): services = [services, ] self.services = services self.subscriber = subscriber - self.subscriber.add_hook_sub("tcp://" + nameserver + ":16543", + address_publish_port = get_configured_address_port() + self.subscriber.add_hook_sub("tcp://" + nameserver + ":" + str(address_publish_port), ["pytroll://address", ], self.handle_msg) def handle_msg(self, msg): """Handle the message *msg*.""" addr_ = msg.data["URI"] - status = msg.data.get('status', True) + status = msg.data.get("status", True) if status: - msg_services = msg.data.get('service') + msg_services = msg.data.get("service") for service in self.services: if not service or service in msg_services: LOGGER.debug("Adding address %s %s", str(addr_), @@ -455,28 +341,29 @@ def create_subscriber_from_dict_config(settings): :class:`~posttroll.subscriber.Subscriber` and :class:`~posttroll.subscriber.NSSubscriber`. """ - if settings.get('addresses') and settings.get('nameserver') is False: + if settings.get("addresses") and settings.get("nameserver") is False: return _get_subscriber_instance(settings) return _get_nssubscriber_instance(settings).start() def _get_subscriber_instance(settings): - addresses = settings['addresses'] - topics = settings.get('topics', '') - message_filter = settings.get('message_filter', None) - translate = settings.get('translate', False) + _ = settings.pop("nameserver", None) + _ = settings.pop("port", None) + _ = settings.pop("services", None) + _ = settings.pop("addr_listener", None), + _ = settings.pop("timeout", None) - return Subscriber(addresses, topics=topics, message_filter=message_filter, translate=translate) + return Subscriber(**settings) def _get_nssubscriber_instance(settings): - services = settings.get('services', '') - topics = settings.get('topics', _MAGICK) - addr_listener = settings.get('addr_listener', False) - addresses = settings.get('addresses', None) - timeout = settings.get('timeout', 10) - translate = settings.get('translate', False) - nameserver = settings.get('nameserver', 'localhost') or 'localhost' + services = settings.get("services", "") + topics = settings.get("topics", _MAGICK) + addr_listener = settings.get("addr_listener", False) + addresses = settings.get("addresses", None) + timeout = settings.get("timeout", 10) + translate = settings.get("translate", False) + nameserver = settings.get("nameserver", "localhost") or "localhost" return NSSubscriber( services=services, diff --git a/posttroll/testing.py b/posttroll/testing.py index d9cbc60..ceb79dc 100644 --- a/posttroll/testing.py +++ b/posttroll/testing.py @@ -10,7 +10,7 @@ def patched_subscriber_recv(messages): def interuptible_recv(self): """Yield message until the subscriber is closed.""" for msg in messages: - if self._loop is False: + if self.running is False: break yield msg @@ -20,7 +20,7 @@ def interuptible_recv(self): @contextmanager def patched_publisher(): - """Patch the Subscriber object to return given messages.""" + """Patch the Publisher object to return given messages.""" from unittest import mock published = [] diff --git a/posttroll/tests/__init__.py b/posttroll/tests/__init__.py index 88e1689..5e53b2b 100644 --- a/posttroll/tests/__init__.py +++ b/posttroll/tests/__init__.py @@ -20,25 +20,4 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -"""Tests package. -""" - -from posttroll.tests import test_bbmcast, test_message, test_pubsub -import unittest - - -def suite(): - """The global test suite. - """ - mysuite = unittest.TestSuite() - # Test the documentation strings - # Use the unittests also - mysuite.addTests(test_bbmcast.suite()) - mysuite.addTests(test_message.suite()) - mysuite.addTests(test_pubsub.suite()) - - return mysuite - - -def load_tests(loader, tests, pattern): - return suite() +"""Tests package.""" diff --git a/posttroll/tests/test_bbmcast.py b/posttroll/tests/test_bbmcast.py index a7bea8d..b61b26c 100644 --- a/posttroll/tests/test_bbmcast.py +++ b/posttroll/tests/test_bbmcast.py @@ -20,100 +20,177 @@ # You should have received a copy of the GNU General Public License along with # pytroll. If not, see . -import unittest -import random -from socket import SOL_SOCKET, SO_BROADCAST, error - -from posttroll import bbmcast - - -class TestBB(unittest.TestCase): - - """Test class. - """ - - def test_mcast_sender(self): - """Unit test for mcast_sender. - """ - mcgroup = (str(random.randint(224, 239)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255))) - socket, group = bbmcast.mcast_sender(mcgroup) - if mcgroup in ("0.0.0.0", "255.255.255.255"): - self.assertEqual(group, "") - self.assertEqual(socket.getsockopt(SOL_SOCKET, SO_BROADCAST), 1) - else: - self.assertEqual(group, mcgroup) - self.assertEqual(socket.getsockopt(SOL_SOCKET, SO_BROADCAST), 0) - - socket.close() +"""Test multicasting and broadcasting.""" - mcgroup = "0.0.0.0" - socket, group = bbmcast.mcast_sender(mcgroup) - self.assertEqual(group, "") - self.assertEqual(socket.getsockopt(SOL_SOCKET, SO_BROADCAST), 1) - socket.close() +import os +import random +from socket import SO_BROADCAST, SOL_SOCKET, error +from threading import Thread - mcgroup = "255.255.255.255" - socket, group = bbmcast.mcast_sender(mcgroup) - self.assertEqual(group, "") - self.assertEqual(socket.getsockopt(SOL_SOCKET, SO_BROADCAST), 1) - socket.close() +import pytest - mcgroup = (str(random.randint(0, 223)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255))) - self.assertRaises(IOError, bbmcast.mcast_sender, mcgroup) - - mcgroup = (str(random.randint(240, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255))) - self.assertRaises(IOError, bbmcast.mcast_sender, mcgroup) - - def test_mcast_receiver(self): - """Unit test for mcast_receiver. - """ - mcport = random.randint(1025, 65535) - mcgroup = "0.0.0.0" - socket, group = bbmcast.mcast_receiver(mcport, mcgroup) - self.assertEqual(group, "") - socket.close() +from posttroll import bbmcast - mcgroup = "255.255.255.255" - socket, group = bbmcast.mcast_receiver(mcport, mcgroup) - self.assertEqual(group, "") - socket.close() - # Valid multicast range is 224.0.0.0 to 239.255.255.255 - mcgroup = (str(random.randint(224, 239)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255))) - socket, group = bbmcast.mcast_receiver(mcport, mcgroup) - self.assertEqual(group, mcgroup) +def test_mcast_sender_works_with_valid_addresses(): + """Unit test for mcast_sender.""" + mcgroup = (str(random.randint(224, 239)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255))) + socket, group = bbmcast.mcast_sender(mcgroup) + if mcgroup in ("0.0.0.0", "255.255.255.255"): + assert group == "" + assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 + else: + assert group == mcgroup + assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 0 + + socket.close() + + +def test_mcast_sender_uses_broadcast_for_0s(): + """Test mcast_sender uses broadcast for 0.0.0.0.""" + mcgroup = "0.0.0.0" + socket, group = bbmcast.mcast_sender(mcgroup) + assert group == "" + assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 + socket.close() + + +def test_mcast_sender_uses_broadcast_for_255s(): + """Test mcast_sender uses broadcast for 255.255.255.255.""" + mcgroup = "255.255.255.255" + socket, group = bbmcast.mcast_sender(mcgroup) + assert group == "" + assert socket.getsockopt(SOL_SOCKET, SO_BROADCAST) == 1 + socket.close() + + +def test_mcast_sender_raises_for_invalit_adresses(): + """Test mcast_sender uses broadcast for 0.0.0.0.""" + mcgroup = (str(random.randint(0, 223)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255))) + with pytest.raises(OSError, match="Invalid multicast address .*"): + bbmcast.mcast_sender(mcgroup) + + mcgroup = (str(random.randint(240, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255))) + with pytest.raises(OSError, match="Invalid multicast address .*"): + bbmcast.mcast_sender(mcgroup) + + +def test_mcast_receiver_works_with_valid_addresses(): + """Unit test for mcast_receiver.""" + mcport = random.randint(1025, 65535) + mcgroup = "0.0.0.0" + socket, group = bbmcast.mcast_receiver(mcport, mcgroup) + assert group == "" + socket.close() + + mcgroup = "255.255.255.255" + socket, group = bbmcast.mcast_receiver(mcport, mcgroup) + assert group == "" + socket.close() + + # Valid multicast range is 224.0.0.0 to 239.255.255.255 + mcgroup = (str(random.randint(224, 239)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255))) + socket, group = bbmcast.mcast_receiver(mcport, mcgroup) + assert group == mcgroup + socket.close() + + mcgroup = (str(random.randint(0, 223)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255))) + with pytest.raises(error, match=".*Invalid argument.*"): + bbmcast.mcast_receiver(mcport, mcgroup) + + mcgroup = (str(random.randint(240, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255)) + "." + + str(random.randint(0, 255))) + with pytest.raises(error, match=".*Invalid argument.*"): + bbmcast.mcast_receiver(mcport, mcgroup) + + +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST"), + reason="Multicast tests disabled.", +) +def test_multicast_roundtrip(reraise): + """Test sending and receiving a multicast message.""" + mcgroup = bbmcast.DEFAULT_MC_GROUP + mcport = 5555 + rec_socket, _rec_group = bbmcast.mcast_receiver(mcport, mcgroup) + rec_socket.settimeout(.1) + + message = "Ho Ho Ho!" + + def check_message(sock, message): + with reraise: + data, _ = sock.recvfrom(1024) + assert data.decode() == message + + snd_socket, _snd_group = bbmcast.mcast_sender(mcgroup) + + thr = Thread(target=check_message, args=(rec_socket, message)) + thr.start() + + snd_socket.sendto(message.encode(), (mcgroup, mcport)) + + thr.join() + rec_socket.close() + snd_socket.close() + + +def test_broadcast_roundtrip(reraise): + """Test sending and receiving a broadcast message.""" + mcgroup = "0.0.0.0" + mcport = 5555 + rec_socket, _rec_group = bbmcast.mcast_receiver(mcport, mcgroup) + + message = "Ho Ho Ho!" + + def check_message(sock, message): + with reraise: + data, _ = sock.recvfrom(1024) + assert data.decode() == message + + snd_socket, _snd_group = bbmcast.mcast_sender(mcgroup) + + thr = Thread(target=check_message, args=(rec_socket, message)) + thr.start() + + snd_socket.sendto(message.encode(), (mcgroup, mcport)) + + thr.join() + rec_socket.close() + snd_socket.close() + + +def test_posttroll_mc_group_is_used(): + """Test that configured mc_group is used.""" + from posttroll import config + other_group = "226.0.0.13" + with config.set(mc_group=other_group): + socket, group = bbmcast.mcast_sender() socket.close() + assert group == "226.0.0.13" - mcgroup = (str(random.randint(0, 223)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255))) - self.assertRaises(error, bbmcast.mcast_receiver, mcport, mcgroup) - - mcgroup = (str(random.randint(240, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255)) + "." + - str(random.randint(0, 255))) - self.assertRaises(error, bbmcast.mcast_receiver, mcport, mcgroup) - - -def suite(): - """The suite for test_bbmcast. - """ - loader = unittest.TestLoader() - mysuite = unittest.TestSuite() - mysuite.addTest(loader.loadTestsFromTestCase(TestBB)) - return mysuite +def test_pytroll_mc_group_is_deprecated(monkeypatch): + """Test that PYTROLL_MC_GROUP is used but pending deprecation.""" + other_group = "226.0.0.13" + monkeypatch.setenv("PYTROLL_MC_GROUP", other_group) + with pytest.deprecated_call(): + socket, group = bbmcast.mcast_sender() + socket.close() + assert group == "226.0.0.13" diff --git a/posttroll/tests/test_message.py b/posttroll/tests/test_message.py index 3a4ee6e..af97236 100644 --- a/posttroll/tests/test_message.py +++ b/posttroll/tests/test_message.py @@ -23,26 +23,25 @@ """Test module for the message class.""" +import copy import os import sys import unittest -import copy from datetime import datetime -from posttroll.message import Message, _MAGICK - +from posttroll.message import _MAGICK, Message -HOME = os.path.dirname(__file__) or '.' -sys.path = [os.path.abspath(HOME + '/../..'), ] + sys.path +HOME = os.path.dirname(__file__) or "." +sys.path = [os.path.abspath(HOME + "/../.."), ] + sys.path -DATADIR = HOME + '/data' -SOME_METADATA = {'timestamp': datetime(2010, 12, 3, 16, 28, 39), - 'satellite': 'metop2', - 'uri': 'file://data/my/path/to/hrpt/files/myfile', - 'orbit': 1222, - 'format': 'hrpt', - 'afloat': 1.2345} +DATADIR = HOME + "/data" +SOME_METADATA = {"timestamp": datetime(2010, 12, 3, 16, 28, 39), + "satellite": "metop2", + "uri": "file://data/my/path/to/hrpt/files/myfile", + "orbit": 1222, + "format": "hrpt", + "afloat": 1.2345} class Test(unittest.TestCase): @@ -50,41 +49,32 @@ class Test(unittest.TestCase): def test_encode_decode(self): """Test the encoding/decoding of the message class.""" - msg1 = Message('/test/whatup/doc', 'info', data='not much to say') + msg1 = Message("/test/whatup/doc", "info", data="not much to say") - sender = '%s@%s' % (msg1.user, msg1.host) - self.assertTrue(sender == msg1.sender, - msg='Messaging, decoding user, host from sender failed') + sender = "%s@%s" % (msg1.user, msg1.host) + assert sender == msg1.sender, "Messaging, decoding user, host from sender failed" msg2 = Message.decode(msg1.encode()) - self.assertTrue(str(msg2) == str(msg1), - msg='Messaging, encoding, decoding failed') + assert str(msg2) == str(msg1), "Messaging, encoding, decoding failed" def test_decode(self): """Test the decoding of a message.""" rawstr = (_MAGICK + - r'/test/1/2/3 info ras@hawaii 2008-04-11T22:13:22.123000 v1.01' + + r"/test/1/2/3 info ras@hawaii 2008-04-11T22:13:22.123000 v1.01" + r' text/ascii "what' + r"'" + r's up doc"') msg = Message.decode(rawstr) - self.assertTrue(str(msg) == rawstr, - msg='Messaging, decoding of message failed') + assert str(msg) == rawstr, "Messaging, decoding of message failed" def test_encode(self): """Test the encoding of a message.""" - subject = '/test/whatup/doc' + subject = "/test/whatup/doc" atype = "info" - data = 'not much to say' + data = "not much to say" msg1 = Message(subject, atype, data=data) - sender = '%s@%s' % (msg1.user, msg1.host) - self.assertEqual(_MAGICK + - subject + " " + - atype + " " + - sender + " " + - str(msg1.time.isoformat()) + " " + - msg1.version + " " + - 'text/ascii' + " " + - data, - msg1.encode()) + sender = "%s@%s" % (msg1.user, msg1.host) + full_message = (_MAGICK + subject + " " + atype + " " + sender + " " + + str(msg1.time.isoformat()) + " " + msg1.version + " " + "text/ascii" + " " + data) + assert full_message == msg1.encode() def test_unicode(self): """Test handling of unicode.""" @@ -92,72 +82,66 @@ def test_unicode(self): msg = ('pytroll://PPS-monitorplot/3/norrköping/utv/polar/direct_readout/ file ' 'safusr.u@lxserv1096.smhi.se 2018-11-16T12:19:29.934025 v1.01 application/json' ' {"start_time": "2018-11-16T12:02:43.700000"}') - self.assertEqual(msg, str(Message(rawstr=msg))) + assert msg == str(Message(rawstr=msg)) except UnicodeDecodeError: - self.fail('Unexpected unicode decoding error') + self.fail("Unexpected unicode decoding error") try: msg = (u'pytroll://oper/polar/direct_readout/norrköping pong sat@MERLIN 2019-01-07T12:52:19.872171' r' v1.01 application/json {"station": "norrk\u00f6ping"}') try: - self.assertEqual(msg, str(Message(rawstr=msg)).decode('utf-8')) + assert msg == str(Message(rawstr=msg)).decode("utf-8") except AttributeError: - self.assertEqual(msg, str(Message(rawstr=msg))) + assert msg == str(Message(rawstr=msg)) except UnicodeDecodeError: - self.fail('Unexpected unicode decoding error') + self.fail("Unexpected unicode decoding error") def test_iso(self): """Test handling of iso-8859-1.""" msg = ('pytroll://oper/polar/direct_readout/norrköping pong sat@MERLIN ' '2019-01-07T12:52:19.872171 v1.01 application/json {"station": "norrköping"}') try: - iso_msg = msg.decode('utf-8').encode('iso-8859-1') + iso_msg = msg.decode("utf-8").encode("iso-8859-1") except AttributeError: - iso_msg = msg.encode('iso-8859-1') + iso_msg = msg.encode("iso-8859-1") try: Message(rawstr=iso_msg) except UnicodeDecodeError: - self.fail('Unexpected iso decoding error') + self.fail("Unexpected iso decoding error") def test_pickle(self): """Test pickling.""" import pickle - msg1 = Message('/test/whatup/doc', 'info', data='not much to say') + msg1 = Message("/test/whatup/doc", "info", data="not much to say") try: - fp_ = open("pickle.message", 'wb') + fp_ = open("pickle.message", "wb") pickle.dump(msg1, fp_) fp_.close() - fp_ = open("pickle.message", 'rb') + fp_ = open("pickle.message", "rb") msg2 = pickle.load(fp_) fp_.close() - self.assertTrue(str(msg1) == str(msg2), - msg='Messaging, pickle failed') + assert str(msg1) == str(msg2), "Messaging, pickle failed" finally: try: - os.remove('pickle.message') + os.remove("pickle.message") except OSError: pass def test_metadata(self): """Test metadata encoding/decoding.""" metadata = copy.copy(SOME_METADATA) - msg = Message.decode(Message('/sat/polar/smb/level1', 'file', + msg = Message.decode(Message("/sat/polar/smb/level1", "file", data=metadata).encode()) - self.assertTrue(msg.data == metadata, - msg='Messaging, metadata decoding / encoding failed') + assert msg.data == metadata, "Messaging, metadata decoding / encoding failed" def test_serialization(self): """Test json serialization.""" - compare_file = '/message_metadata.dumps' - try: - import json - except ImportError: - import simplejson as json - compare_file += '.simplejson' + compare_file = "/message_metadata.dumps" + import json metadata = copy.copy(SOME_METADATA) - metadata['timestamp'] = metadata['timestamp'].isoformat() + metadata["timestamp"] = metadata["timestamp"].isoformat() fp_ = open(DATADIR + compare_file) dump = fp_.read() fp_.close() @@ -165,21 +149,8 @@ def test_serialization(self): msg = json.loads(dump) for key, val in msg.items(): - self.assertEqual(val, metadata.get(key)) + assert val == metadata.get(key) msg = json.loads(local_dump) for key, val in msg.items(): - self.assertEqual(val, metadata.get(key)) - - -def suite(): - """Create the suite for test_message.""" - loader = unittest.TestLoader() - mysuite = unittest.TestSuite() - mysuite.addTest(loader.loadTestsFromTestCase(Test)) - - return mysuite - - -if __name__ == '__main__': - unittest.main() + assert val == metadata.get(key) diff --git a/posttroll/tests/test_nameserver.py b/posttroll/tests/test_nameserver.py new file mode 100644 index 0000000..f4fe81d --- /dev/null +++ b/posttroll/tests/test_nameserver.py @@ -0,0 +1,257 @@ +"""Tests for communication involving the nameserver for service discovery.""" + +import os +import time +import unittest +from contextlib import contextmanager +from datetime import timedelta +from threading import Thread +from unittest import mock + +import pytest + +from posttroll import config +from posttroll.message import Message +from posttroll.ns import NameServer, get_pub_address +from posttroll.publisher import Publish +from posttroll.subscriber import Subscribe + + +def free_port(): + """Get a free port. + + From https://gist.github.com/bertjwregeer/0be94ced48383a42e70c3d9fff1f4ad0 + + Returns a factory that finds the next free port that is available on the OS + This is a bit of a hack, it does this by creating a new socket, and calling + bind with the 0 port. The operating system will assign a brand new port, + which we can find out using getsockname(). Once we have the new port + information we close the socket thereby returning it to the free pool. + This means it is technically possible for this function to return the same + port twice (for example if run in very quick succession), however operating + systems return a random port number in the default range (1024 - 65535), + and it is highly unlikely for two processes to get the same port number. + In other words, it is possible to flake, but incredibly unlikely. + """ + import socket + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("0.0.0.0", 0)) + portnum = s.getsockname()[1] + s.close() + + return portnum + + +@contextmanager +def create_nameserver_instance(max_age=3, multicast_enabled=True): + """Create a nameserver instance.""" + config.set(nameserver_port=free_port()) + config.set(address_publish_port=free_port()) + ns = NameServer(max_age=timedelta(seconds=max_age), multicast_enabled=multicast_enabled) + thr = Thread(target=ns.run) + thr.start() + + try: + yield + finally: + ns.stop() + thr.join() + + +class TestAddressReceiver(unittest.TestCase): + """Test the AddressReceiver.""" + + @mock.patch("posttroll.address_receiver.Message") + @mock.patch("posttroll.address_receiver.Publish") + @mock.patch("posttroll.address_receiver.MulticastReceiver") + def test_localhost_restriction(self, mcrec, pub, msg): + """Test address receiver restricted only to localhost.""" + mocked_publish_instance = mock.Mock() + pub.return_value.__enter__.return_value = mocked_publish_instance + mcr_instance = mock.Mock() + mcrec.return_value = mcr_instance + mcr_instance.return_value = "blabla", ("255.255.255.255", 12) + + from posttroll.address_receiver import AddressReceiver + adr = AddressReceiver(restrict_to_localhost=True) + adr.start() + time.sleep(3) + try: + msg.decode.assert_not_called() + mocked_publish_instance.send.assert_not_called() + finally: + adr.stop() + + +@pytest.mark.parametrize( + "multicast_enabled", + [True, False] +) +def test_pub_addresses(multicast_enabled): + """Test retrieving addresses.""" + from posttroll.ns import get_pub_addresses + from posttroll.publisher import Publish + + if multicast_enabled: + if os.getenv("DISABLED_MULTICAST"): + pytest.skip("Multicast tests disabled.") + nameservers = None + else: + nameservers = ["localhost"] + with config.set(broadcast_port=free_port()): + with create_nameserver_instance(multicast_enabled=multicast_enabled): + with Publish(str("data_provider"), 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1): + time.sleep(.3) + res = get_pub_addresses(["this_data"], timeout=.5) + assert len(res) == 1 + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} + for key, val in expected.items(): + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] + res = get_pub_addresses([str("data_provider")]) + assert len(res) == 1 + expected = {u"status": True, + u"service": [u"data_provider", u"this_data"], + u"name": u"address"} + for key, val in expected.items(): + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] + + +@pytest.mark.parametrize( + "multicast_enabled", + [True, False] +) +def test_pub_sub_ctx(multicast_enabled): + """Test publish and subscribe.""" + if multicast_enabled: + if os.getenv("DISABLED_MULTICAST"): + pytest.skip("Multicast tests disabled.") + nameservers = None + else: + nameservers = ["localhost"] + + with config.set(broadcast_port=free_port()): + with create_nameserver_instance(multicast_enabled=multicast_enabled): + with Publish("data_provider", 0, ["this_data"], nameservers=nameservers, broadcast_interval=0.1) as pub: + with Subscribe("this_data", "counter") as sub: + for counter in range(5): + message = Message("/counter", "info", str(counter)) + pub.send(str(message)) + time.sleep(.1) + msg = next(sub.recv(.2)) + if msg is not None: + assert str(msg) == str(message) + tested = True + assert tested + + +@pytest.mark.parametrize( + "multicast_enabled", + [True, False] +) +def test_pub_sub_add_rm(multicast_enabled): + """Test adding and removing publishers.""" + if multicast_enabled: + if os.getenv("DISABLED_MULTICAST"): + pytest.skip("Multicast tests disabled.") + nameservers = None + else: + nameservers = ["localhost"] + + max_age = 0.5 + with config.set(broadcast_port=free_port()): + with create_nameserver_instance(max_age=max_age, multicast_enabled=multicast_enabled): + with Subscribe("this_data", "counter", addr_listener=True, timeout=.2) as sub: + assert len(sub.addresses) == 0 + with Publish("data_provider", 0, ["this_data"], nameservers=nameservers): + time.sleep(.1) + next(sub.recv(.1)) + assert len(sub.addresses) == 1 + time.sleep(max_age * 4) + for msg in sub.recv(.1): + if msg is None: + break + time.sleep(0.3) + assert len(sub.addresses) == 0 + with Publish("data_provider_2", 0, ["another_data"], nameservers=nameservers): + time.sleep(.1) + next(sub.recv(.1)) + assert len(sub.addresses) == 0 + + +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST"), + reason="Multicast tests disabled.", +) +def test_listener_container(): + """Test listener container.""" + from posttroll.listener import ListenerContainer + from posttroll.message import Message + from posttroll.publisher import NoisyPublisher + + with create_nameserver_instance(): + pub = NoisyPublisher("test", broadcast_interval=0.1) + pub.start() + sub = ListenerContainer(topics=["/counter"]) + time.sleep(.1) + for counter in range(5): + tested = False + msg_out = Message("/counter", "info", str(counter)) + pub.send(str(msg_out)) + + msg_in = sub.output_queue.get(True, 1) + if msg_in is not None: + assert str(msg_in) == str(msg_out) + tested = True + assert tested + pub.stop() + sub.stop() + + +@pytest.mark.skipif( + os.getenv("DISABLED_MULTICAST"), + reason="Multicast tests disabled.", +) +def test_noisypublisher_heartbeat(): + """Test that the heartbeat in the NoisyPublisher works.""" + from posttroll.publisher import NoisyPublisher + from posttroll.subscriber import Subscribe + + min_interval = 10 + + try: + with config.set(address_publish_port=free_port(), nameserver_port=free_port()): + ns_ = NameServer() + thr = Thread(target=ns_.run) + thr.start() + + pub = NoisyPublisher("test") + pub.start() + time.sleep(0.2) + + with Subscribe("test", topics="/heartbeat/test", nameserver="localhost") as sub: + time.sleep(0.2) + pub.heartbeat(min_interval=min_interval) + msg = next(sub.recv(1)) + assert msg.type == "beat" + assert msg.data == {"min_interval": min_interval} + finally: + pub.stop() + ns_.stop() + thr.join() + + +def test_switch_backend_for_nameserver(): + """Test switching backend for nameserver.""" + with config.set(backend="spurious_backend"): + with pytest.raises(NotImplementedError): + NameServer() + with pytest.raises(NotImplementedError): + get_pub_address("some_name") diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index a593bb2..c152bf8 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -23,200 +23,49 @@ """Test the publishing and subscribing facilities.""" -import unittest -from unittest import mock -from datetime import timedelta -from threading import Thread, Lock import time +import unittest from contextlib import contextmanager +from threading import Lock +from unittest import mock -import posttroll import pytest from donfig import Config -test_lock = Lock() - - -class TestNS(unittest.TestCase): - """Test the nameserver.""" - - def setUp(self): - """Set up the testing class.""" - from posttroll.ns import NameServer - test_lock.acquire() - self.ns = NameServer(max_age=timedelta(seconds=3)) - self.thr = Thread(target=self.ns.run) - self.thr.start() - - def tearDown(self): - """Clean up after the tests have run.""" - self.ns.stop() - self.thr.join() - time.sleep(2) - test_lock.release() - - def test_pub_addresses(self): - """Test retrieving addresses.""" - from posttroll.ns import get_pub_addresses - from posttroll.publisher import Publish - - with Publish(str("data_provider"), 0, ["this_data"], broadcast_interval=0.1): - time.sleep(.3) - res = get_pub_addresses(["this_data"], timeout=.5) - assert len(res) == 1 - expected = {u'status': True, - u'service': [u'data_provider', u'this_data'], - u'name': u'address'} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] - res = get_pub_addresses([str("data_provider")]) - assert len(res) == 1 - expected = {u'status': True, - u'service': [u'data_provider', u'this_data'], - u'name': u'address'} - for key, val in expected.items(): - assert res[0][key] == val - assert "receive_time" in res[0] - assert "URI" in res[0] - - def test_pub_sub_ctx(self): - """Test publish and subscribe.""" - from posttroll.message import Message - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe - - with Publish("data_provider", 0, ["this_data"]) as pub: - with Subscribe("this_data", "counter") as sub: - for counter in range(5): - message = Message("/counter", "info", str(counter)) - pub.send(str(message)) - time.sleep(1) - msg = next(sub.recv(2)) - if msg is not None: - assert str(msg) == str(message) - tested = True - sub.close() - assert tested +import posttroll +from posttroll import config +from posttroll.message import Message +from posttroll.publisher import Publish, Publisher, create_publisher_from_dict_config +from posttroll.subscriber import Subscribe, Subscriber - def test_pub_sub_add_rm(self): - """Test adding and removing publishers.""" - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe - - time.sleep(4) - with Subscribe("this_data", "counter", True) as sub: - assert len(sub.sub_addr) == 0 - with Publish("data_provider", 0, ["this_data"]): - time.sleep(4) - next(sub.recv(2)) - assert len(sub.sub_addr) == 1 - time.sleep(3) - for msg in sub.recv(2): - if msg is None: - break - time.sleep(3) - assert len(sub.sub_addr) == 0 - with Publish("data_provider_2", 0, ["another_data"]): - time.sleep(4) - next(sub.recv(2)) - assert len(sub.sub_addr) == 0 - sub.close() - - -class TestNSWithoutMulticasting(unittest.TestCase): - """Test the nameserver.""" +test_lock = Lock() - def setUp(self): - """Set up the testing class.""" - from posttroll.ns import NameServer - test_lock.acquire() - self.nameservers = ['localhost'] - self.ns = NameServer(max_age=timedelta(seconds=3), - multicast_enabled=False) - self.thr = Thread(target=self.ns.run) - self.thr.start() - def tearDown(self): - """Clean up after the tests have run.""" - self.ns.stop() - self.thr.join() - time.sleep(2) - test_lock.release() +def free_port(): + """Get a free port. - def test_pub_addresses(self): - """Test retrieving addresses.""" - from posttroll.ns import get_pub_addresses - from posttroll.publisher import Publish + From https://gist.github.com/bertjwregeer/0be94ced48383a42e70c3d9fff1f4ad0 - with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers): - time.sleep(3) - res = get_pub_addresses(["this_data"]) - self.assertEqual(len(res), 1) - expected = {u'status': True, - u'service': [u'data_provider', u'this_data'], - u'name': u'address'} - for key, val in expected.items(): - self.assertEqual(res[0][key], val) - self.assertTrue("receive_time" in res[0]) - self.assertTrue("URI" in res[0]) - res = get_pub_addresses(["data_provider"]) - self.assertEqual(len(res), 1) - expected = {u'status': True, - u'service': [u'data_provider', u'this_data'], - u'name': u'address'} - for key, val in expected.items(): - self.assertEqual(res[0][key], val) - self.assertTrue("receive_time" in res[0]) - self.assertTrue("URI" in res[0]) - - def test_pub_sub_ctx(self): - """Test publish and subscribe.""" - from posttroll.message import Message - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe + Returns a factory that finds the next free port that is available on the OS + This is a bit of a hack, it does this by creating a new socket, and calling + bind with the 0 port. The operating system will assign a brand new port, + which we can find out using getsockname(). Once we have the new port + information we close the socket thereby returning it to the free pool. + This means it is technically possible for this function to return the same + port twice (for example if run in very quick succession), however operating + systems return a random port number in the default range (1024 - 65535), + and it is highly unlikely for two processes to get the same port number. + In other words, it is possible to flake, but incredibly unlikely. + """ + import socket - with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers) as pub: - with Subscribe("this_data", "counter") as sub: - for counter in range(5): - message = Message("/counter", "info", str(counter)) - pub.send(str(message)) - time.sleep(1) - msg = next(sub.recv(2)) - if msg is not None: - self.assertEqual(str(msg), str(message)) - tested = True - sub.close() - self.assertTrue(tested) + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("0.0.0.0", 0)) + portnum = s.getsockname()[1] + s.close() - def test_pub_sub_add_rm(self): - """Test adding and removing publishers.""" - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe - - time.sleep(4) - with Subscribe("this_data", "counter", True) as sub: - self.assertEqual(len(sub.sub_addr), 0) - with Publish("data_provider", 0, ["this_data"], - nameservers=self.nameservers): - time.sleep(4) - next(sub.recv(2)) - self.assertEqual(len(sub.sub_addr), 1) - time.sleep(3) - for msg in sub.recv(2): - if msg is None: - break - - time.sleep(3) - self.assertEqual(len(sub.sub_addr), 0) - with Publish("data_provider_2", 0, ["another_data"], - nameservers=self.nameservers): - time.sleep(4) - next(sub.recv(2)) - self.assertEqual(len(sub.sub_addr), 0) + return portnum class TestPubSub(unittest.TestCase): @@ -233,27 +82,24 @@ def tearDown(self): def test_pub_address_timeout(self): """Test timeout in offline nameserver.""" from posttroll.ns import get_pub_address - from posttroll.ns import TimeoutError - - self.assertRaises(TimeoutError, - get_pub_address, ["this_data", 0.5]) + with pytest.raises(TimeoutError): + get_pub_address("this_data", 0.05) def test_pub_suber(self): """Test publisher and subscriber.""" - from posttroll.message import Message - from posttroll.publisher import Publisher from posttroll.publisher import get_own_ip - from posttroll.subscriber import Subscriber - pub_address = "tcp://" + str(get_own_ip()) + ":0" pub = Publisher(pub_address).start() addr = pub_address[:-1] + str(pub.port_number) - sub = Subscriber([addr], '/counter') + sub = Subscriber([addr], topics="/counter") + # wait a bit before sending the first message so that the subscriber is ready + time.sleep(.002) + tested = False for counter in range(5): message = Message("/counter", "info", str(counter)) pub.send(str(message)) - time.sleep(1) + time.sleep(.05) msg = next(sub.recv(2)) if msg is not None: @@ -264,22 +110,21 @@ def test_pub_suber(self): def test_pub_sub_ctx_no_nameserver(self): """Test publish and subscribe.""" - from posttroll.message import Message - from posttroll.publisher import Publish - from posttroll.subscriber import Subscribe - with Publish("data_provider", 40000, nameservers=False) as pub: with Subscribe(topics="counter", nameserver=False, addresses=["tcp://127.0.0.1:40000"]) as sub: + assert isinstance(sub, Subscriber) + # wait a bit before sending the first message so that the subscriber is ready + time.sleep(.002) for counter in range(5): message = Message("/counter", "info", str(counter)) pub.send(str(message)) - time.sleep(1) + time.sleep(.05) msg = next(sub.recv(2)) if msg is not None: - self.assertEqual(str(msg), str(message)) + assert str(msg) == str(message) tested = True sub.close() - self.assertTrue(tested) + assert tested class TestPub(unittest.TestCase): @@ -293,47 +138,50 @@ def tearDown(self): """Clean up after the tests have run.""" test_lock.release() - def test_pub_unicode(self): + def test_pub_supports_unicode(self): """Test publishing messages in Unicode.""" from posttroll.message import Message from posttroll.publisher import Publish - message = Message("/pџтяöll", "info", 'hej') - with Publish("a_service", 9000) as pub: + message = Message("/pџтяöll", "info", "hej") + with Publish("a_service", 0) as pub: try: pub.send(message.encode()) except UnicodeDecodeError: self.fail("Sending raised UnicodeDecodeError unexpectedly!") - def test_pub_minmax_port(self): - """Test user defined port range.""" + def test_pub_minmax_port_from_config(self): + """Test config defined port range.""" # Using environment variables to set port range # Try over a range of ports just in case the single port is reserved for port in range(40000, 50000): # Set the port range to config with posttroll.config.set(pub_min_port=str(port), pub_max_port=str(port + 1)): - res = _get_port(min_port=None, max_port=None) + res = _get_port_from_publish_instance(min_port=None, max_port=None) if res is False: # The port wasn't free, try another one continue # Port was selected, make sure it's within the "range" of one - self.assertEqual(res, port) + assert res == port break + def test_pub_minmax_port_from_instanciation(self): + """Test port range defined at instanciation.""" # Using range of ports defined at instantation time, this # should override environment variables for port in range(50000, 60000): - res = _get_port(min_port=port, max_port=port+1) + res = _get_port_from_publish_instance(min_port=port, max_port=port + 1) if res is False: # The port wasn't free, try again continue # Port was selected, make sure it's within the "range" of one - self.assertEqual(res, port) + assert res == port break -def _get_port(min_port=None, max_port=None): +def _get_port_from_publish_instance(min_port=None, max_port=None): from zmq.error import ZMQError + from posttroll.publisher import Publish try: @@ -347,48 +195,6 @@ def _get_port(min_port=None, max_port=None): return False -class TestListenerContainer(unittest.TestCase): - """Testing listener container.""" - - def setUp(self): - """Set up the testing class.""" - from posttroll.ns import NameServer - test_lock.acquire() - self.ns = NameServer(max_age=timedelta(seconds=3)) - self.thr = Thread(target=self.ns.run) - self.thr.start() - - def tearDown(self): - """Clean up after the tests have run.""" - self.ns.stop() - self.thr.join() - time.sleep(2) - test_lock.release() - - def test_listener_container(self): - """Test listener container.""" - from posttroll.message import Message - from posttroll.publisher import NoisyPublisher - from posttroll.listener import ListenerContainer - - pub = NoisyPublisher("test") - pub.start() - sub = ListenerContainer(topics=["/counter"]) - time.sleep(2) - for counter in range(5): - tested = False - msg_out = Message("/counter", "info", str(counter)) - pub.send(str(msg_out)) - - msg_in = sub.output_queue.get(True, 1) - if msg_in is not None: - self.assertEqual(str(msg_in), str(msg_out)) - tested = True - self.assertTrue(tested) - pub.stop() - sub.stop() - - class TestListenerContainerNoNameserver(unittest.TestCase): """Testing listener container without nameserver.""" @@ -402,9 +208,8 @@ def tearDown(self): def test_listener_container(self): """Test listener container.""" - from posttroll.message import Message - from posttroll.publisher import Publisher from posttroll.listener import ListenerContainer + from posttroll.message import Message pub_addr = "tcp://127.0.0.1:55000" pub = Publisher(pub_addr, name="test") @@ -419,144 +224,126 @@ def test_listener_container(self): msg_in = sub.output_queue.get(True, 1) if msg_in is not None: - self.assertEqual(str(msg_in), str(msg_out)) + assert str(msg_in) == str(msg_out) tested = True - self.assertTrue(tested) + assert tested pub.stop() sub.stop() -class TestAddressReceiver(unittest.TestCase): - """Test the AddressReceiver.""" +# Test create_publisher_from_config - @mock.patch("posttroll.address_receiver.Message") - @mock.patch("posttroll.address_receiver.Publish") - @mock.patch("posttroll.address_receiver.MulticastReceiver") - def test_localhost_restriction(self, mcrec, pub, msg): - """Test address receiver restricted only to localhost.""" - mcr_instance = mock.Mock() - mcrec.return_value = mcr_instance - mcr_instance.return_value = 'blabla', ('255.255.255.255', 12) - from posttroll.address_receiver import AddressReceiver - adr = AddressReceiver(restrict_to_localhost=True) - adr.start() - time.sleep(3) - msg.decode.assert_not_called() - adr.stop() +def test_publisher_with_invalid_arguments_crashes(): + """Test that only valid arguments are passed to Publisher.""" + settings = {"address": "ipc:///tmp/test.ipc", "nameservers": False, "invalid_arg": "bar"} + with pytest.raises(TypeError): + _ = create_publisher_from_dict_config(settings) -class TestPublisherDictConfig(unittest.TestCase): - """Test configuring publishers with a dictionary.""" +def test_publisher_is_selected(): + """Test that Publisher is selected as publisher class.""" + settings = {"port": 12345, "nameservers": False} - @mock.patch('posttroll.publisher.Publisher') - def test_publisher_is_selected(self, Publisher): - """Test that Publisher is selected as publisher class.""" - from posttroll.publisher import create_publisher_from_dict_config + pub = create_publisher_from_dict_config(settings) + assert isinstance(pub, Publisher) + assert pub is not None - settings = {'port': 12345, 'nameservers': False} - pub = create_publisher_from_dict_config(settings) - Publisher.assert_called_once() - assert pub is not None +@mock.patch("posttroll.publisher.Publisher") +def test_publisher_all_arguments(Publisher): + """Test that only valid arguments are passed to Publisher.""" + settings = {"port": 12345, "nameservers": False, "name": "foo", + "min_port": 40000, "max_port": 41000} + _ = create_publisher_from_dict_config(settings) + _check_valid_settings_in_call(settings, Publisher, ignore=["port", "nameservers"]) + assert Publisher.call_args[0][0].startswith("tcp://*:") + assert Publisher.call_args[0][0].endswith(str(settings["port"])) - @mock.patch('posttroll.publisher.Publisher') - def test_publisher_all_arguments(self, Publisher): - """Test that only valid arguments are passed to Publisher.""" - from posttroll.publisher import create_publisher_from_dict_config - settings = {'port': 12345, 'nameservers': False, 'name': 'foo', - 'min_port': 40000, 'max_port': 41000, 'invalid_arg': 'bar'} - _ = create_publisher_from_dict_config(settings) - _check_valid_settings_in_call(settings, Publisher, ignore=['port', 'nameservers']) - assert Publisher.call_args[0][0].startswith("tcp://*:") - assert Publisher.call_args[0][0].endswith(str(settings['port'])) +def test_no_name_raises_keyerror(): + """Trying to create a NoisyPublisher without a given name will raise KeyError.""" + with pytest.raises(KeyError): + _ = create_publisher_from_dict_config(dict()) - def test_no_name_raises_keyerror(self): - """Trying to create a NoisyPublisher without a given name will raise KeyError.""" - from posttroll.publisher import create_publisher_from_dict_config - with self.assertRaises(KeyError): - _ = create_publisher_from_dict_config(dict()) +def test_noisypublisher_is_selected_only_name(): + """Test that NoisyPublisher is selected as publisher class.""" + from posttroll.publisher import NoisyPublisher - @mock.patch('posttroll.publisher.NoisyPublisher') - def test_noisypublisher_is_selected_only_name(self, NoisyPublisher): - """Test that NoisyPublisher is selected as publisher class.""" - from posttroll.publisher import create_publisher_from_dict_config + settings = {"name": "publisher_name"} - settings = {'name': 'publisher_name'} + pub = create_publisher_from_dict_config(settings) + assert isinstance(pub, NoisyPublisher) - pub = create_publisher_from_dict_config(settings) - NoisyPublisher.assert_called_once() - assert pub is not None - @mock.patch('posttroll.publisher.NoisyPublisher') - def test_noisypublisher_is_selected_name_and_port(self, NoisyPublisher): - """Test that NoisyPublisher is selected as publisher class.""" - from posttroll.publisher import create_publisher_from_dict_config +def test_noisypublisher_is_selected_name_and_port(): + """Test that NoisyPublisher is selected as publisher class.""" + from posttroll.publisher import NoisyPublisher - settings = {'name': 'publisher_name', 'port': 40000} + settings = {"name": "publisher_name", "port": 40000} - _ = create_publisher_from_dict_config(settings) - NoisyPublisher.assert_called_once() + pub = create_publisher_from_dict_config(settings) + assert isinstance(pub, NoisyPublisher) - @mock.patch('posttroll.publisher.NoisyPublisher') - def test_noisypublisher_all_arguments(self, NoisyPublisher): - """Test that only valid arguments are passed to NoisyPublisher.""" - from posttroll.publisher import create_publisher_from_dict_config - settings = {'port': 12345, 'nameservers': ['foo'], 'name': 'foo', - 'min_port': 40000, 'max_port': 41000, 'invalid_arg': 'bar', - 'aliases': ['alias1', 'alias2'], 'broadcast_interval': 42} - _ = create_publisher_from_dict_config(settings) - _check_valid_settings_in_call(settings, NoisyPublisher, ignore=['name']) - assert NoisyPublisher.call_args[0][0] == settings["name"] +@mock.patch("posttroll.publisher.NoisyPublisher") +def test_noisypublisher_all_arguments(NoisyPublisher): + """Test that only valid arguments are passed to NoisyPublisher.""" + from posttroll.publisher import create_publisher_from_dict_config - @mock.patch('posttroll.publisher.Publisher') - def test_publish_is_not_noisy(self, Publisher): - """Test that Publisher is selected with the context manager when it should be.""" - from posttroll.publisher import Publish + settings = {"port": 12345, "nameservers": ["foo"], "name": "foo", + "min_port": 40000, "max_port": 41000, "invalid_arg": "bar", + "aliases": ["alias1", "alias2"], "broadcast_interval": 42} + _ = create_publisher_from_dict_config(settings) + _check_valid_settings_in_call(settings, NoisyPublisher, ignore=["name"]) + assert NoisyPublisher.call_args[0][0] == settings["name"] - with Publish("service_name", port=40000, nameservers=False): - Publisher.assert_called_once() - @mock.patch('posttroll.publisher.NoisyPublisher') - def test_publish_is_noisy_only_name(self, NoisyPublisher): - """Test that NoisyPublisher is selected with the context manager when only name is given.""" - from posttroll.publisher import Publish +def test_publish_is_not_noisy(): + """Test that Publisher is selected with the context manager when it should be.""" + from posttroll.publisher import Publish - with Publish("service_name"): - NoisyPublisher.assert_called_once() + with Publish("service_name", port=40000, nameservers=False) as pub: + assert isinstance(pub, Publisher) - @mock.patch('posttroll.publisher.NoisyPublisher') - def test_publish_is_noisy_with_port(self, NoisyPublisher): - """Test that NoisyPublisher is selected with the context manager when port is given.""" - from posttroll.publisher import Publish - with Publish("service_name", port=40000): - NoisyPublisher.assert_called_once() +def test_publish_is_noisy_only_name(): + """Test that NoisyPublisher is selected with the context manager when only name is given.""" + from posttroll.publisher import NoisyPublisher, Publish + + with Publish("service_name") as pub: + assert isinstance(pub, NoisyPublisher) + + +def test_publish_is_noisy_with_port(): + """Test that NoisyPublisher is selected with the context manager when port is given.""" + from posttroll.publisher import NoisyPublisher, Publish + + with Publish("service_name", port=40001) as pub: + assert isinstance(pub, NoisyPublisher) - @mock.patch('posttroll.publisher.NoisyPublisher') - def test_publish_is_noisy_with_nameservers(self, NoisyPublisher): - """Test that NoisyPublisher is selected with the context manager when nameservers are given.""" - from posttroll.publisher import Publish - with Publish("service_name", nameservers=['a', 'b']): - NoisyPublisher.assert_called_once() +def test_publish_is_noisy_with_nameservers(): + """Test that NoisyPublisher is selected with the context manager when nameservers are given.""" + from posttroll.publisher import NoisyPublisher, Publish + + with Publish("service_name", nameservers=["a", "b"]) as pub: + assert isinstance(pub, NoisyPublisher) def _check_valid_settings_in_call(settings, pub_class, ignore=None): ignore = ignore or [] for key in settings: - if key == 'invalid_arg': - assert 'invalid_arg' not in pub_class.call_args[1] + if key == "invalid_arg": + assert "invalid_arg" not in pub_class.call_args[1] continue if key in ignore: continue assert pub_class.call_args[1][key] == settings[key] -@mock.patch('posttroll.subscriber.Subscriber') -@mock.patch('posttroll.subscriber.NSSubscriber') +@mock.patch("posttroll.subscriber.Subscriber") +@mock.patch("posttroll.subscriber.NSSubscriber") def test_dict_config_minimal(NSSubscriber, Subscriber): """Test that without any settings NSSubscriber is created.""" from posttroll.subscriber import create_subscriber_from_dict_config @@ -567,31 +354,31 @@ def test_dict_config_minimal(NSSubscriber, Subscriber): Subscriber.assert_not_called() -@mock.patch('posttroll.subscriber.Subscriber') -@mock.patch('posttroll.subscriber.NSSubscriber') +@mock.patch("posttroll.subscriber.Subscriber") +@mock.patch("posttroll.subscriber.NSSubscriber") def test_dict_config_nameserver_false(NSSubscriber, Subscriber): """Test that NSSubscriber is created with 'localhost' nameserver when no addresses are given.""" from posttroll.subscriber import create_subscriber_from_dict_config - subscriber = create_subscriber_from_dict_config({'nameserver': False}) + subscriber = create_subscriber_from_dict_config({"nameserver": False}) NSSubscriber.assert_called_once() assert subscriber == NSSubscriber().start() Subscriber.assert_not_called() -@mock.patch('posttroll.subscriber.Subscriber') -@mock.patch('posttroll.subscriber.NSSubscriber') +@mock.patch("posttroll.subscriber.Subscriber") +@mock.patch("posttroll.subscriber.NSSubscriber") def test_dict_config_subscriber(NSSubscriber, Subscriber): """Test that Subscriber is created when nameserver is False and addresses are given.""" from posttroll.subscriber import create_subscriber_from_dict_config - subscriber = create_subscriber_from_dict_config({'nameserver': False, 'addresses': ['addr1']}) + subscriber = create_subscriber_from_dict_config({"nameserver": False, "addresses": ["addr1"]}) assert subscriber == Subscriber.return_value Subscriber.assert_called_once() NSSubscriber.assert_not_called() -@mock.patch('posttroll.subscriber.NSSubscriber.start') +@mock.patch("posttroll.subscriber.NSSubscriber.start") def test_dict_config_full_nssubscriber(NSSubscriber_start): """Test that all NSSubscriber options are passed.""" from posttroll.subscriber import create_subscriber_from_dict_config @@ -611,8 +398,7 @@ def test_dict_config_full_nssubscriber(NSSubscriber_start): NSSubscriber_start.assert_called_once() -@mock.patch('posttroll.subscriber.Subscriber.update') -def test_dict_config_full_subscriber(Subscriber_update): +def test_dict_config_full_subscriber(): """Test that all Subscriber options are passed.""" from posttroll.subscriber import create_subscriber_from_dict_config @@ -620,7 +406,7 @@ def test_dict_config_full_subscriber(Subscriber_update): "services": "val1", "topics": "val2", "addr_listener": "val3", - "addresses": "val4", + "addresses": "ipc://bla.ipc", "timeout": "val5", "translate": "val6", "nameserver": False, @@ -629,14 +415,10 @@ def test_dict_config_full_subscriber(Subscriber_update): _ = create_subscriber_from_dict_config(settings) -@pytest.fixture -def tcp_keepalive_settings(monkeypatch): +@pytest.fixture() +def _tcp_keepalive_settings(monkeypatch): """Set TCP Keepalive settings.""" - monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE", "1") - monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE_CNT", "10") - monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE_IDLE", "1") - monkeypatch.setenv("POSTTROLL_TCP_KEEPALIVE_INTVL", "1") - with reset_config_for_tests(): + with config.set(tcp_keepalive=1, tcp_keepalive_cnt=10, tcp_keepalive_idle=1, tcp_keepalive_intvl=1): yield @@ -649,97 +431,75 @@ def reset_config_for_tests(): posttroll.config = old_config -@pytest.fixture -def tcp_keepalive_no_settings(monkeypatch): +@pytest.fixture() +def _tcp_keepalive_no_settings(): """Set TCP Keepalive settings.""" - monkeypatch.delenv("POSTTROLL_TCP_KEEPALIVE", raising=False) - monkeypatch.delenv("POSTTROLL_TCP_KEEPALIVE_CNT", raising=False) - monkeypatch.delenv("POSTTROLL_TCP_KEEPALIVE_IDLE", raising=False) - monkeypatch.delenv("POSTTROLL_TCP_KEEPALIVE_INTVL", raising=False) - with reset_config_for_tests(): + with config.set(tcp_keepalive=None, tcp_keepalive_cnt=None, tcp_keepalive_idle=None, tcp_keepalive_intvl=None): yield -def test_publisher_tcp_keepalive(tcp_keepalive_settings): +@pytest.mark.usefixtures("_tcp_keepalive_settings") +def test_publisher_tcp_keepalive(): """Test that TCP Keepalive is set for Publisher if the environment variables are present.""" - socket = mock.MagicMock() - with mock.patch('posttroll.publisher.get_context') as get_context: - get_context.return_value.socket.return_value = socket - from posttroll.publisher import Publisher - - _ = Publisher("tcp://127.0.0.1:9000").start() - - _assert_tcp_keepalive(socket) + from posttroll.backends.zmq.publisher import ZMQPublisher + pub = ZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start() + _assert_tcp_keepalive(pub.publish_socket) + pub.stop() -def test_publisher_tcp_keepalive_not_set(tcp_keepalive_no_settings): +@pytest.mark.usefixtures("_tcp_keepalive_no_settings") +def test_publisher_tcp_keepalive_not_set(): """Test that TCP Keepalive is not set on by default.""" - socket = mock.MagicMock() - with mock.patch('posttroll.publisher.get_context') as get_context: - get_context.return_value.socket.return_value = socket - from posttroll.publisher import Publisher - - _ = Publisher("tcp://127.0.0.1:9000").start() - _assert_no_tcp_keepalive(socket) + from posttroll.backends.zmq.publisher import ZMQPublisher + pub = ZMQPublisher(f"tcp://127.0.0.1:{str(free_port())}").start() + _assert_no_tcp_keepalive(pub.publish_socket) + pub.stop() -def test_subscriber_tcp_keepalive(tcp_keepalive_settings): +@pytest.mark.usefixtures("_tcp_keepalive_settings") +def test_subscriber_tcp_keepalive(): """Test that TCP Keepalive is set for Subscriber if the environment variables are present.""" - socket = mock.MagicMock() - with mock.patch('posttroll.subscriber.get_context') as get_context: - get_context.return_value.socket.return_value = socket - from posttroll.subscriber import Subscriber - - _ = Subscriber("tcp://127.0.0.1:9000") + from posttroll.backends.zmq.subscriber import ZMQSubscriber + sub = ZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}") + assert len(sub.addr_sub.values()) == 1 + _assert_tcp_keepalive(list(sub.addr_sub.values())[0]) + sub.stop() - _assert_tcp_keepalive(socket) - -def test_subscriber_tcp_keepalive_not_set(tcp_keepalive_no_settings): +@pytest.mark.usefixtures("_tcp_keepalive_no_settings") +def test_subscriber_tcp_keepalive_not_set(): """Test that TCP Keepalive is not set on by default.""" - socket = mock.MagicMock() - with mock.patch('posttroll.subscriber.get_context') as get_context: - get_context.return_value.socket.return_value = socket - from posttroll.subscriber import Subscriber - - _ = Subscriber("tcp://127.0.0.1:9000") - - _assert_no_tcp_keepalive(socket) + from posttroll.backends.zmq.subscriber import ZMQSubscriber + sub = ZMQSubscriber(f"tcp://127.0.0.1:{str(free_port())}") + assert len(sub.addr_sub.values()) == 1 + _assert_no_tcp_keepalive(list(sub.addr_sub.values())[0]) + sub.close() def _assert_tcp_keepalive(socket): import zmq - assert mock.call(zmq.TCP_KEEPALIVE, 1) in socket.setsockopt.mock_calls - assert mock.call(zmq.TCP_KEEPALIVE_CNT, 10) in socket.setsockopt.mock_calls - assert mock.call(zmq.TCP_KEEPALIVE_IDLE, 1) in socket.setsockopt.mock_calls - assert mock.call(zmq.TCP_KEEPALIVE_INTVL, 1) in socket.setsockopt.mock_calls + assert socket.getsockopt(zmq.TCP_KEEPALIVE) == 1 + assert socket.getsockopt(zmq.TCP_KEEPALIVE_CNT) == 10 + assert socket.getsockopt(zmq.TCP_KEEPALIVE_IDLE) == 1 + assert socket.getsockopt(zmq.TCP_KEEPALIVE_INTVL) == 1 def _assert_no_tcp_keepalive(socket): - assert "TCP_KEEPALIVE" not in str(socket.setsockopt.mock_calls) - + import zmq -def test_noisypublisher_heartbeat(): - """Test that the heartbeat in the NoisyPublisher works.""" - from posttroll.ns import NameServer - from posttroll.publisher import NoisyPublisher - from posttroll.subscriber import Subscribe - - ns_ = NameServer() - thr = Thread(target=ns_.run) - thr.start() - - pub = NoisyPublisher("test") - pub.start() - time.sleep(0.2) - - with Subscribe("test", topics="/heartbeat/test", nameserver="localhost") as sub: - time.sleep(0.2) - pub.heartbeat(min_interval=10) - msg = next(sub.recv(1)) - assert msg.type == "beat" - assert msg.data == {'min_interval': 10} - pub.stop() - ns_.stop() - thr.join() + assert socket.getsockopt(zmq.TCP_KEEPALIVE) == -1 + assert socket.getsockopt(zmq.TCP_KEEPALIVE_CNT) == -1 + assert socket.getsockopt(zmq.TCP_KEEPALIVE_IDLE) == -1 + assert socket.getsockopt(zmq.TCP_KEEPALIVE_INTVL) == -1 + + +def test_switch_to_unknown_backend(): + """Test switching to unknown backend.""" + from posttroll.publisher import Publisher + from posttroll.subscriber import Subscriber + with config.set(backend="unsecure_and_deprecated"): + with pytest.raises(NotImplementedError): + Publisher("ipc://bla.ipc") + with pytest.raises(NotImplementedError): + Subscriber("ipc://bla.ipc") diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py new file mode 100644 index 0000000..f912125 --- /dev/null +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -0,0 +1,171 @@ + +"""Test the curve-based zmq backend.""" + +import os +import shutil +import time +from threading import Thread + +import zmq.auth + +from posttroll import config +from posttroll.backends.zmq import generate_keys +from posttroll.message import Message +from posttroll.ns import get_pub_address +from posttroll.publisher import Publisher, create_publisher_from_dict_config +from posttroll.subscriber import Subscriber, create_subscriber_from_dict_config +from posttroll.tests.test_nameserver import create_nameserver_instance + + +def create_keys(tmp_path): + """Create keys.""" + base_dir = tmp_path + keys_dir = base_dir / "certificates" + public_keys_dir = base_dir / "public_keys" + secret_keys_dir = base_dir / "private_keys" + + keys_dir.mkdir() + public_keys_dir.mkdir() + secret_keys_dir.mkdir() + + # create new keys in certificates dir + _server_public_file, _server_secret_file = zmq.auth.create_certificates( + keys_dir, "server" + ) + _client_public_file, _client_secret_file = zmq.auth.create_certificates( + keys_dir, "client" + ) + + # move public keys to appropriate directory + for key_file in os.listdir(keys_dir): + if key_file.endswith(".key"): + shutil.move( + os.path.join(keys_dir, key_file), os.path.join(public_keys_dir, ".") + ) + + # move secret keys to appropriate directory + for key_file in os.listdir(keys_dir): + if key_file.endswith(".key_secret"): + shutil.move( + os.path.join(keys_dir, key_file), os.path.join(secret_keys_dir, ".") + ) + + +def test_ipc_pubsub_with_sec(tmp_path): + """Test pub-sub on a secure ipc socket.""" + server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") + client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") + + ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" + + with config.set(backend="secure_zmq", + client_secret_key_file=client_secret_key_file, + clients_public_keys_directory=os.path.dirname(client_public_key_file), + server_public_key_file=server_public_key_file, + server_secret_key_file=server_secret_key_file): + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202) + sub = create_subscriber_from_dict_config(subscriber_settings) + + pub = Publisher(ipc_address) + + pub.start() + + def delayed_send(msg): + time.sleep(.2) + msg = Message(subject="/hi", atype="string", data=msg) + pub.send(str(msg)) + thr = Thread(target=delayed_send, args=["very sensitive message"]) + thr.start() + try: + for msg in sub.recv(): + assert msg.data == "very sensitive message" + break + finally: + sub.stop() + thr.join() + pub.stop() + + +def test_switch_to_secure_zmq_backend(tmp_path): + """Test switching to the secure_zmq backend.""" + create_keys(tmp_path) + + base_dir = tmp_path + public_keys_dir = base_dir / "public_keys" + secret_keys_dir = base_dir / "private_keys" + + server_secret_key = secret_keys_dir / "server.key_secret" + public_keys_directory = public_keys_dir + + client_secret_key = secret_keys_dir / "client.key_secret" + server_public_key = public_keys_dir / "server.key" + + with config.set(backend="secure_zmq", + client_secret_key_file=client_secret_key, + clients_public_keys_directory=public_keys_directory, + server_public_key_file=server_public_key, + server_secret_key_file=server_secret_key): + Publisher("ipc://bla.ipc") + Subscriber("ipc://bla.ipc") + + +def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): + """Test pub-sub on a secure ipc socket.""" + # create_keys(tmp_path) + + server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") + client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") + + ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" + + with config.set(backend="secure_zmq", + client_secret_key_file=client_secret_key_file, + clients_public_keys_directory=os.path.dirname(client_public_key_file), + server_public_key_file=server_public_key_file, + server_secret_key_file=server_secret_key_file): + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10203) + sub = create_subscriber_from_dict_config(subscriber_settings) + pub_settings = dict(address=ipc_address, + nameservers=False, port=1789) + pub = create_publisher_from_dict_config(pub_settings) + + pub.start() + + def delayed_send(msg): + time.sleep(.2) + msg = Message(subject="/hi", atype="string", data=msg) + pub.send(str(msg)) + thr = Thread(target=delayed_send, args=["very sensitive message"]) + thr.start() + try: + for msg in sub.recv(): + assert msg.data == "very sensitive message" + break + finally: + sub.stop() + thr.join() + pub.stop() + + +def test_switch_to_secure_backend_for_nameserver(tmp_path): + """Test switching backend for nameserver.""" + server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") + client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client") + with config.set(backend="secure_zmq", + client_secret_key_file=client_secret_key_file, + clients_public_keys_directory=os.path.dirname(client_public_key_file), + server_public_key_file=server_public_key_file, + server_secret_key_file=server_secret_key_file): + + with create_nameserver_instance(): + res = get_pub_address("some_name") + assert res == "" + + +def test_create_certificates_cli(tmp_path): + """Test the certificate creation cli.""" + name = "server" + args = [name, "-d", str(tmp_path)] + generate_keys(args) + assert (tmp_path / (name + ".key")).exists() + assert (tmp_path / (name + ".key_secret")).exists() diff --git a/posttroll/tests/test_unsecure_zmq_backend.py b/posttroll/tests/test_unsecure_zmq_backend.py new file mode 100644 index 0000000..1b2b469 --- /dev/null +++ b/posttroll/tests/test_unsecure_zmq_backend.py @@ -0,0 +1,67 @@ +"""Tests for the unsecure zmq backend.""" + +import time + +import pytest + +from posttroll import config +from posttroll.publisher import Publisher, create_publisher_from_dict_config +from posttroll.subscriber import Subscriber, create_subscriber_from_dict_config + + +def test_ipc_pubsub(tmp_path): + """Test pub-sub on an ipc socket.""" + ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" + + with config.set(backend="unsecure_zmq"): + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202) + sub = create_subscriber_from_dict_config(subscriber_settings) + pub = Publisher(ipc_address) + pub.start() + + def delayed_send(msg): + time.sleep(.2) + from posttroll.message import Message + msg = Message(subject="/hi", atype="string", data=msg) + pub.send(str(msg)) + pub.stop() + from threading import Thread + Thread(target=delayed_send, args=["hi"]).start() + for msg in sub.recv(): + assert msg.data == "hi" + break + sub.stop() + + +def test_switch_to_unsecure_zmq_backend(tmp_path): + """Test switching to the secure_zmq backend.""" + ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" + + with config.set(backend="unsecure_zmq"): + Publisher(ipc_address) + Subscriber(ipc_address) + + +def test_ipc_pub_crashes_when_passed_key_files(tmp_path): + """Test pub-sub on an ipc socket.""" + ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" + + with config.set(backend="unsecure_zmq"): + subscriber_settings = dict(addresses=ipc_address, topics="", nameserver=False, port=10202, + client_secret_key_file="my_secret_key", + server_public_key_file="server_public_key") + with pytest.raises(TypeError): + create_subscriber_from_dict_config(subscriber_settings) + + +def test_ipc_sub_crashes_when_passed_key_files(tmp_path): + """Test pub-sub on a secure ipc socket.""" + ipc_address = f"ipc://{str(tmp_path)}/bla.ipc" + + with config.set(backend="unsecure_zmq"): + pub_settings = dict(address=ipc_address, + server_secret_key="server.key_secret", + public_keys_directory="public_keys_dir", + nameservers=False, port=1789) + with pytest.raises(TypeError): + create_publisher_from_dict_config(pub_settings) diff --git a/posttroll/version.py b/posttroll/version.py deleted file mode 100644 index aabec7c..0000000 --- a/posttroll/version.py +++ /dev/null @@ -1,657 +0,0 @@ - -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.23 (https://github.com/python-versioneer/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys -from typing import Callable, Dict -import functools - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "v" - cfg.parentdir_prefix = "None" - cfg.versionfile_source = "posttroll/version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} - - -def register_vcs_handler(vcs, method): # decorator - """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - - popen_kwargs = {} - if sys.platform == "win32": - # This hides the console window if pythonw.exe is used - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - popen_kwargs["startupinfo"] = startupinfo - - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, process.returncode - return stdout, process.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r'\d', r): - continue - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - # GIT_DIR can interfere with correct operation of Versioneer. - # It may be intended to be passed to the Versioneer-versioned project, - # but that should not change where we get our version from. - env = os.environ.copy() - env.pop("GIT_DIR", None) - runner = functools.partial(runner, env=env) - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) - pieces["distance"] = len(out.split()) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces): - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def pep440_split_post(ver): - """Split pep440 version string at the post-release segment. - - Returns the release segments before the post-release and the - post-release version number (or -1 if no post-release segment is present). - """ - vc = str.split(ver, ".post") - return vc[0], int(vc[1] or 0) if len(vc) == 2 else None - - -def render_pep440_pre(pieces): - """TAG[.postN.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - if pieces["distance"]: - # update the post release segment - tag_version, post_version = pep440_split_post(pieces["closest-tag"]) - rendered = tag_version - if post_version is not None: - rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) - else: - rendered += ".post0.dev%d" % (pieces["distance"]) - else: - # no commits, use the tag as the version - rendered = pieces["closest-tag"] - else: - # exception #1 - rendered = "0.post0.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..09c9663 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,77 @@ +[project] +name = "posttroll" +dynamic = ["version"] +description = "Messaging system for pytroll" +authors = [ + { name = "The Pytroll Team", email = "pytroll@googlegroups.com" } +] +dependencies = [ + "pyzmq", + "netifaces-plus", + "donfig", +] +readme = "README.md" +requires-python = ">=3.10" +license = { text = "GPLv3" } +classifiers = [ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Programming Language :: Python", + "Operating System :: OS Independent", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering", + "Topic :: Communications" +] + +[project.scripts] +pytroll-logger = "posttroll.logger:run" +posttroll-generate-keys = "posttroll.backends.zmq:generate_keys" + +[project.urls] +Homepage = "https://github.com/pytroll/posttroll" +"Bug Tracker" = "https://github.com/pytroll/posttroll/issues" +Documentation = "https://posttroll.readthedocs.io/" +"Source Code" = "https://github.com/pytroll/posttroll" +Organization = "https://pytroll.github.io/" +Slack = "https://pytroll.slack.com/" +"Release Notes" = "https://github.com/pytroll/posttroll/blob/main/CHANGELOG.md" + +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["posttroll"] + +[tool.hatch.version] +source = "vcs" + +[tool.hatch.build.hooks.vcs] +version-file = "posttroll/version.py" + +[tool.isort] +sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] +skip_gitignore = true +default_section = "THIRDPARTY" +known_first_party = "posttroll" +line_length = 120 + + +[tool.ruff] +# See https://docs.astral.sh/ruff/rules/ +# In the future, add "B", "S", "N" +lint.select = ["A", "D", "E", "W", "F", "I", "PT", "TID", "C90", "Q", "T10", "T20"] +line-length = 120 +exclude = ["versioneer.py", + "posttroll/version.py", + "doc"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.mccabe] +# Unlike Flake8, default to a complexity level of 10. +max-complexity = 10 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 96fbf05..0000000 --- a/setup.cfg +++ /dev/null @@ -1,14 +0,0 @@ -[bdist_rpm] -requires=python-daemon pyzmq -release=1 - -[versioneer] -VCS = git -style = pep440 -versionfile_source = posttroll/version.py -versionfile_build = -tag_prefix = v -#parentdir_prefix = myproject- - -[flake8] -max-line-length = 120 diff --git a/setup.py b/setup.py deleted file mode 100644 index e80d8d6..0000000 --- a/setup.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright (c) 2011, 2012, 2014, 2015, 2020. - -# Author(s): - -# The pytroll team: -# Martin Raspaud - -# This file is part of pytroll. - -# This is free software: you can redistribute it and/or modify it under the -# terms of the GNU General Public License as published by the Free Software -# Foundation, either version 3 of the License, or (at your option) any later -# version. - -# This program is distributed in the hope that it will be useful, but WITHOUT -# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS -# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more -# details. - -# You should have received a copy of the GNU General Public License along with -# this program. If not, see . - -from setuptools import setup -import versioneer - - -requirements = ['pyzmq', 'netifaces', "donfig"] - - -setup(name="posttroll", - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - description='Messaging system for pytroll', - author='The pytroll team', - author_email='pytroll@googlegroups.com', - url="http://github.com/pytroll/posttroll", - packages=['posttroll'], - entry_points={ - 'console_scripts': ['pytroll-logger = posttroll.logger:run', ]}, - scripts=['bin/nameserver'], - zip_safe=False, - license="GPLv3", - install_requires=requirements, - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', - 'Programming Language :: Python', - 'Operating System :: OS Independent', - 'Intended Audience :: Science/Research', - 'Topic :: Scientific/Engineering', - 'Topic :: Communications' - ], - python_requires='>=3.7', - test_suite='posttroll.tests.suite', - ) diff --git a/versioneer.py b/versioneer.py deleted file mode 100644 index 070e384..0000000 --- a/versioneer.py +++ /dev/null @@ -1,2146 +0,0 @@ - -# Version: 0.23 - -"""The Versioneer - like a rocketeer, but for versions. - -The Versioneer -============== - -* like a rocketeer, but for versions! -* https://github.com/python-versioneer/python-versioneer -* Brian Warner -* License: Public Domain (CC0-1.0) -* Compatible with: Python 3.7, 3.8, 3.9, 3.10 and pypy3 -* [![Latest Version][pypi-image]][pypi-url] -* [![Build Status][travis-image]][travis-url] - -This is a tool for managing a recorded version number in distutils/setuptools-based -python projects. The goal is to remove the tedious and error-prone "update -the embedded version string" step from your release process. Making a new -release should be as easy as recording a new tag in your version-control -system, and maybe making new tarballs. - - -## Quick Install - -* `pip install versioneer` to somewhere in your $PATH -* add a `[versioneer]` section to your setup.cfg (see [Install](INSTALL.md)) -* run `versioneer install` in your source tree, commit the results -* Verify version information with `python setup.py version` - -## Version Identifiers - -Source trees come from a variety of places: - -* a version-control system checkout (mostly used by developers) -* a nightly tarball, produced by build automation -* a snapshot tarball, produced by a web-based VCS browser, like github's - "tarball from tag" feature -* a release tarball, produced by "setup.py sdist", distributed through PyPI - -Within each source tree, the version identifier (either a string or a number, -this tool is format-agnostic) can come from a variety of places: - -* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows - about recent "tags" and an absolute revision-id -* the name of the directory into which the tarball was unpacked -* an expanded VCS keyword ($Id$, etc) -* a `_version.py` created by some earlier build step - -For released software, the version identifier is closely related to a VCS -tag. Some projects use tag names that include more than just the version -string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool -needs to strip the tag prefix to extract the version identifier. For -unreleased software (between tags), the version identifier should provide -enough information to help developers recreate the same tree, while also -giving them an idea of roughly how old the tree is (after version 1.2, before -version 1.3). Many VCS systems can report a description that captures this, -for example `git describe --tags --dirty --always` reports things like -"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the -0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes). - -The version identifier is used for multiple purposes: - -* to allow the module to self-identify its version: `myproject.__version__` -* to choose a name and prefix for a 'setup.py sdist' tarball - -## Theory of Operation - -Versioneer works by adding a special `_version.py` file into your source -tree, where your `__init__.py` can import it. This `_version.py` knows how to -dynamically ask the VCS tool for version information at import time. - -`_version.py` also contains `$Revision$` markers, and the installation -process marks `_version.py` to have this marker rewritten with a tag name -during the `git archive` command. As a result, generated tarballs will -contain enough information to get the proper version. - -To allow `setup.py` to compute a version too, a `versioneer.py` is added to -the top level of your source tree, next to `setup.py` and the `setup.cfg` -that configures it. This overrides several distutils/setuptools commands to -compute the version when invoked, and changes `setup.py build` and `setup.py -sdist` to replace `_version.py` with a small static file that contains just -the generated version data. - -## Installation - -See [INSTALL.md](./INSTALL.md) for detailed installation instructions. - -## Version-String Flavors - -Code which uses Versioneer can learn about its version string at runtime by -importing `_version` from your main `__init__.py` file and running the -`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can -import the top-level `versioneer.py` and run `get_versions()`. - -Both functions return a dictionary with different flavors of version -information: - -* `['version']`: A condensed version string, rendered using the selected - style. This is the most commonly used value for the project's version - string. The default "pep440" style yields strings like `0.11`, - `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section - below for alternative styles. - -* `['full-revisionid']`: detailed revision identifier. For Git, this is the - full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". - -* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the - commit date in ISO 8601 format. This will be None if the date is not - available. - -* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that - this is only accurate if run in a VCS checkout, otherwise it is likely to - be False or None - -* `['error']`: if the version string could not be computed, this will be set - to a string describing the problem, otherwise it will be None. It may be - useful to throw an exception in setup.py if this is set, to avoid e.g. - creating tarballs with a version string of "unknown". - -Some variants are more useful than others. Including `full-revisionid` in a -bug report should allow developers to reconstruct the exact code being tested -(or indicate the presence of local changes that should be shared with the -developers). `version` is suitable for display in an "about" box or a CLI -`--version` output: it can be easily compared against release notes and lists -of bugs fixed in various releases. - -The installer adds the following text to your `__init__.py` to place a basic -version in `YOURPROJECT.__version__`: - - from ._version import get_versions - __version__ = get_versions()['version'] - del get_versions - -## Styles - -The setup.cfg `style=` configuration controls how the VCS information is -rendered into a version string. - -The default style, "pep440", produces a PEP440-compliant string, equal to the -un-prefixed tag name for actual releases, and containing an additional "local -version" section with more detail for in-between builds. For Git, this is -TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags ---dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the -tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and -that this commit is two revisions ("+2") beyond the "0.11" tag. For released -software (exactly equal to a known tag), the identifier will only contain the -stripped tag, e.g. "0.11". - -Other styles are available. See [details.md](details.md) in the Versioneer -source tree for descriptions. - -## Debugging - -Versioneer tries to avoid fatal errors: if something goes wrong, it will tend -to return a version of "0+unknown". To investigate the problem, run `setup.py -version`, which will run the version-lookup code in a verbose mode, and will -display the full contents of `get_versions()` (including the `error` string, -which may help identify what went wrong). - -## Known Limitations - -Some situations are known to cause problems for Versioneer. This details the -most significant ones. More can be found on Github -[issues page](https://github.com/python-versioneer/python-versioneer/issues). - -### Subprojects - -Versioneer has limited support for source trees in which `setup.py` is not in -the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are -two common reasons why `setup.py` might not be in the root: - -* Source trees which contain multiple subprojects, such as - [Buildbot](https://github.com/buildbot/buildbot), which contains both - "master" and "slave" subprojects, each with their own `setup.py`, - `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI - distributions (and upload multiple independently-installable tarballs). -* Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other languages) in subdirectories. - -Versioneer will look for `.git` in parent directories, and most operations -should get the right version string. However `pip` and `setuptools` have bugs -and implementation details which frequently cause `pip install .` from a -subproject directory to fail to find a correct version string (so it usually -defaults to `0+unknown`). - -`pip install --editable .` should work correctly. `setup.py install` might -work too. - -Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in -some later version. - -[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking -this issue. The discussion in -[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the -issue from the Versioneer side in more detail. -[pip PR#3176](https://github.com/pypa/pip/pull/3176) and -[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve -pip to let Versioneer work correctly. - -Versioneer-0.16 and earlier only looked for a `.git` directory next to the -`setup.cfg`, so subprojects were completely unsupported with those releases. - -### Editable installs with setuptools <= 18.5 - -`setup.py develop` and `pip install --editable .` allow you to install a -project into a virtualenv once, then continue editing the source code (and -test) without re-installing after every change. - -"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a -convenient way to specify executable scripts that should be installed along -with the python package. - -These both work as expected when using modern setuptools. When using -setuptools-18.5 or earlier, however, certain operations will cause -`pkg_resources.DistributionNotFound` errors when running the entrypoint -script, which must be resolved by re-installing the package. This happens -when the install happens with one version, then the egg_info data is -regenerated while a different version is checked out. Many setup.py commands -cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into -a different virtualenv), so this can be surprising. - -[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes -this one, but upgrading to a newer version of setuptools should probably -resolve it. - - -## Updating Versioneer - -To upgrade your project to a new release of Versioneer, do the following: - -* install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace - `SRC/_version.py` -* commit any changed files - -## Future Directions - -This tool is designed to make it easily extended to other version-control -systems: all VCS-specific components are in separate directories like -src/git/ . The top-level `versioneer.py` script is assembled from these -components by running make-versioneer.py . In the future, make-versioneer.py -will take a VCS name as an argument, and will construct a version of -`versioneer.py` that is specific to the given VCS. It might also take the -configuration arguments that are currently provided manually during -installation by editing setup.py . Alternatively, it might go the other -direction and include code from all supported VCS systems, reducing the -number of intermediate scripts. - -## Similar projects - -* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time - dependency -* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of - versioneer -* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools - plugin - -## License - -To make Versioneer easier to embed, all its code is dedicated to the public -domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . - -[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg -[pypi-url]: https://pypi.python.org/pypi/versioneer/ -[travis-image]: -https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg -[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer - -""" -# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring -# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements -# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error -# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with -# pylint:disable=attribute-defined-outside-init,too-many-arguments - -import configparser -import errno -import json -import os -import re -import subprocess -import sys -from typing import Callable, Dict -import functools - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_root(): - """Get the project root directory. - - We require that all commands are run from the project root, i.e. the - directory that contains setup.py, setup.cfg, and versioneer.py . - """ - root = os.path.realpath(os.path.abspath(os.getcwd())) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - # allow 'python path/to/setup.py COMMAND' - root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") - raise VersioneerBadRootError(err) - try: - # Certain runtime workflows (setup.py install/develop in a setuptools - # tree) execute all dependencies in a single python process, so - # "versioneer" may be imported multiple times, and python's shared - # module-import table will cache the first one. So we can't use - # os.path.dirname(__file__), as that will find whichever - # versioneer.py was first imported, even in later projects. - my_path = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(my_path)[0]) - vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(my_path), versioneer_py)) - except NameError: - pass - return root - - -def get_config_from_root(root): - """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise OSError (if setup.cfg is missing), or - # configparser.NoSectionError (if it lacks a [versioneer] section), or - # configparser.NoOptionError (if it lacks "VCS="). See the docstring at - # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.ConfigParser() - with open(setup_cfg, "r") as cfg_file: - parser.read_file(cfg_file) - VCS = parser.get("versioneer", "VCS") # mandatory - - # Dict-like interface for non-mandatory entries - section = parser["versioneer"] - - cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = section.get("style", "") - cfg.versionfile_source = section.get("versionfile_source") - cfg.versionfile_build = section.get("versionfile_build") - cfg.tag_prefix = section.get("tag_prefix") - if cfg.tag_prefix in ("''", '""', None): - cfg.tag_prefix = "" - cfg.parentdir_prefix = section.get("parentdir_prefix") - cfg.verbose = section.get("verbose") - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -# these dictionaries contain VCS-specific tools -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} - - -def register_vcs_handler(vcs, method): # decorator - """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - HANDLERS.setdefault(vcs, {})[method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - - popen_kwargs = {} - if sys.platform == "win32": - # This hides the console window if pythonw.exe is used - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - popen_kwargs["startupinfo"] = startupinfo - - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, process.returncode - return stdout, process.returncode - - -LONG_VERSION_PY['git'] = r''' -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.23 (https://github.com/python-versioneer/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys -from typing import Callable, Dict -import functools - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" - git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" - git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "%(STYLE)s" - cfg.tag_prefix = "%(TAG_PREFIX)s" - cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" - cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} - - -def register_vcs_handler(vcs, method): # decorator - """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - - popen_kwargs = {} - if sys.platform == "win32": - # This hides the console window if pythonw.exe is used - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - popen_kwargs["startupinfo"] = startupinfo - - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %%s" %% dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %%s" %% (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %%s (error)" %% dispcmd) - print("stdout was %%s" %% stdout) - return None, process.returncode - return stdout, process.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %%s but none started with prefix %%s" %% - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %%d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} - if verbose: - print("discarding '%%s', no digits" %% ",".join(refs - tags)) - if verbose: - print("likely tags: %%s" %% ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r'\d', r): - continue - if verbose: - print("picking %%s" %% r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - # GIT_DIR can interfere with correct operation of Versioneer. - # It may be intended to be passed to the Versioneer-versioned project, - # but that should not change where we get our version from. - env = os.environ.copy() - env.pop("GIT_DIR", None) - runner = functools.partial(runner, env=env) - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %%s not under git control" %% root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%%s'" - %% describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%%s' doesn't start with prefix '%%s'" - print(fmt %% (full_tag, tag_prefix)) - pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" - %% (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) - pieces["distance"] = len(out.split()) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces): - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def pep440_split_post(ver): - """Split pep440 version string at the post-release segment. - - Returns the release segments before the post-release and the - post-release version number (or -1 if no post-release segment is present). - """ - vc = str.split(ver, ".post") - return vc[0], int(vc[1] or 0) if len(vc) == 2 else None - - -def render_pep440_pre(pieces): - """TAG[.postN.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - if pieces["distance"]: - # update the post release segment - tag_version, post_version = pep440_split_post(pieces["closest-tag"]) - rendered = tag_version - if post_version is not None: - rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"]) - else: - rendered += ".post0.dev%%d" %% (pieces["distance"]) - else: - # no commits, use the tag as the version - rendered = pieces["closest-tag"] - else: - # exception #1 - rendered = "0.post0.dev%%d" %% pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%%s'" %% style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} -''' - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r'\d', r): - continue - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - # GIT_DIR can interfere with correct operation of Versioneer. - # It may be intended to be passed to the Versioneer-versioned project, - # but that should not change where we get our version from. - env = os.environ.copy() - env.pop("GIT_DIR", None) - runner = functools.partial(runner, env=env) - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) - pieces["distance"] = len(out.split()) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def do_vcs_install(versionfile_source, ipy): - """Git-specific installation logic for Versioneer. - - For Git, this means creating/changing .gitattributes to mark _version.py - for export-subst keyword substitution. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - files = [versionfile_source] - if ipy: - files.append(ipy) - try: - my_path = __file__ - if my_path.endswith(".pyc") or my_path.endswith(".pyo"): - my_path = os.path.splitext(my_path)[0] + ".py" - versioneer_file = os.path.relpath(my_path) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) - present = False - try: - with open(".gitattributes", "r") as fobj: - for line in fobj: - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - break - except OSError: - pass - if not present: - with open(".gitattributes", "a+") as fobj: - fobj.write(f"{versionfile_source} export-subst\n") - files.append(".gitattributes") - run_command(GITS, ["add", "--"] + files) - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.23) from -# revision-control system data, or from the parent directory name of an -# unpacked source archive. Distribution tarballs contain a pre-generated copy -# of this file. - -import json - -version_json = ''' -%s -''' # END VERSION_JSON - - -def get_versions(): - return json.loads(version_json) -""" - - -def versions_from_file(filename): - """Try to determine the version from _version.py if present.""" - try: - with open(filename) as f: - contents = f.read() - except OSError: - raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - raise NotThisMethod("no version_json in _version.py") - return json.loads(mo.group(1)) - - -def write_to_version_file(filename, versions): - """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) - with open(filename, "w") as f: - f.write(SHORT_VERSION_PY % contents) - - print("set %s to '%s'" % (filename, versions["version"])) - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces): - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def pep440_split_post(ver): - """Split pep440 version string at the post-release segment. - - Returns the release segments before the post-release and the - post-release version number (or -1 if no post-release segment is present). - """ - vc = str.split(ver, ".post") - return vc[0], int(vc[1] or 0) if len(vc) == 2 else None - - -def render_pep440_pre(pieces): - """TAG[.postN.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - if pieces["distance"]: - # update the post release segment - tag_version, post_version = pep440_split_post(pieces["closest-tag"]) - rendered = tag_version - if post_version is not None: - rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) - else: - rendered += ".post0.dev%d" % (pieces["distance"]) - else: - # no commits, use the tag as the version - rendered = pieces["closest-tag"] - else: - # exception #1 - rendered = "0.post0.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -class VersioneerBadRootError(Exception): - """The project root directory is unknown or missing key files.""" - - -def get_versions(verbose=False): - """Get the project version from whatever source is available. - - Returns dict with two keys: 'version' and 'full'. - """ - if "versioneer" in sys.modules: - # see the discussion in cmdclass.py:get_cmdclass() - del sys.modules["versioneer"] - - root = get_root() - cfg = get_config_from_root(root) - - assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" - handlers = HANDLERS.get(cfg.VCS) - assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" - assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" - - versionfile_abs = os.path.join(root, cfg.versionfile_source) - - # extract version from first of: _version.py, VCS command (e.g. 'git - # describe'), parentdir. This is meant to work for developers using a - # source checkout, for users of a tarball created by 'setup.py sdist', - # and for users of a tarball/zipball created by 'git archive' or github's - # download-from-tag feature or the equivalent in other VCSes. - - get_keywords_f = handlers.get("get_keywords") - from_keywords_f = handlers.get("keywords") - if get_keywords_f and from_keywords_f: - try: - keywords = get_keywords_f(versionfile_abs) - ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) - if verbose: - print("got version from expanded keyword %s" % ver) - return ver - except NotThisMethod: - pass - - try: - ver = versions_from_file(versionfile_abs) - if verbose: - print("got version from file %s %s" % (versionfile_abs, ver)) - return ver - except NotThisMethod: - pass - - from_vcs_f = handlers.get("pieces_from_vcs") - if from_vcs_f: - try: - pieces = from_vcs_f(cfg.tag_prefix, root, verbose) - ver = render(pieces, cfg.style) - if verbose: - print("got version from VCS %s" % ver) - return ver - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - if verbose: - print("got version from parentdir %s" % ver) - return ver - except NotThisMethod: - pass - - if verbose: - print("unable to compute version") - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} - - -def get_version(): - """Get the short version string for this project.""" - return get_versions()["version"] - - -def get_cmdclass(cmdclass=None): - """Get the custom setuptools subclasses used by Versioneer. - - If the package uses a different cmdclass (e.g. one from numpy), it - should be provide as an argument. - """ - if "versioneer" in sys.modules: - del sys.modules["versioneer"] - # this fixes the "python setup.py develop" case (also 'install' and - # 'easy_install .'), in which subdependencies of the main project are - # built (using setup.py bdist_egg) in the same python process. Assume - # a main project A and a dependency B, which use different versions - # of Versioneer. A's setup.py imports A's Versioneer, leaving it in - # sys.modules by the time B's setup.py is executed, causing B to run - # with the wrong versioneer. Setuptools wraps the sub-dep builds in a - # sandbox that restores sys.modules to it's pre-build state, so the - # parent is protected against the child's "import versioneer". By - # removing ourselves from sys.modules here, before the child build - # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/python-versioneer/python-versioneer/issues/52 - - cmds = {} if cmdclass is None else cmdclass.copy() - - # we add "version" to setuptools - from setuptools import Command - - class cmd_version(Command): - description = "report generated version string" - user_options = [] - boolean_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - vers = get_versions(verbose=True) - print("Version: %s" % vers["version"]) - print(" full-revisionid: %s" % vers.get("full-revisionid")) - print(" dirty: %s" % vers.get("dirty")) - print(" date: %s" % vers.get("date")) - if vers["error"]: - print(" error: %s" % vers["error"]) - cmds["version"] = cmd_version - - # we override "build_py" in setuptools - # - # most invocation pathways end up running build_py: - # distutils/build -> build_py - # distutils/install -> distutils/build ->.. - # setuptools/bdist_wheel -> distutils/install ->.. - # setuptools/bdist_egg -> distutils/install_lib -> build_py - # setuptools/install -> bdist_egg ->.. - # setuptools/develop -> ? - # pip install: - # copies source tree to a tempdir before running egg_info/etc - # if .git isn't copied too, 'git describe' will fail - # then does setup.py bdist_wheel, or sometimes setup.py install - # setup.py egg_info -> ? - - # pip install -e . and setuptool/editable_wheel will invoke build_py - # but the build_py command is not expected to copy any files. - - # we override different "build_py" commands for both environments - if 'build_py' in cmds: - _build_py = cmds['build_py'] - else: - from setuptools.command.build_py import build_py as _build_py - - class cmd_build_py(_build_py): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_py.run(self) - if getattr(self, "editable_mode", False): - # During editable installs `.py` and data files are - # not copied to build_lib - return - # now locate _version.py in the new build/ directory and replace - # it with an updated value - if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - cmds["build_py"] = cmd_build_py - - if 'build_ext' in cmds: - _build_ext = cmds['build_ext'] - else: - from setuptools.command.build_ext import build_ext as _build_ext - - class cmd_build_ext(_build_ext): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_ext.run(self) - if self.inplace: - # build_ext --inplace will only build extensions in - # build/lib<..> dir with no _version.py to write to. - # As in place builds will already have a _version.py - # in the module dir, we do not need to write one. - return - # now locate _version.py in the new build/ directory and replace - # it with an updated value - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) - if not os.path.exists(target_versionfile): - print(f"Warning: {target_versionfile} does not exist, skipping " - "version update. This can happen if you are running build_ext " - "without first running build_py.") - return - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - cmds["build_ext"] = cmd_build_ext - - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe - # nczeczulin reports that py2exe won't like the pep440-style string - # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. - # setup(console=[{ - # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION - # "product_version": versioneer.get_version(), - # ... - - class cmd_build_exe(_build_exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _build_exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["build_exe"] = cmd_build_exe - del cmds["build_py"] - - if 'py2exe' in sys.modules: # py2exe enabled? - from py2exe.distutils_buildexe import py2exe as _py2exe - - class cmd_py2exe(_py2exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _py2exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["py2exe"] = cmd_py2exe - - # sdist farms its file list building out to egg_info - if 'egg_info' in cmds: - _sdist = cmds['egg_info'] - else: - from setuptools.command.egg_info import egg_info as _egg_info - - class cmd_egg_info(_egg_info): - def find_sources(self): - # egg_info.find_sources builds the manifest list and writes it - # in one shot - super().find_sources() - - # Modify the filelist and normalize it - root = get_root() - cfg = get_config_from_root(root) - self.filelist.append('versioneer.py') - if cfg.versionfile_source: - # There are rare cases where versionfile_source might not be - # included by default, so we must be explicit - self.filelist.append(cfg.versionfile_source) - self.filelist.sort() - self.filelist.remove_duplicates() - - # The write method is hidden in the manifest_maker instance that - # generated the filelist and was thrown away - # We will instead replicate their final normalization (to unicode, - # and POSIX-style paths) - from setuptools import unicode_utils - normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/') - for f in self.filelist.files] - - manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt') - with open(manifest_filename, 'w') as fobj: - fobj.write('\n'.join(normalized)) - - cmds['egg_info'] = cmd_egg_info - - # we override different "sdist" commands for both environments - if 'sdist' in cmds: - _sdist = cmds['sdist'] - else: - from setuptools.command.sdist import sdist as _sdist - - class cmd_sdist(_sdist): - def run(self): - versions = get_versions() - self._versioneer_generated_versions = versions - # unless we update this, the command will keep using the old - # version - self.distribution.metadata.version = versions["version"] - return _sdist.run(self) - - def make_release_tree(self, base_dir, files): - root = get_root() - cfg = get_config_from_root(root) - _sdist.make_release_tree(self, base_dir, files) - # now locate _version.py in the new base_dir directory - # (remembering that it may be a hardlink) and replace it with an - # updated value - target_versionfile = os.path.join(base_dir, cfg.versionfile_source) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) - cmds["sdist"] = cmd_sdist - - return cmds - - -CONFIG_ERROR = """ -setup.cfg is missing the necessary Versioneer configuration. You need -a section like: - - [versioneer] - VCS = git - style = pep440 - versionfile_source = src/myproject/_version.py - versionfile_build = myproject/_version.py - tag_prefix = - parentdir_prefix = myproject- - -You will also need to edit your setup.py to use the results: - - import versioneer - setup(version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), ...) - -Please read the docstring in ./versioneer.py for configuration instructions, -edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. -""" - -SAMPLE_CONFIG = """ -# See the docstring in versioneer.py for instructions. Note that you must -# re-run 'versioneer.py setup' after changing this section, and commit the -# resulting files. - -[versioneer] -#VCS = git -#style = pep440 -#versionfile_source = -#versionfile_build = -#tag_prefix = -#parentdir_prefix = - -""" - -OLD_SNIPPET = """ -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions -""" - -INIT_PY_SNIPPET = """ -from . import {0} -__version__ = {0}.get_versions()['version'] -""" - - -def do_setup(): - """Do main VCS-independent setup function for installing Versioneer.""" - root = get_root() - try: - cfg = get_config_from_root(root) - except (OSError, configparser.NoSectionError, - configparser.NoOptionError) as e: - if isinstance(e, (OSError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) - with open(os.path.join(root, "setup.cfg"), "a") as f: - f.write(SAMPLE_CONFIG) - print(CONFIG_ERROR, file=sys.stderr) - return 1 - - print(" creating %s" % cfg.versionfile_source) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") - if os.path.exists(ipy): - try: - with open(ipy, "r") as f: - old = f.read() - except OSError: - old = "" - module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] - snippet = INIT_PY_SNIPPET.format(module) - if OLD_SNIPPET in old: - print(" replacing boilerplate in %s" % ipy) - with open(ipy, "w") as f: - f.write(old.replace(OLD_SNIPPET, snippet)) - elif snippet not in old: - print(" appending to %s" % ipy) - with open(ipy, "a") as f: - f.write(snippet) - else: - print(" %s unmodified" % ipy) - else: - print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make VCS-specific changes. For git, this means creating/changing - # .gitattributes to mark _version.py for export-subst keyword - # substitution. - do_vcs_install(cfg.versionfile_source, ipy) - return 0 - - -def scan_setup_py(): - """Validate the contents of setup.py against Versioneer's expectations.""" - found = set() - setters = False - errors = 0 - with open("setup.py", "r") as f: - for line in f.readlines(): - if "import versioneer" in line: - found.add("import") - if "versioneer.get_cmdclass()" in line: - found.add("cmdclass") - if "versioneer.get_version()" in line: - found.add("get_version") - if "versioneer.VCS" in line: - setters = True - if "versioneer.versionfile_source" in line: - setters = True - if len(found) != 3: - print("") - print("Your setup.py appears to be missing some important items") - print("(but I might be wrong). Please make sure it has something") - print("roughly like the following:") - print("") - print(" import versioneer") - print(" setup( version=versioneer.get_version(),") - print(" cmdclass=versioneer.get_cmdclass(), ...)") - print("") - errors += 1 - if setters: - print("You should remove lines like 'versioneer.VCS = ' and") - print("'versioneer.versionfile_source = ' . This configuration") - print("now lives in setup.cfg, and should be removed from setup.py") - print("") - errors += 1 - return errors - - -if __name__ == "__main__": - cmd = sys.argv[1] - if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1)