diff --git a/.travis.yml b/.travis.yml index 3bcdc76e..42c51edb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,7 +26,9 @@ script: - export LD_LIBRARY_PATH=/usr/local/lib/x86_64-linux-gnu # For testing SSH agent related functionality - eval `ssh-agent -s` - - nosetests --with-coverage --cover-package=pssh + - nosetests --with-coverage --cover-package=pssh tests/test_native_tunnel.py + - nosetests --with-coverage --cover-package=pssh tests/test_native_*_client.py + - nosetests --with-coverage --cover-package=pssh tests/test_paramiko*.py - flake8 pssh - cd doc; make html; cd .. # Test building from source distribution diff --git a/Changelog.rst b/Changelog.rst index 70af87b2..4bcf794d 100644 --- a/Changelog.rst +++ b/Changelog.rst @@ -1,6 +1,24 @@ Change Log ============ +1.7.0 +++++++ + +Changes +-------- + +* Better tunneling implementation for native clients that supports multiple tunnels over single SSH connection for connecting multiple hosts through single proxy. +* Added ``greenlet_timeout`` setting to native client ``run_command`` to pass on to getting greenlet result to allow for greenlets to timeout. +* Native client raises specific exceptions on non-authentication errors connecting to host instead of generic ``SessionError``. + + +Fixes +------ + +* Native client tunneling would not work correctly - #123. +* ``timeout`` setting was not applied to native client sockets. +* Native client would have ``SessionError`` instead of ``Timeout`` exceptions on timeout errors connecting to hosts. + 1.6.3 ++++++ diff --git a/doc/advanced.rst b/doc/advanced.rst index 41f7d8b6..2a5f417c 100644 --- a/doc/advanced.rst +++ b/doc/advanced.rst @@ -111,7 +111,7 @@ To make use of this new client, ``ParallelSSHClient`` can be imported from ``pss `Feature comparison `_ for how the client features compare. - API documentation for `parallel `_ and `single `_ native clients. + API documentation for `parallel `_ and `single `_ native clients. Tunneling ********** diff --git a/doc/api.rst b/doc/api.rst index 955efe13..ab3797e3 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -11,5 +11,6 @@ API Documentation base_pssh output agent + tunnel utils exceptions diff --git a/doc/index.rst b/doc/index.rst index 4dab8a9c..0593d310 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -32,17 +32,19 @@ It uses non-blocking asynchronous SSH sessions and is to date the only publicly quickstart ssh2 advanced - Changelog api + Changelog In a nutshell ************** +Client will attempt to use all available keys under ``~/.ssh`` as well as any keys in an SSH agent, if one is available. + .. code-block:: python from __future__ import print_function - from pssh.pssh_client import ParallelSSHClient + from pssh.clients import ParallelSSHClient client = ParallelSSHClient(['localhost']) output = client.run_command('whoami') @@ -54,31 +56,11 @@ In a nutshell -`ssh2-python` (`libssh2`) based clients -****************************************** - -As of version ``1.2.0``, new single host and parallel clients are available based on the ``libssh2`` C library via its ``ssh2-python`` wrapper. - -They offer significantly enhanced performance and stability, at much less overhead, with a native non-blocking mode meaning *no monkey patching of the Python standard library* when using them. - -To use them, import from ``pssh2_client`` or ``ssh2_client`` for the parallel and single clients respectively. - -.. code-block:: python - - from __future__ import print_function - - from pssh.pssh2_client import ParallelSSHClient - - client = ParallelSSHClient(['localhost']) - output = client.run_command('whoami') - for line in output['localhost'].stdout: - print(line) - -The API is mostly identical to the current clients, though some features are not yet supported. See `client feature comparison `_ section for how feature support differs between the two clients. - .. note:: - From version ``2.x.x`` onwards, the ``ssh2-python`` based clients will *become the default*, replacing the current ``pssh_client.ParallelSSHClient``, with the current clients renamed. + There is also a now deprecated paramiko based client available under ``pssh.clients.miko`` that has much the same API. It supports some features not currently supported by the native client - see `feature comparison `_. + + From version ``2.x.x`` onwards, the clients under ``pssh.clients.miko`` will be an optional ``extras`` install. Indices and tables diff --git a/doc/ssh2.rst b/doc/ssh2.rst index 6bcc5798..094c46b6 100644 --- a/doc/ssh2.rst +++ b/doc/ssh2.rst @@ -8,7 +8,7 @@ Below is a comparison of feature support for the two client types. =============================== ============== ====================== Feature paramiko ssh2-python (libssh2) =============================== ============== ====================== -Agent forwarding Yes Not supported (*PR Pending*) +Agent forwarding Yes Yes (binary wheels or from source builds only) Proxying/tunnelling Yes Yes Kerberos (GSS) authentication Yes Not supported Private key file authentication Yes Yes @@ -20,8 +20,8 @@ Session timeout setting Yes Yes Per-channel timeout setting Yes Yes Programmatic SSH agent Yes Not supported OpenSSH config parsing Yes Not yet implemented -ECSA keys support Yes Not supported (*PR Pending*) -SCP functionality Not supported Not yet implemented +ECSA keys support Yes Yes +SCP functionality Not supported Yes =============================== ============== ====================== If any of missing features are required for a use case, then the paramiko based clients should be used instead. diff --git a/doc/tunnel.rst b/doc/tunnel.rst new file mode 100644 index 00000000..2c26e88d --- /dev/null +++ b/doc/tunnel.rst @@ -0,0 +1,5 @@ +Native Tunnel +============== + +.. automodule:: pssh.clients.native.tunnel + :member-order: groupwise diff --git a/pssh/clients/base_pssh.py b/pssh/clients/base_pssh.py index a29bf9f2..87616462 100644 --- a/pssh/clients/base_pssh.py +++ b/pssh/clients/base_pssh.py @@ -67,6 +67,7 @@ def run_command(self, command, user=None, stop_on_errors=True, host_args=None, use_pty=False, shell=None, encoding='utf-8', *args, **kwargs): + greenlet_timeout = kwargs.pop('greenlet_timeout', None) output = {} if host_args: try: @@ -88,7 +89,7 @@ def run_command(self, command, user=None, stop_on_errors=True, for host in self.hosts] for cmd in cmds: try: - self.get_output(cmd, output) + self.get_output(cmd, output, timeout=greenlet_timeout) except Exception: if stop_on_errors: raise @@ -122,7 +123,7 @@ def _get_host_config_values(self, host): def _run_command(self, host, command, *args, **kwargs): raise NotImplementedError - def get_output(self, cmd, output): + def get_output(self, cmd, output, timeout=None): """Get output from command. :param cmd: Command to get output from @@ -133,7 +134,7 @@ def get_output(self, cmd, output): :type output: dict :rtype: None""" try: - (channel, host, stdout, stderr, stdin) = cmd.get() + (channel, host, stdout, stderr, stdin) = cmd.get(timeout=timeout) except Exception as ex: host = ex.host self._update_host_output( diff --git a/pssh/clients/miko/__init__.py b/pssh/clients/miko/__init__.py index 39ea422c..db4de3f1 100644 --- a/pssh/clients/miko/__init__.py +++ b/pssh/clients/miko/__init__.py @@ -15,6 +15,5 @@ # License along with this library; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA -# flake8: noqa: F401 -from .parallel import ParallelSSHClient -from .single import SSHClient, logger +from .parallel import ParallelSSHClient # noqa: F401 +from .single import SSHClient, logger # noqa: F401 diff --git a/pssh/clients/native/parallel.py b/pssh/clients/native/parallel.py index 0a39321c..1c5bf40c 100644 --- a/pssh/clients/native/parallel.py +++ b/pssh/clients/native/parallel.py @@ -16,13 +16,15 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA import logging +from collections import deque from gevent import sleep +from gevent.lock import RLock from ..base_pssh import BaseParallelSSHClient from ...constants import DEFAULT_RETRIES, RETRY_DELAY from .single import SSHClient from ...exceptions import ProxyError, Timeout, HostArgumentException -from ...tunnel import Tunnel +from .tunnel import Tunnel logger = logging.getLogger(__name__) @@ -31,12 +33,12 @@ class ParallelSSHClient(BaseParallelSSHClient): """ssh2-python based parallel client.""" - def __init__(self, hosts, user=None, password=None, port=None, pkey=None, + def __init__(self, hosts, user=None, password=None, port=22, pkey=None, num_retries=DEFAULT_RETRIES, timeout=None, pool_size=10, allow_agent=True, host_config=None, retry_delay=RETRY_DELAY, proxy_host=None, proxy_port=22, proxy_user=None, proxy_password=None, proxy_pkey=None, - forward_ssh_agent=True): + forward_ssh_agent=True, tunnel_timeout=None): """ :param hosts: Hosts to connect to :type hosts: list(str) @@ -46,7 +48,7 @@ def __init__(self, hosts, user=None, password=None, port=None, pkey=None, no password :type password: str :param port: (Optional) Port number to use for SSH connection. Defaults - to ``None`` which uses SSH default (22) + to 22. :type port: int :param pkey: Private key file path to use. Note that the public key file pair *must* also exist in the same location with name ``.pub`` @@ -57,9 +59,10 @@ def __init__(self, hosts, user=None, password=None, port=None, pkey=None, :param retry_delay: Number of seconds to wait between retries. Defaults to :py:class:`pssh.constants.RETRY_DELAY` :type retry_delay: int - :param timeout: SSH session timeout setting in seconds. This controls - timeout setting of authenticated SSH sessions. - :type timeout: int + :param timeout: (Optional) SSH session timeout setting in seconds. + This controls timeout setting of socket operations used for SSH + sessions. Defaults to OS default - usually 60 seconds. + :type timeout: float :param pool_size: (Optional) Greenlet pool size. Controls concurrency, on how many hosts to execute tasks in parallel. Defaults to 10. Overhead in event @@ -94,6 +97,9 @@ def __init__(self, hosts, user=None, password=None, port=None, pkey=None, equivalent to `ssh -A` from the `ssh` command line utility. Defaults to True if not set. :type forward_ssh_agent: bool + :param tunnel_timeout: (Optional) Timeout setting for proxy tunnel + connections. + :type tunnel_timeout: float """ BaseParallelSSHClient.__init__( self, hosts, user=user, password=password, port=port, pkey=pkey, @@ -106,11 +112,16 @@ def __init__(self, hosts, user=None, password=None, port=None, pkey=None, self.proxy_user = proxy_user self.proxy_password = proxy_password self.forward_ssh_agent = forward_ssh_agent - self._tunnels = {} + self._tunnel = None + self._tunnel_in_q = None + self._tunnel_out_q = None + self._tunnel_lock = None + self._tunnel_timeout = tunnel_timeout + self._clients_lock = RLock() def run_command(self, command, sudo=False, user=None, stop_on_errors=True, use_pty=False, host_args=None, shell=None, - encoding='utf-8', timeout=None): + encoding='utf-8', timeout=None, greenlet_timeout=None): """Run command on all hosts in parallel, honoring self.pool_size, and return output dictionary. @@ -160,9 +171,21 @@ def run_command(self, command, sudo=False, user=None, stop_on_errors=True, :type encoding: str :param timeout: (Optional) Timeout in seconds for reading from stdout or stderr. Defaults to no timeout. Reading from stdout/stderr will - timeout after this many seconds if remote output is not ready. + raise :py:class:`pssh.exceptions.Timeout` + after ``timeout`` number seconds if remote output is not ready. :type timeout: int - + :param greenlet_timeout: (Optional) Greenlet timeout setting. + Defaults to no timeout. If set, this function will raise + :py:class:`gevent.Timeout` after ``greenlet_timeout`` seconds + if no result is available from greenlets. + In some cases, such as when using proxy hosts, connection timeout + is controlled by proxy server and getting result from greenlets may + hang indefinitely if remote server is unavailable. Use this setting + to avoid blocking in such circumstances. + Note that ``gevent.Timeout`` is a special class that inherits from + ``BaseException`` and thus **can not be caught** by + ``stop_on_errors=False``. + :type greenlet_timeout: float :rtype: Dictionary with host as key and :py:class:`pssh.output.HostOutput` as value as per :py:func:`pssh.pssh_client.ParallelSSHClient.get_output` @@ -181,11 +204,17 @@ def run_command(self, command, sudo=False, user=None, stop_on_errors=True, dict for cmd string format :raises: :py:class:`pssh.exceptions.ProxyError` on errors connecting to proxy if a proxy host has been set. + :raises: :py:class:`gevent.Timeout` on greenlet timeout. Gevent timeout + can not be caught by ``stop_on_errors=False``. + :raises: Exceptions from :py:mod:`ssh2.exceptions` for all other + specific errors such as + :py:class:`ssh2.exceptions.SocketDisconnectError` et al. """ return BaseParallelSSHClient.run_command( self, command, stop_on_errors=stop_on_errors, host_args=host_args, user=user, shell=shell, sudo=sudo, - encoding=encoding, use_pty=use_pty, timeout=timeout) + encoding=encoding, use_pty=use_pty, timeout=timeout, + greenlet_timeout=greenlet_timeout) def _run_command(self, host, command, sudo=False, user=None, shell=None, use_pty=False, @@ -198,7 +227,7 @@ def _run_command(self, host, command, sudo=False, user=None, use_pty=use_pty, encoding=encoding, timeout=timeout) except Exception as ex: ex.host = host - logger.error("Failed to run on host %s", host) + logger.error("Failed to run on host %s - %s", host, ex) raise ex def join(self, output, consume_output=False, timeout=None): @@ -290,40 +319,64 @@ def _get_exit_code(self, channel): return return channel.get_exit_status() - def _start_tunnel(self, host): - if host in self._tunnels: - return self._tunnels[host] - tunnel = Tunnel( - self.proxy_host, host, self.port, user=self.proxy_user, + def _start_tunnel_thread(self): + self._tunnel_lock = RLock() + self._tunnel_in_q = deque() + self._tunnel_out_q = deque() + self._tunnel = Tunnel( + self.proxy_host, self._tunnel_in_q, self._tunnel_out_q, + user=self.proxy_user, password=self.proxy_password, port=self.proxy_port, pkey=self.proxy_pkey, num_retries=self.num_retries, - timeout=self.timeout, retry_delay=self.retry_delay, + timeout=self._tunnel_timeout, retry_delay=self.retry_delay, allow_agent=self.allow_agent) - tunnel.daemon = True - tunnel.start() - while not tunnel.tunnel_open.is_set(): + self._tunnel.daemon = True + self._tunnel.start() + while not self._tunnel.tunnel_open.is_set(): logger.debug("Waiting for tunnel to become active") sleep(.1) - if not tunnel.is_alive(): + if not self._tunnel.is_alive(): msg = "Proxy authentication failed. " \ "Exception from tunnel client: %s" - logger.error(msg, tunnel.exception) - raise ProxyError(msg, tunnel.exception) - self._tunnels[host] = tunnel - return tunnel + logger.error(msg, self._tunnel.exception) + raise ProxyError(msg, self._tunnel.exception) def _make_ssh_client(self, host): - if host not in self.host_clients or self.host_clients[host] is None: - if self.proxy_host is not None: - tunnel = self._start_tunnel(host) - _user, _port, _password, _pkey = self._get_host_config_values(host) - proxy_host = None if self.proxy_host is None else '127.0.0.1' - _port = _port if self.proxy_host is None else tunnel.listen_port - self.host_clients[host] = SSHClient( - host, user=_user, password=_password, port=_port, pkey=_pkey, - num_retries=self.num_retries, timeout=self.timeout, - allow_agent=self.allow_agent, retry_delay=self.retry_delay, - proxy_host=proxy_host) + auth_thread_pool = True + if self.proxy_host is not None and self._tunnel is None: + self._start_tunnel_thread() + logger.debug("Make client request for host %s, host in clients: %s", + host, host in self.host_clients) + with self._clients_lock: + if host not in self.host_clients or self.host_clients[host] is None: + _user, _port, _password, _pkey = self._get_host_config_values( + host) + proxy_host = None if self.proxy_host is None else '127.0.0.1' + if proxy_host is not None: + auth_thread_pool = False + _wait = 0.0 + max_wait = self.timeout if self.timeout is not None else 60 + with self._tunnel_lock: + self._tunnel_in_q.append((host, _port)) + while True: + if _wait >= max_wait: + raise Timeout("Timed out waiting on tunnel to " + "open listening port") + try: + _port = self._tunnel_out_q.pop() + except IndexError: + logger.debug( + "Waiting on tunnel to open listening port") + sleep(.5) + _wait += .5 + else: + break + self.host_clients[host] = SSHClient( + host, user=_user, password=_password, port=_port, + pkey=_pkey, num_retries=self.num_retries, + timeout=self.timeout, + allow_agent=self.allow_agent, retry_delay=self.retry_delay, + proxy_host=proxy_host, _auth_thread_pool=auth_thread_pool) def copy_file(self, local_file, remote_file, recurse=False, copy_args=None): """Copy local file to remote file in parallel diff --git a/pssh/clients/native/single.py b/pssh/clients/native/single.py index a021bb11..1a9f62c6 100644 --- a/pssh/clients/native/single.py +++ b/pssh/clients/native/single.py @@ -1,16 +1,16 @@ # This file is part of parallel-ssh. - +# # Copyright (C) 2014-2018 Panos Kittenis. - +# # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation, version 2.1. - +# # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. - +# # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA @@ -28,7 +28,8 @@ from gevent import sleep, socket, get_hub from gevent.hub import Hub from ssh2.error_codes import LIBSSH2_ERROR_EAGAIN -from ssh2.exceptions import SFTPHandleError, SFTPProtocolError +from ssh2.exceptions import SFTPHandleError, SFTPProtocolError, \ + Timeout as SSH2Timeout from ssh2.session import Session from ssh2.sftp import LIBSSH2_FXF_READ, LIBSSH2_FXF_CREAT, LIBSSH2_FXF_WRITE, \ LIBSSH2_FXF_TRUNC, LIBSSH2_SFTP_S_IRUSR, LIBSSH2_SFTP_S_IRGRP, \ @@ -64,7 +65,8 @@ def __init__(self, host, retry_delay=RETRY_DELAY, allow_agent=True, timeout=None, forward_ssh_agent=True, - proxy_host=None): + proxy_host=None, + _auth_thread_pool=True): """:param host: Host name or IP to connect to. :type host: str :param user: User to connect as. Defaults to logged in user. @@ -116,7 +118,10 @@ def __init__(self, host, self.session = None self._host = proxy_host if proxy_host else host self._connect(self._host, self.port) - THREAD_POOL.apply(self._init) + if _auth_thread_pool: + THREAD_POOL.apply(self._init) + else: + self._init() def disconnect(self): """Disconnect session, close socket if needed.""" @@ -156,7 +161,10 @@ def _init(self, retries=1): while retries < self.num_retries: return self._connect_init_retry(retries) msg = "Error connecting to host %s:%s - %s" - raise SessionError(msg, self.host, self.port, ex) + logger.error(msg, self.host, self.port, ex) + if isinstance(ex, SSH2Timeout): + raise Timeout(msg, self.host, self.port, ex) + raise try: self.auth() except Exception as ex: @@ -168,6 +176,7 @@ def _init(self, retries=1): def _connect(self, host, port, retries=1): self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.settimeout(self.timeout) logger.debug("Connecting to %s:%s", host, port) try: self.sock.connect((host, port)) diff --git a/pssh/clients/native/tunnel.py b/pssh/clients/native/tunnel.py new file mode 100644 index 00000000..6c1279fa --- /dev/null +++ b/pssh/clients/native/tunnel.py @@ -0,0 +1,310 @@ +# This file is part of parallel-ssh. +# +# Copyright (C) 2014-2018 Panos Kittenis. +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation, version 2.1. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +from threading import Thread, Event +import logging + +from gevent import socket, spawn, joinall, get_hub, sleep +from gevent.select import select + +from ssh2.error_codes import LIBSSH2_ERROR_EAGAIN + +from .single import SSHClient +from ...constants import DEFAULT_RETRIES, RETRY_DELAY + + +logger = logging.getLogger(__name__) + + +class Tunnel(Thread): + + """SSH proxy implementation with direct TCP/IP tunnels. + + Each tunnel object runs in its own thread and can open any number of + direct tunnels to remote host:port destinations on local ports over + the same SSH connection. + + To use, append ``(host, port)`` tuples into ``Tunnel.in_q`` and read + listen port for tunnel connection from ``Tunnel.out_q``. + + ``Tunnel.tunnel_open`` is a *thread* event that will be set once tunnel is + ready.""" + + def __init__(self, host, in_q, out_q, user=None, + password=None, port=None, pkey=None, + num_retries=DEFAULT_RETRIES, + retry_delay=RETRY_DELAY, + allow_agent=True, timeout=None, + channel_retries=5): + """ + :param host: Remote SSH host to open tunnels with. + :type host: str + :param in_q: Deque for requesting new tunnel to given ``((host, port))`` + :type in_q: :py:class:`collections.deque` + :param out_q: Deque for feeding back tunnel listening ports. + :type out_q: :py:class:`collections.deque` + :param user: (Optional) User to login as. Defaults to logged in user + :type user: str + :param password: (Optional) Password to use for login. Defaults to + no password + :type password: str + :param port: (Optional) Port number to use for SSH connection. Defaults + to ``None`` which uses SSH default (22) + :type port: int + :param pkey: Private key file path to use. Note that the public key file + pair *must* also exist in the same location with name ``.pub`` + :type pkey: str + :param num_retries: (Optional) Number of connection and authentication + attempts before the client gives up. Defaults to 3. + :type num_retries: int + :param retry_delay: Number of seconds to wait between retries. Defaults + to :py:class:`pssh.constants.RETRY_DELAY` + :type retry_delay: int + :param timeout: SSH session timeout setting in seconds. This controls + timeout setting of authenticated SSH sessions. + :type timeout: int + :param allow_agent: (Optional) set to False to disable connecting to + the system's SSH agent. + :type allow_agent: bool + """ + Thread.__init__(self) + self.client = None + self.session = None + self._sockets = [] + self.in_q = in_q + self.out_q = out_q + self.host = host + self.port = port + self.user = user + self.password = password + self.pkey = pkey + self.num_retries = num_retries + self.retry_delay = retry_delay + self.allow_agent = allow_agent + self.timeout = timeout + self.exception = None + self.tunnel_open = Event() + self._tunnels = [] + self.channel_retries = channel_retries + + def __del__(self): + self.cleanup() + + def _read_forward_sock(self, forward_sock, channel): + while True: + if channel.eof(): + logger.debug("Channel closed") + return + try: + data = forward_sock.recv(1024) + except Exception: + logger.exception("Forward socket read error:") + sleep(1) + continue + data_len = len(data) + if data_len == 0: + continue + data_written = 0 + while data_written < data_len: + try: + rc = channel.write(data) + except Exception: + logger.exception("Channel write error:") + sleep(1) + continue + if rc == LIBSSH2_ERROR_EAGAIN: + select((), ((self.client.sock,)), (), timeout=0.001) + try: + rc = channel.write(data[data_written:]) + except Exception: + logger.exception("Channel write error:") + sleep(1) + continue + data_written += rc + try: + rc = channel.write(data[data_written:]) + except Exception: + logger.exception("Channel write error:") + sleep(1) + + def _read_channel(self, forward_sock, channel): + while True: + if channel.eof(): + logger.debug("Channel closed") + return + try: + size, data = channel.read() + except Exception as ex: + logger.error("Error reading from channel - %s", ex) + sleep(1) + continue + while size == LIBSSH2_ERROR_EAGAIN or size > 0: + if size == LIBSSH2_ERROR_EAGAIN: + select((self.client.sock,), (), (), timeout=0.001) + try: + size, data = channel.read() + except Exception as ex: + logger.error("Error reading from channel - %s", ex) + sleep(1) + continue + while size > 0: + try: + forward_sock.sendall(data) + except Exception as ex: + logger.error( + "Error sending data to forward socket - %s", ex) + sleep(.5) + continue + try: + size, data = channel.read() + except Exception as ex: + logger.error("Error reading from channel - %s", ex) + sleep(.5) + + def _init_tunnel_sock(self): + tunnel_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + tunnel_socket.settimeout(self.timeout) + tunnel_socket.bind(('127.0.0.1', 0)) + tunnel_socket.listen(0) + listen_port = tunnel_socket.getsockname()[1] + self._sockets.append(tunnel_socket) + return tunnel_socket, listen_port + + def _init_tunnel_client(self): + self.client = SSHClient(self.host, user=self.user, port=self.port, + password=self.password, pkey=self.pkey, + num_retries=self.num_retries, + retry_delay=self.retry_delay, + allow_agent=self.allow_agent, + timeout=self.timeout, + _auth_thread_pool=False) + self.session = self.client.session + self.tunnel_open.set() + + def cleanup(self): + for _sock in self._sockets: + try: + _sock.close() + except Exception as ex: + logger.error("Exception while closing sockets - %s", ex) + if self.session is not None: + self.client.disconnect() + + def _consume_q(self): + while True: + try: + host, port = self.in_q.pop() + except IndexError: + sleep(1) + continue + logger.debug("Got request for tunnel to %s:%s", host, port) + tunnel = spawn(self._start_tunnel, host, port) + self._tunnels.append(tunnel) + + def _open_channel(self, fw_host, fw_port, local_port): + channel = self.session.direct_tcpip_ex( + fw_host, fw_port, '127.0.0.1', + local_port) + while channel == LIBSSH2_ERROR_EAGAIN: + select((self.client.sock,), (self.client.sock,), ()) + channel = self.session.direct_tcpip_ex( + fw_host, fw_port, '127.0.0.1', + local_port) + return channel + + def _open_channel_retries(self, fw_host, fw_port, local_port, + wait_time=0.1): + num_tries = 0 + while num_tries < self.channel_retries: + try: + channel = self._open_channel(fw_host, fw_port, local_port) + except Exception: + num_tries += 1 + if num_tries > self.num_retries: + raise + logger.error("Error opening channel to %s:%s, retries %s/%s", + fw_host, fw_port, num_tries, self.num_retries) + sleep(wait_time) + wait_time *= 5 + continue + return channel + + def _start_tunnel(self, fw_host, fw_port): + try: + listen_socket, listen_port = self._init_tunnel_sock() + except Exception as ex: + logger.error("Error initialising tunnel listen socket - %s", ex) + self.exception = ex + return + logger.debug("Tunnel listening on 127.0.0.1:%s on hub %s", + listen_port, get_hub().thread_ident) + self.out_q.append(listen_port) + try: + forward_sock, forward_addr = listen_socket.accept() + except Exception as ex: + logger.error("Error accepting connection from client - %s", ex) + self.exception = ex + listen_socket.close() + return + forward_sock.settimeout(self.timeout) + logger.debug("Client connected, forwarding %s:%s on" + " remote host to %s", + fw_host, fw_port, forward_addr) + local_port = forward_addr[1] + try: + channel = self._open_channel_retries(fw_host, fw_port, local_port) + except Exception as ex: + logger.exception("Could not establish channel to %s:%s:", + fw_host, fw_port) + self.exception = ex + forward_sock.close() + listen_socket.close() + return + source = spawn(self._read_forward_sock, forward_sock, channel) + dest = spawn(self._read_channel, forward_sock, channel) + logger.debug("Waiting for read/write greenlets") + self._wait_send_receive_lets(source, dest, channel, forward_sock) + + def _wait_send_receive_lets(self, source, dest, channel, forward_sock): + try: + joinall((source, dest), raise_error=True) + except Exception as ex: + logger.error(ex) + finally: + logger.debug("Closing channel and forward socket") + channel.close() + forward_sock.close() + + def run(self): + """Thread run target. Starts tunnel client and waits for incoming + tunnel connection requests from ``Tunnel.in_q``.""" + try: + self._init_tunnel_client() + except Exception as ex: + # logger.error("Tunnel initilisation failed - %s", ex) + self.exception = ex + return + logger.debug("Hub ID in run function: %s", get_hub().thread_ident) + consume_let = spawn(self._consume_q) + try: + consume_let.get() + except Exception as ex: + logger.error("Tunnel thread caught exception and will exit:", + exc_info=1) + self.exception = ex + finally: + self.cleanup() diff --git a/pssh/tunnel.py b/pssh/tunnel.py index 497ae22c..75d9ef99 100644 --- a/pssh/tunnel.py +++ b/pssh/tunnel.py @@ -1,157 +1,6 @@ -# This file is part of parallel-ssh. +from .clients.native.tunnel import Tunnel # noqa: F401 -# Copyright (C) 2014-2018 Panos Kittenis. - -# This library is free software; you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public -# License as published by the Free Software Foundation, version 2.1. - -# This library is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. - -# You should have received a copy of the GNU Lesser General Public -# License along with this library; if not, write to the Free Software -# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA - -from threading import Thread, Event -import logging - -from gevent import socket, spawn, joinall, get_hub - -from ssh2.error_codes import LIBSSH2_ERROR_EAGAIN - -from .clients.native.single import SSHClient -from .native._ssh2 import wait_select -from .constants import DEFAULT_RETRIES, RETRY_DELAY - - -logger = logging.getLogger(__name__) - - -class Tunnel(Thread): - - def __init__(self, host, fw_host, fw_port, user=None, - password=None, port=None, pkey=None, - num_retries=DEFAULT_RETRIES, - retry_delay=RETRY_DELAY, - allow_agent=True, timeout=None, listen_port=0): - Thread.__init__(self) - self.client = None - self.session = None - self.socket = None - self.listen_port = listen_port - self.fw_host = fw_host - self.fw_port = fw_port if fw_port else 22 - self.channel = None - self.forward_sock = None - self.host = host - self.port = port - self.user = user - self.password = password - self.pkey = pkey - self.num_retries = num_retries - self.retry_delay = retry_delay - self.allow_agent = allow_agent - self.timeout = timeout - self.exception = None - self.tunnel_open = Event() - - def _read_forward_sock(self): - while True: - logger.debug("Waiting on forward socket read") - data = self.forward_sock.recv(1024) - data_len = len(data) - if data_len == 0: - logger.error("Client disconnected") - return - data_written = 0 - rc = self.channel.write(data) - while data_written < data_len: - if rc == LIBSSH2_ERROR_EAGAIN: - logger.debug("Waiting on channel write") - wait_select(self.channel.session) - continue - elif rc < 0: - logger.error("Channel write error %s", rc) - return - data_written += rc - logger.debug( - "Wrote %s bytes from forward socket to channel", rc) - rc = self.channel.write(data[data_written:]) - logger.debug("Total channel write size %s from %s received", - data_written, data_len) - - def _read_channel(self): - while True: - size, data = self.channel.read() - while size == LIBSSH2_ERROR_EAGAIN or size > 0: - if size == LIBSSH2_ERROR_EAGAIN: - logger.debug("Waiting on channel") - wait_select(self.channel.session) - size, data = self.channel.read() - elif size < 0: - logger.error("Error reading from channel") - return - while size > 0: - logger.debug("Read %s from channel..", size) - self.forward_sock.sendall(data) - logger.debug("Forwarded %s bytes from channel", size) - size, data = self.channel.read() - if self.channel.eof(): - logger.debug("Channel closed") - return - - def _init_tunnel_sock(self): - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.bind(('127.0.0.1', self.listen_port)) - self.socket.listen(0) - self.listen_port = self.socket.getsockname()[1] - logger.debug("Tunnel listening on 127.0.0.1:%s on hub %s", - self.listen_port, get_hub()) - - def _init_tunnel_client(self): - self.client = SSHClient(self.host, user=self.user, port=self.port, - password=self.password, pkey=self.pkey, - num_retries=self.num_retries, - retry_delay=self.retry_delay, - allow_agent=self.allow_agent, - timeout=self.timeout) - self.session = self.client.session - - def run(self): - try: - self._init_tunnel_client() - self._init_tunnel_sock() - except Exception as ex: - logger.error("Tunnel initilisation failed with %s", ex) - self.exception = ex - return - logger.debug("Hub in run function: %s", get_hub()) - try: - while True: - logger.debug("Tunnel waiting for connection") - self.tunnel_open.set() - self.forward_sock, forward_addr = self.socket.accept() - logger.debug("Client connected, forwarding %s:%s on" - " remote host to local %s", - self.fw_host, self.fw_port, - forward_addr) - self.session.set_blocking(1) - self.channel = self.session.direct_tcpip_ex( - self.fw_host, self.fw_port, '127.0.0.1', forward_addr[1]) - if self.channel is None: - self.forward_sock.close() - self.socket.close() - raise Exception("Could not establish channel to %s:%s", - self.fw_host, self.fw_port) - self.session.set_blocking(0) - source = spawn(self._read_forward_sock) - dest = spawn(self._read_channel) - joinall((source, dest)) - self.channel.close() - self.forward_sock.close() - finally: - if not self.socket.closed: - self.socket.close() +from warnings import warn +__msg = "pssh.tunnel is deprecated and has been moved to " \ + "pssh.clients.native.tunnel" +warn(__msg) diff --git a/tests/base_ssh2_test.py b/tests/base_ssh2_test.py index a90b0dd9..5328de74 100644 --- a/tests/base_ssh2_test.py +++ b/tests/base_ssh2_test.py @@ -1,3 +1,20 @@ +# This file is part of parallel-ssh. +# +# Copyright (C) 2015-2018 Panos Kittenis +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation, version 2.1. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + import unittest import pwd import os diff --git a/tests/embedded_server/openssh.py b/tests/embedded_server/openssh.py index 5f416734..a0cb74c2 100644 --- a/tests/embedded_server/openssh.py +++ b/tests/embedded_server/openssh.py @@ -1,15 +1,15 @@ -# This file is part of paralle-ssh. -# Copyright (C) 2014-2017 Panos Kittenis - +# This file is part of parallel-ssh. +# Copyright (C) 2014-2018 Panos Kittenis +# # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation, version 2.1. - +# # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. - +# # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA @@ -18,7 +18,8 @@ import socket import random import string -from gevent.subprocess import Popen +from threading import Thread +from subprocess import Popen from time import sleep from sys import version_info @@ -65,11 +66,10 @@ def make_config(self): def start_server(self): cmd = ['/usr/sbin/sshd', '-D', '-p', str(self.port), '-h', SERVER_KEY, '-f', self.sshd_config] - server = Popen(cmd) - self.server_proc = server - self._wait_for_port() + self.server_proc = Popen(cmd) + self.wait_for_port() - def _wait_for_port(self): + def wait_for_port(self): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) while sock.connect_ex((self.listen_ip, self.port)) != 0: sleep(.1) @@ -78,9 +78,26 @@ def _wait_for_port(self): def stop(self): if self.server_proc is not None and self.server_proc.returncode is None: - self.server_proc.terminate() - self.server_proc.wait() + try: + self.server_proc.terminate() + self.server_proc.wait() + except OSError: + pass + try: + os.unlink(self.sshd_config) + except OSError: + pass def __del__(self): self.stop() - os.unlink(self.sshd_config) + + +class ThreadedOpenSSHServer(Thread, OpenSSHServer): + + def __init__(self, listen_ip='127.0.0.1', port=2222): + Thread.__init__(self) + OpenSSHServer.__init__(self, listen_ip=listen_ip, port=port) + + def run(self): + self.start_server() + self.server_proc.wait() diff --git a/tests/test_imports.py b/tests/test_imports.py index afb85df9..831782ea 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -7,6 +7,9 @@ def test_regular_import(self): from pssh.clients.native.parallel import ParallelSSHClient from pssh.pssh2_client import ParallelSSHClient as Client2 self.assertEqual(ParallelSSHClient, Client2) + from pssh.ssh2_client import SSHClient as Client2 + from pssh.clients.native.single import SSHClient + self.assertEqual(SSHClient, Client2) def test_deprecated_import(self): from pssh.pssh_client import ParallelSSHClient @@ -38,3 +41,6 @@ def test_client_imports(self): import pssh.clients.miko import pssh.clients.native import pssh.clients + + def test_tunnel_imports(self): + import pssh.tunnel diff --git a/tests/test_native_parallel_client.py b/tests/test_native_parallel_client.py index 5d19752c..aa2af175 100644 --- a/tests/test_native_parallel_client.py +++ b/tests/test_native_parallel_client.py @@ -1,23 +1,22 @@ -#!/usr/bin/env python # -*- coding: utf-8 -*- - # This file is part of parallel-ssh. - +# # Copyright (C) 2015-2018 Panos Kittenis - +# # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation, version 2.1. - +# # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. - +# # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + """Unittests for :mod:`pssh.ParallelSSHClient` class""" from __future__ import print_function @@ -35,18 +34,17 @@ import time -import gevent -from pssh.clients.native import ParallelSSHClient, logger as pssh_logger +from gevent import joinall +from pssh.clients.native import ParallelSSHClient from pssh.exceptions import UnknownHostException, \ AuthenticationException, ConnectionErrorException, SessionError, \ HostArgumentException, SFTPError, SFTPIOError, Timeout, SCPError, \ ProxyError +from pssh import logger as pssh_logger from .embedded_server.embedded_server import make_socket from .embedded_server.openssh import OpenSSHServer from .base_ssh2_test import PKEY_FILENAME, PUB_FILE -# from pssh.utils import load_private_key -# from pssh.agent import SSHAgent pssh_logger.setLevel(logging.DEBUG) @@ -59,10 +57,10 @@ class ParallelSSHClientTest(unittest.TestCase): def setUpClass(cls): _mask = int('0600') if version_info <= (2,) else 0o600 os.chmod(PKEY_FILENAME, _mask) - cls.server = OpenSSHServer() - cls.server.start_server() cls.host = '127.0.0.1' - cls.port = 2222 + cls.port = 2223 + cls.server = OpenSSHServer(listen_ip=cls.host, port=cls.port) + cls.server.start_server() cls.cmd = 'echo me' cls.resp = u'me' cls.user_key = PKEY_FILENAME @@ -80,7 +78,6 @@ def setUpClass(cls): def tearDownClass(cls): del cls.client cls.server.stop() - del cls.server def setUp(self): self.long_cmd = lambda lines: 'for (( i=0; i<%s; i+=1 )) do echo $i; sleep 1; done' % (lines,) @@ -89,7 +86,7 @@ def make_random_port(self, host=None): host = self.host if not host else host listen_socket = make_socket(host) listen_port = listen_socket.getsockname()[1] - del listen_socket + listen_socket.close() return listen_port def test_client_join_consume_output(self): @@ -149,20 +146,23 @@ def test_get_last_output(self): host = '127.0.0.9' server = OpenSSHServer(listen_ip=host, port=self.port) server.start_server() - hosts = [self.host, host] - client = ParallelSSHClient(hosts, port=self.port, pkey=self.user_key) - self.assertTrue(client.cmds is None) - self.assertTrue(client.get_last_output() is None) - client.run_command(self.cmd) - self.assertTrue(client.cmds is not None) - self.assertEqual(len(client.cmds), len(hosts)) - output = client.get_last_output() - self.assertTrue(len(output), len(hosts)) - client.join(output) - for host in hosts: - self.assertTrue(host in output) - exit_code = output[host].exit_code - self.assertTrue(exit_code == 0) + try: + hosts = [self.host, host] + client = ParallelSSHClient(hosts, port=self.port, pkey=self.user_key) + self.assertTrue(client.cmds is None) + self.assertTrue(client.get_last_output() is None) + client.run_command(self.cmd) + self.assertTrue(client.cmds is not None) + self.assertEqual(len(client.cmds), len(hosts)) + output = client.get_last_output() + self.assertTrue(len(output), len(hosts)) + client.join(output) + for host in hosts: + self.assertTrue(host in output) + exit_code = output[host].exit_code + self.assertTrue(exit_code == 0) + finally: + server.stop() def test_pssh_client_no_stdout_non_zero_exit_code_immediate_exit(self): output = self.client.run_command('exit 1') @@ -273,8 +273,8 @@ def test_pssh_client_timeout(self): timeout=client_timeout, num_retries=1) output = client.run_command('sleep 1', stop_on_errors=False) - self.assertTrue(isinstance(output[self.host].exception, - SessionError)) + self.assertIsInstance(output[self.host].exception, + Timeout) # def test_pssh_client_run_command_password(self): # """Test password authentication. Embedded server accepts any password @@ -367,7 +367,7 @@ def test_pssh_copy_file(self): remote_file_abspath = os.path.expanduser('~/' + remote_filepath) cmds = client.copy_file(local_filename, remote_filepath) try: - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) except Exception: raise finally: @@ -381,11 +381,12 @@ def test_pssh_copy_file(self): def test_pssh_copy_file_per_host_args(self): """Test parallel copy file with per-host arguments""" host2, host3 = '127.0.0.6', '127.0.0.7' - server2 = OpenSSHServer(host2) - server3 = OpenSSHServer(host3) + server2 = OpenSSHServer(host2, port=self.port) + server3 = OpenSSHServer(host3, port=self.port) servers = [server2, server3] for server in servers: server.start_server() + # server.wait_for_port() time.sleep(1) hosts = [self.host, host2, host3] @@ -408,7 +409,7 @@ def test_pssh_copy_file_per_host_args(self): num_retries=2) greenlets = client.copy_file('%(local_file)s', '%(remote_file)s', copy_args=copy_args) - gevent.joinall(greenlets) + joinall(greenlets) self.assertRaises(HostArgumentException, client.copy_file, '%(local_file)s', '%(remote_file)s', @@ -466,7 +467,7 @@ def test_pssh_client_directory_relative_path(self): test_file.close() cmds = client.copy_file(local_test_path, remote_test_path_rel, recurse=True) try: - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) for path in remote_file_paths: self.assertTrue(os.path.isfile(path)) finally: @@ -505,7 +506,7 @@ def test_pssh_client_directory_abs_path(self): test_file.close() cmds = client.copy_file(local_test_path, remote_test_path_abs, recurse=True) try: - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) for path in remote_file_paths: self.assertTrue(os.path.isfile(path)) finally: @@ -544,7 +545,7 @@ def test_pssh_client_copy_file_failure(self): os.chmod(remote_test_path_abs, mask) cmds = self.client.copy_file(local_test_path, remote_test_path_abs, recurse=True) try: - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) raise Exception("Expected SFTPError exception") except SFTPError: pass @@ -555,7 +556,7 @@ def test_pssh_client_copy_file_failure(self): remote_test_path_abs = os.sep.join((dir_name, remote_test_path)) cmds = self.client.copy_file(local_file_path, remote_test_path_abs, recurse=True) try: - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) raise Exception("Expected SFTPError exception on creating remote " "directory") except SFTPError: @@ -613,10 +614,10 @@ def test_pssh_copy_remote_file(self): test_file.write(test_file_data) test_file.close() cmds = self.client.copy_remote_file(remote_test_path_abs, local_test_path) - self.assertRaises(ValueError, gevent.joinall, cmds, raise_error=True) + self.assertRaises(ValueError, joinall, cmds, raise_error=True) cmds = self.client.copy_remote_file(remote_test_path_abs, local_test_path, recurse=True) - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) try: self.assertTrue(os.path.isdir(local_copied_dir)) for path in local_file_paths: @@ -630,7 +631,7 @@ def test_pssh_copy_remote_file(self): # Relative path cmds = self.client.copy_remote_file(remote_test_path_rel, local_test_path, recurse=True) - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) try: self.assertTrue(os.path.isdir(local_copied_dir)) for path in local_file_paths: @@ -641,7 +642,7 @@ def test_pssh_copy_remote_file(self): # Different suffix cmds = self.client.copy_remote_file(remote_test_path_abs, local_test_path, suffix_separator='.', recurse=True) - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) new_local_copied_dir = '.'.join([local_test_path, self.host]) try: for path in local_file_paths: @@ -654,11 +655,12 @@ def test_pssh_copy_remote_file(self): def test_pssh_copy_remote_file_per_host_args(self): """Test parallel remote copy file with per-host arguments""" host2, host3 = '127.0.0.10', '127.0.0.11' - server2 = OpenSSHServer(host2) - server3 = OpenSSHServer(host3) + server2 = OpenSSHServer(host2, port=self.port) + server3 = OpenSSHServer(host3, port=self.port) servers = [server2, server3] for server in servers: server.start_server() + # server.wait_for_port() time.sleep(1) hosts = [self.host, host2, host3] @@ -683,7 +685,7 @@ def test_pssh_copy_remote_file_per_host_args(self): num_retries=2) greenlets = client.copy_remote_file('%(remote_file)s', '%(local_file)s', copy_args=copy_args) - gevent.joinall(greenlets) + joinall(greenlets) self.assertRaises(HostArgumentException, client.copy_remote_file, '%(remote_file)s', '%(local_file)s', @@ -736,7 +738,7 @@ def test_pssh_hosts_more_than_pool_size(self): get logs for all hosts""" # Make a second server on the same port as the first one host2 = '127.0.0.2' - server2 = OpenSSHServer(listen_ip=host2) + server2 = OpenSSHServer(listen_ip=host2, port=self.port) server2.start_server() hosts = [self.host, host2] client = ParallelSSHClient(hosts, @@ -759,10 +761,10 @@ def test_pssh_hosts_more_than_pool_size(self): def test_pssh_hosts_iterator_hosts_modification(self): """Test using iterator as host list and modifying host list in place""" host2, host3 = '127.0.0.2', '127.0.0.3' - server2 = OpenSSHServer(listen_ip=host2) - server3 = OpenSSHServer(listen_ip=host3) - server2.start_server() - server3.start_server() + server2 = OpenSSHServer(listen_ip=host2, port=self.port) + server3 = OpenSSHServer(listen_ip=host3, port=self.port) + for _server in (server2, server3): + _server.start_server() hosts = [self.host, '127.0.0.2'] client = ParallelSSHClient(iter(hosts), port=self.port, @@ -793,111 +795,6 @@ def test_pssh_hosts_iterator_hosts_modification(self): server2.stop() server3.stop() -# def test_ssh_proxy(self): -# """Test connecting to remote destination via SSH proxy -# client -> proxy -> destination -# Proxy SSH server accepts no commands and sends no responses, only -# proxies to destination. Destination accepts a command as usual.""" -# del self.client -# self.client = None -# self.server.kill() -# server, _ = start_server_from_ip(self.host, port=self.listen_port) -# proxy_host = '127.0.0.2' -# proxy_server, proxy_server_port = start_server_from_ip(proxy_host) -# client = ParallelSSHClient([self.host], port=self.listen_port, -# pkey=self.user_key, -# proxy_host=proxy_host, -# proxy_port=proxy_server_port, -# ) -# try: -# output = client.run_command(self.fake_cmd) -# stdout = list(output[self.host]['stdout']) -# expected_stdout = [self.fake_resp] -# self.assertEqual(expected_stdout, stdout, -# msg="Got unexpected stdout - %s, expected %s" % -# (stdout, -# expected_stdout,)) -# finally: -# del client -# server.kill() -# proxy_server.kill() - -# def test_ssh_proxy_target_host_failure(self): -# del self.client -# self.client = None -# self.server.kill() -# proxy_host = '127.0.0.2' -# proxy_server, proxy_server_port = start_server_from_ip(proxy_host) -# client = ParallelSSHClient([self.host], port=self.listen_port, -# pkey=self.user_key, -# proxy_host=proxy_host, -# proxy_port=proxy_server_port, -# ) -# try: -# self.assertRaises( -# ConnectionErrorException, client.run_command, self.fake_cmd) -# finally: -# del client -# proxy_server.kill() - -# def test_ssh_proxy_auth(self): -# """Test connecting to remote destination via SSH proxy -# client -> proxy -> destination -# Proxy SSH server accepts no commands and sends no responses, only -# proxies to destination. Destination accepts a command as usual.""" -# host2 = '127.0.0.2' -# proxy_server, proxy_server_port = start_server_from_ip(host2) -# proxy_user = 'proxy_user' -# proxy_password = 'fake' -# client = ParallelSSHClient([self.host], port=self.listen_port, -# pkey=self.user_key, -# proxy_host=host2, -# proxy_port=proxy_server_port, -# proxy_user=proxy_user, -# proxy_password='fake', -# proxy_pkey=self.user_key, -# num_retries=1, -# ) -# expected_stdout = [self.fake_resp] -# try: -# output = client.run_command(self.fake_cmd) -# stdout = list(output[self.host]['stdout']) -# self.assertEqual(expected_stdout, stdout, -# msg="Got unexpected stdout - %s, expected %s" % ( -# stdout, expected_stdout,)) -# self.assertEqual(client.host_clients[self.host].proxy_user, -# proxy_user) -# self.assertEqual(client.host_clients[self.host].proxy_password, -# proxy_password) -# self.assertTrue(client.host_clients[self.host].proxy_pkey) -# finally: -# del client -# proxy_server.kill() - -# def test_ssh_proxy_auth_fail(self): -# """Test failures while connecting via proxy""" -# proxy_host = '127.0.0.2' -# server, listen_port = start_server_from_ip(self.host, fail_auth=True) -# proxy_server, proxy_server_port = start_server_from_ip(proxy_host) -# proxy_user = 'proxy_user' -# proxy_password = 'fake' -# client = ParallelSSHClient([self.host], port=listen_port, -# pkey=self.user_key, -# proxy_host='127.0.0.2', -# proxy_port=proxy_server_port, -# proxy_user=proxy_user, -# proxy_password='fake', -# proxy_pkey=self.user_key, -# num_retries=1, -# ) -# try: -# self.assertRaises( -# AuthenticationException, client.run_command, self.fake_cmd) -# finally: -# del client -# server.kill() -# proxy_server.kill() - def test_bash_variable_substitution(self): """Test bash variables work correctly""" command = """for i in 1 2 3; do echo $i; done""" @@ -947,7 +844,6 @@ def test_connection_error_exception(self): def test_authentication_exception(self): """Test that we get authentication exception in output with correct arguments""" - # server, port = start_server_from_ip(self.host, fail_auth=True) hosts = [self.host] client = ParallelSSHClient(hosts, port=self.port, pkey='A REALLY FAKE KEY', @@ -1089,8 +985,8 @@ def test_get_exit_codes_bad_output(self): def test_per_host_tuple_args(self): host2, host3 = '127.0.0.4', '127.0.0.5' - server2 = OpenSSHServer(host2) - server3 = OpenSSHServer(host3) + server2 = OpenSSHServer(host2, port=self.port) + server3 = OpenSSHServer(host3, port=self.port) servers = [server2, server3] for server in servers: server.start_server() @@ -1126,8 +1022,8 @@ def test_per_host_tuple_args(self): def test_per_host_dict_args(self): host2, host3 = '127.0.0.2', '127.0.0.3' - server2 = OpenSSHServer(host2) - server3 = OpenSSHServer(host3) + server2 = OpenSSHServer(host2, port=self.port) + server3 = OpenSSHServer(host3, port=self.port) servers = [server2, server3] for server in servers: server.start_server() @@ -1390,7 +1286,7 @@ def test_scp_send_dir(self): remote_test_dir_abspath = os.path.expanduser('~/' + remote_test_dir) try: cmds = self.client.scp_send(local_filename, remote_filename) - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) time.sleep(.2) self.assertTrue(os.path.isdir(remote_test_dir_abspath)) self.assertTrue(os.path.isfile(remote_file_abspath)) @@ -1414,7 +1310,7 @@ def test_scp_send(self): remote_file_abspath = os.path.expanduser('~/' + remote_filepath) cmds = self.client.scp_send(local_filename, remote_filepath) try: - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) except Exception: raise else: @@ -1433,7 +1329,7 @@ def test_scp_recv_failure(self): cmds = self.client.scp_recv( 'fakey fakey fake fake', 'equally fake') try: - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) except Exception as ex: self.assertEqual(ex.host, self.host) self.assertIsInstance(ex, SCPError) @@ -1475,11 +1371,11 @@ def test_scp_recv(self): test_file.write(test_file_data) test_file.close() cmds = self.client.scp_recv(remote_test_path_abs, local_test_path) - self.assertRaises(SCPError, gevent.joinall, cmds, raise_error=True) + self.assertRaises(SCPError, joinall, cmds, raise_error=True) cmds = self.client.scp_recv(remote_test_path_abs, local_test_path, recurse=True) try: - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) self.assertTrue(os.path.isdir(local_copied_dir)) for path in local_file_paths: self.assertTrue(os.path.isfile(path)) @@ -1493,69 +1389,10 @@ def test_scp_recv(self): cmds = self.client.scp_recv(remote_test_path_rel, local_test_path, recurse=True) try: - gevent.joinall(cmds, raise_error=True) + joinall(cmds, raise_error=True) self.assertTrue(os.path.isdir(local_copied_dir)) for path in local_file_paths: self.assertTrue(os.path.isfile(path)) finally: shutil.rmtree(remote_test_path_abs) shutil.rmtree(local_copied_dir) - - # This is a unit test, no output is checked, due to race conditions - # with running server in same thread. - def test_tunnel(self): - proxy_host = '127.0.0.9' - server = OpenSSHServer(listen_ip=proxy_host, port=self.port) - server.start_server() - client = ParallelSSHClient( - [self.host], port=self.port, pkey=self.user_key, - proxy_host=proxy_host, proxy_port=self.port, num_retries=1, - proxy_pkey=self.user_key, - timeout=2) - output = client.run_command(self.cmd, stop_on_errors=False) - self.assertEqual(self.host, list(output.keys())[0]) - del client - server.stop() - - def test_tunnel_init_failure(self): - proxy_host = '127.0.0.20' - client = ParallelSSHClient( - [self.host], port=self.port, pkey=self.user_key, - proxy_host=proxy_host, proxy_port=self.port, num_retries=1, - proxy_pkey=self.user_key, - timeout=2) - output = client.run_command(self.cmd, stop_on_errors=False) - exc = output[self.host].exception - self.assertIsInstance(exc, ProxyError) - self.assertIsInstance(exc.args[1], ConnectionErrorException) - -# def test_proxy_remote_host_failure_timeout(self): -# """Test that timeout setting is passed on to proxy to be used for the -# proxy->remote host connection timeout -# """ -# self.server.kill() -# server_timeout=0.2 -# client_timeout=server_timeout-0.1 -# server, listen_port = start_server_from_ip(self.host, -# timeout=server_timeout) -# proxy_host = '127.0.0.2' -# proxy_server, proxy_server_port = start_server_from_ip(proxy_host) -# proxy_user = 'proxy_user' -# proxy_password = 'fake' -# client = ParallelSSHClient([self.host], port=listen_port, -# pkey=self.user_key, -# proxy_host='127.0.0.2', -# proxy_port=proxy_server_port, -# proxy_user=proxy_user, -# proxy_password='fake', -# proxy_pkey=self.user_key, -# num_retries=1, -# timeout=client_timeout, -# ) -# try: -# self.assertRaises( -# ConnectionErrorException, client.run_command, self.fake_cmd) -# finally: -# del client -# server.kill() -# proxy_server.kill() diff --git a/tests/test_native_single_client.py b/tests/test_native_single_client.py index e789ecf9..8b55ec4b 100644 --- a/tests/test_native_single_client.py +++ b/tests/test_native_single_client.py @@ -9,10 +9,10 @@ from .base_ssh2_test import SSH2TestCase from .embedded_server.openssh import OpenSSHServer from pssh.clients.native import SSHClient, logger as ssh_logger -from pssh.tunnel import Tunnel from ssh2.session import Session +from ssh2.exceptions import SocketDisconnectError from pssh.exceptions import AuthenticationException, ConnectionErrorException, \ - SessionError + SessionError, SFTPIOError, SFTPError, SCPError ssh_logger.setLevel(logging.DEBUG) @@ -21,6 +21,28 @@ class SSH2ClientTest(SSH2TestCase): + def test_context_manager(self): + with SSHClient(self.host, port=self.port, + pkey=self.user_key, + num_retries=1) as client: + self.assertIsInstance(client, SSHClient) + + def test_sftp_fail(self): + sftp = self.client._make_sftp() + self.assertRaises(SFTPIOError, self.client._mkdir, sftp, '/blah') + self.assertRaises(SFTPError, self.client.sftp_put, sftp, 'a file', '/blah') + + def test_scp_fail(self): + self.assertRaises(SCPError, self.client.scp_recv, 'fakey', 'fake') + try: + os.mkdir('adir') + except OSError: + pass + try: + self.assertRaises(ValueError, self.client.scp_send, 'adir', 'fake') + finally: + os.rmdir('adir') + def test_execute(self): channel, host, stdout, stderr, stdin = self.client.run_command( self.cmd) @@ -74,7 +96,7 @@ def test_handshake_fail(self): pkey=self.user_key, num_retries=1) client.session.disconnect() - self.assertRaises(SessionError, client._init) + self.assertRaises(SocketDisconnectError, client._init) def test_stdout_parsing(self): dir_list = os.listdir(os.path.expanduser('~')) @@ -114,24 +136,4 @@ def test_password_auth_failure(self): def test_retry_failure(self): self.assertRaises(ConnectionErrorException, SSHClient, self.host, port=12345, - num_retries=2) - - ## OpenSSHServer needs to run in its own thread for this test to work - ## Race conditions otherwise. - # - # def test_direct_tcpip(self): - # proxy_host = '127.0.0.9' - # server = OpenSSHServer(listen_ip=proxy_host, port=self.port) - # server.start_server() - # t = Tunnel(self.host, proxy_host, self.port, - # port=self.port, - # pkey=self.user_key, - # num_retries=1, - # timeout=5) - # t.daemon = True - # t.start() - # while not t.tunnel_open.is_set(): - # sleep(.1) - # client = SSHClient('127.0.0.1', port=t.listen_port, - # pkey=self.user_key, - # timeout=2) + num_retries=2, _auth_thread_pool=False) diff --git a/tests/test_native_tunnel.py b/tests/test_native_tunnel.py new file mode 100644 index 00000000..196a7f23 --- /dev/null +++ b/tests/test_native_tunnel.py @@ -0,0 +1,345 @@ +# This file is part of parallel-ssh. +# +# Copyright (C) 2015-2018 Panos Kittenis +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation, version 2.1. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +from __future__ import print_function + +import unittest +import pwd +import os +import shutil +import sys +import string +from socket import timeout as socket_timeout +from sys import version_info +import random +import time +from collections import deque + +from gevent import sleep, spawn, Timeout as GTimeout, socket +from pssh.clients.native.tunnel import Tunnel +from pssh.clients.native import SSHClient, ParallelSSHClient +from pssh.exceptions import UnknownHostException, \ + AuthenticationException, ConnectionErrorException, SessionError, \ + HostArgumentException, SFTPError, SFTPIOError, Timeout, SCPError, \ + ProxyError +from ssh2.exceptions import ChannelFailure + +from .embedded_server.openssh import ThreadedOpenSSHServer, OpenSSHServer +from .base_ssh2_test import PKEY_FILENAME, PUB_FILE + + +class TunnelTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + _mask = int('0600') if version_info <= (2,) else 0o600 + os.chmod(PKEY_FILENAME, _mask) + cls.host = '127.0.0.1' + cls.port = 2225 + cls.cmd = 'echo me' + cls.resp = u'me' + cls.user_key = PKEY_FILENAME + cls.user_pub_key = PUB_FILE + cls.user = pwd.getpwuid(os.geteuid()).pw_name + cls.proxy_host = '127.0.0.9' + cls.server = OpenSSHServer(listen_ip=cls.proxy_host, port=cls.port) + cls.server.start_server() + + @classmethod + def tearDownClass(cls): + cls.server.stop() + + def test_tunnel_retries(self): + local_port = 3050 + fw_host, fw_port = '127.0.0.1', 2100 + t = Tunnel(self.proxy_host, deque(), deque(), port=self.port, + pkey=self.user_key, num_retries=2) + t._init_tunnel_client() + self.assertRaises(ChannelFailure, t._open_channel_retries, fw_host, fw_port, local_port) + + def _connect_client(self, _socket): + while True: + _socket.read() + + def test_tunnel_channel_eof(self): + remote_host = '127.0.0.59' + server = OpenSSHServer(listen_ip=remote_host, port=self.port) + server.start_server() + in_q, out_q = deque(), deque() + try: + tunnel = Tunnel(self.proxy_host, in_q, out_q, port=self.port, + pkey=self.user_key, num_retries=1) + tunnel._init_tunnel_client() + channel = tunnel._open_channel_retries(self.proxy_host, self.port, 2150) + self.assertFalse(channel.eof()) + channel.close() + listen_socket, listen_port = tunnel._init_tunnel_sock() + client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client_socket.connect(('127.0.0.1', listen_port)) + client = spawn(self._connect_client, client_socket) + tunnel._read_channel(client_socket, channel) + tunnel._read_forward_sock(client_socket, channel) + self.assertTrue(channel.eof()) + client.kill() + finally: + server.stop() + + def test_tunnel_sock_failure(self): + remote_host = '127.0.0.59' + server = OpenSSHServer(listen_ip=remote_host, port=self.port) + server.start_server() + in_q, out_q = deque(), deque() + try: + tunnel = Tunnel(self.proxy_host, in_q, out_q, port=self.port, + pkey=self.user_key, num_retries=1) + tunnel._init_tunnel_client() + channel = tunnel._open_channel_retries(self.proxy_host, self.port, 2150) + self.assertFalse(channel.eof()) + listen_socket, listen_port = tunnel._init_tunnel_sock() + client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client_socket.connect(('127.0.0.1', listen_port)) + client_socket.send(b'blah\n') + client_socket.close() + gl1 = spawn(tunnel._read_channel, client_socket, channel) + gl2 = spawn(tunnel._read_forward_sock, client_socket, channel) + sleep(1) + gl1.kill() + gl2.kill() + tunnel._sockets.append(None) + tunnel.cleanup() + finally: + server.stop() + + def test_tunnel_init(self): + proxy_host = '127.0.0.49' + server = OpenSSHServer(listen_ip=proxy_host, port=self.port) + server.start_server() + in_q, out_q = deque(), deque() + try: + tunnel = Tunnel(proxy_host, in_q, out_q, port=self.port, + pkey=self.user_key, num_retries=1) + tunnel._init_tunnel_client() + consume_let = spawn(tunnel._consume_q) + in_q.append((self.host, self.port)) + while not tunnel.tunnel_open.is_set(): + sleep(.1) + if not tunnel.is_alive(): + raise ProxyError + self.assertTrue(tunnel.tunnel_open.is_set()) + self.assertIsNotNone(tunnel.client) + tunnel.cleanup() + for _sock in tunnel._sockets: + self.assertTrue(_sock.closed) + finally: + server.stop() + + def test_tunnel_channel_exc(self): + remote_host = '127.0.0.69' + server = OpenSSHServer(listen_ip=remote_host, port=self.port) + server.start_server() + in_q, out_q = deque(), deque() + try: + tunnel = Tunnel(remote_host, in_q, out_q, port=self.port, + pkey=self.user_key, num_retries=1) + tunnel._init_tunnel_client() + tunnel_accept = spawn(tunnel._start_tunnel, '127.0.0.255', self.port) + while len(out_q) == 0: + sleep(1) + listen_port = out_q.pop() + client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client_socket.connect(('127.0.0.1', listen_port)) + client = spawn(self._connect_client, client_socket) + sleep(1) + client.kill() + tunnel_accept.kill() + for _sock in tunnel._sockets: + self.assertTrue(_sock.closed) + finally: + server.stop() + + def test_tunnel_channel_failure(self): + remote_host = '127.0.0.8' + remote_server = OpenSSHServer(listen_ip=remote_host, port=self.port) + remote_server.start_server() + in_q, out_q = deque(), deque() + try: + tunnel = Tunnel(self.proxy_host, in_q, out_q, port=self.port, + pkey=self.user_key, num_retries=1) + tunnel.daemon = True + tunnel.start() + in_q.append((remote_host, self.port)) + while not tunnel.tunnel_open.is_set(): + sleep(.1) + if not tunnel.is_alive(): + raise ProxyError + self.assertTrue(tunnel.tunnel_open.is_set()) + self.assertIsNotNone(tunnel.client) + while True: + try: + _port = out_q.pop() + except IndexError: + sleep(.5) + else: + break + proxy_client = SSHClient( + '127.0.0.1', pkey=self.user_key, port=_port, + num_retries=1, _auth_thread_pool=False) + tunnel.cleanup() + spawn(proxy_client.execute, 'echo me') + proxy_client.disconnect() + self.assertTrue(proxy_client.sock.closed) + finally: + remote_server.stop() + + def test_tunnel_server_failure(self): + proxy_host = '127.0.0.9' + remote_host = '127.0.0.8' + server = OpenSSHServer(listen_ip=proxy_host, port=self.port) + remote_server = OpenSSHServer(listen_ip=remote_host, port=self.port) + for _server in (server, remote_server): + _server.start_server() + in_q, out_q = deque(), deque() + try: + tunnel = Tunnel(proxy_host, in_q, out_q, port=self.port, + pkey=self.user_key, num_retries=1) + tunnel._init_tunnel_client() + consume_let = spawn(tunnel._consume_q) + in_q.append((remote_host, self.port)) + while not tunnel.tunnel_open.is_set(): + sleep(.1) + if not tunnel.is_alive(): + raise ProxyError + self.assertTrue(tunnel.tunnel_open.is_set()) + self.assertIsNotNone(tunnel.client) + while True: + try: + _port = out_q.pop() + except IndexError: + sleep(.5) + else: + break + proxy_client = spawn( + SSHClient, + '127.0.0.1', pkey=self.user_key, port=_port, + num_retries=1, _auth_thread_pool=False) + remote_server.stop() + tunnel.cleanup() + self.assertRaises(ConnectionErrorException, proxy_client.get) + finally: + for _server in (server, remote_server): + _server.stop() + + def test_tunnel(self): + remote_host = '127.0.0.8' + remote_server = OpenSSHServer(listen_ip=remote_host, port=self.port) + remote_server.start_server() + try: + client = ParallelSSHClient( + [remote_host], port=self.port, pkey=self.user_key, + proxy_host=self.proxy_host, proxy_port=self.port, num_retries=1, + proxy_pkey=self.user_key) + output = client.run_command(self.cmd) + client.join(output) + for host, host_out in output.items(): + _stdout = list(host_out.stdout) + self.assertListEqual(_stdout, [self.resp]) + self.assertEqual(remote_host, list(output.keys())[0]) + del client + finally: + remote_server.stop() + + def test_tunnel_init_failure(self): + proxy_host = '127.0.0.20' + client = ParallelSSHClient( + [self.host], port=self.port, pkey=self.user_key, + proxy_host=proxy_host, proxy_port=self.port, num_retries=1, + proxy_pkey=self.user_key) + output = client.run_command(self.cmd, stop_on_errors=False) + client.join(output) + exc = output[self.host].exception + self.assertIsInstance(exc, ProxyError) + self.assertIsInstance(exc.args[1], ConnectionErrorException) + + def test_tunnel_remote_host_timeout(self): + remote_host = '127.0.0.18' + proxy_host = '127.0.0.19' + server = ThreadedOpenSSHServer(listen_ip=proxy_host, port=self.port) + remote_server = ThreadedOpenSSHServer(listen_ip=remote_host, port=self.port) + for _server in (server, remote_server): + _server.start() + _server.wait_for_port() + try: + client = ParallelSSHClient( + [remote_host], port=self.port, pkey=self.user_key, + proxy_host=proxy_host, proxy_port=self.port, num_retries=1, + proxy_pkey=self.user_key) + output = client.run_command(self.cmd) + client.join(output) + client._tunnel.cleanup() + for _server in (server, remote_server): + _server.stop() + _server.join() + # Gevent timeout cannot be caught by stop_on_errors + self.assertRaises(GTimeout, client.run_command, self.cmd, + greenlet_timeout=1, stop_on_errors=False) + finally: + for _server in (server, remote_server): + _server.stop() + + def test_single_tunnel_multi_hosts(self): + remote_host = '127.0.0.8' + remote_server = ThreadedOpenSSHServer( + listen_ip=remote_host, port=self.port) + remote_server.start() + remote_server.wait_for_port() + hosts = [remote_host, remote_host, remote_host] + try: + client = ParallelSSHClient( + hosts, port=self.port, pkey=self.user_key, + proxy_host=self.proxy_host, proxy_port=self.port, num_retries=1, + proxy_pkey=self.user_key) + output = client.run_command(self.cmd, stop_on_errors=False) + client.join(output) + for host, host_out in output.items(): + _stdout = list(host_out.stdout) + self.assertListEqual(_stdout, [self.resp]) + self.assertEqual(len(hosts), len(list(output.keys()))) + del client + finally: + remote_server.stop() + remote_server.join() + + def test_single_tunnel_multi_hosts_timeout(self): + remote_host = '127.0.0.8' + remote_server = ThreadedOpenSSHServer( + listen_ip=remote_host, port=self.port) + remote_server.start() + remote_server.wait_for_port() + hosts = [remote_host, remote_host, remote_host] + try: + client = ParallelSSHClient( + hosts, port=self.port, pkey=self.user_key, + proxy_host=self.proxy_host, proxy_port=self.port, num_retries=1, + proxy_pkey=self.user_key, + timeout=.001) + output = client.run_command(self.cmd, stop_on_errors=False) + client.join(output) + for host, host_out in output.items(): + self.assertIsInstance(output[host].exception, Timeout) + finally: + remote_server.stop() + remote_server.join() diff --git a/tests/test_output.py b/tests/test_output.py index 629032dc..d33bd905 100644 --- a/tests/test_output.py +++ b/tests/test_output.py @@ -30,6 +30,9 @@ class TestHostOutput(unittest.TestCase): def setUp(self): self.output = HostOutput(None, None, None, None, None, None) + def test_print(self): + self.assertTrue(str(self.output)) + def test_update(self): host, cmd, chan, stdout, stderr, \ stdin, exit_code, exception = 'host', 'cmd', 'chan', 'stdout', \ diff --git a/tests/test_paramiko_parallel_client.py b/tests/test_paramiko_parallel_client.py index c11f3e60..5da0f4a3 100644 --- a/tests/test_paramiko_parallel_client.py +++ b/tests/test_paramiko_parallel_client.py @@ -1,18 +1,18 @@ # -*- coding: utf-8 -*- - +# # This file is part of parallel-ssh. - +# # Copyright (C) 2015 Panos Kittenis - +# # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation, version 2.1. - +# # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. - +# # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA diff --git a/tests/test_utils.py b/tests/test_utils.py index 18209093..45c20d59 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,20 @@ +# This file is part of parallel-ssh. +# +# Copyright (C) 2015 Panos Kittenis +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation, version 2.1. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + from pssh import utils import unittest import os @@ -12,6 +29,7 @@ DSA_KEY_FILENAME = os.path.sep.join([os.path.dirname(__file__), 'test_client_private_key_dsa']) ECDSA_KEY_FILENAME = os.path.sep.join([os.path.dirname(__file__), 'test_client_private_key_ecdsa']) + class ParallelSSHUtilsTest(unittest.TestCase): def test_enabling_host_logger(self): @@ -22,13 +40,16 @@ def test_enabling_host_logger(self): utils.enable_host_logger() self.assertTrue(len([h for h in utils.host_logger.handlers if not isinstance(h, NullHandler)]) == 1) + utils.host_logger.handlers = [NullHandler()] def test_enabling_pssh_logger(self): self.assertTrue(len([h for h in utils.logger.handlers if isinstance(h, NullHandler)]) == 1) utils.enable_logger(utils.logger) - self.assertTrue(len([h for h in utils.host_logger.handlers + utils.enable_logger(utils.logger) + self.assertTrue(len([h for h in utils.logger.handlers if not isinstance(h, NullHandler)]) == 1) + utils.logger.handlers = [NullHandler()] def test_loading_key_files(self): for key_filename in [PKEY_FILENAME, DSA_KEY_FILENAME, ECDSA_KEY_FILENAME]: @@ -36,8 +57,15 @@ def test_loading_key_files(self): self.assertTrue(pkey, msg="Error loading key from file %s" % (key_filename,)) pkey = utils.load_private_key(open(key_filename)) self.assertTrue(pkey, msg="Error loading key from open file object for file %s" % (key_filename,)) - fake_key = BytesIO(b"blah blah fakey fakey key") + fake_key = BytesIO(b"blah blah fakey fakey key\n") self.assertFalse(utils.load_private_key(fake_key)) + fake_file = 'fake_key_file' + with open(fake_file, 'wb') as fh: + fh.write(b'fake key data\n') + try: + self.assertIsNone(utils.load_private_key(fake_file)) + finally: + os.unlink(fake_file) def test_openssh_config_missing(self): self.assertFalse(utils.read_openssh_config('test', config_file=str(uuid4())))