Skip to content

Commit

Permalink
Sftp perf (#328)
Browse files Browse the repository at this point in the history
* Updated tests, setup.cfg
* Fix regression in sftp put performance, added test
* Updated changelog
* Updated manifest
  • Loading branch information
pkittenis authored Nov 27, 2021
1 parent 32e62fd commit 3215d05
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 23 deletions.
10 changes: 10 additions & 0 deletions Changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
Change Log
============

2.7.1
+++++

Fixes
------

* ``copy_file`` performance would be abnormally low when copying plain text files - 100x performance increase. Binary
file copying performance has also increased.


2.7.0
+++++

Expand Down
2 changes: 0 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,3 @@ include LICENSE
include COPYING
include COPYING.LESSER
recursive-exclude tests *
include pssh/native/*.c
include pssh/native/*.pyx
15 changes: 9 additions & 6 deletions examples/sftp_copy_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@


with open('file_copy', 'wb') as fh:
for _ in range(2000000):
fh.write(b'asdfa')
# 200MB
for _ in range(20055120):
fh.write(b'asdfartkj\n')


fileinfo = os.stat('file_copy')
client = ParallelSSHClient(['localhost'])
mb_size = fileinfo.st_size / (1024000.0)
client = ParallelSSHClient(['127.0.0.1'], timeout=1, num_retries=1)
print(f"Starting copy of {mb_size}MB file")
now = datetime.now()
cmd = client.copy_file('file_copy', '/tmp/file_copy')
joinall(cmd, raise_error=True)
taken = datetime.now() - now
mb_size = fileinfo.st_size / (1024000.0)
rate = mb_size / taken.total_seconds()
print("File size %sMB transfered in %s, transfer rate %s MB/s" % (
mb_size, taken, rate))
print("File size %sMB transfered in %s, transfer rate %s MB/s" % (mb_size, taken, rate))
os.unlink('file_copy')
os.unlink('/tmp/file_copy')
8 changes: 6 additions & 2 deletions pssh/clients/native/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@

class SSHClient(BaseSSHClient):
"""ssh2-python (libssh2) based non-blocking SSH client."""
# 2MB buffer
_BUF_SIZE = 2048 * 1024

def __init__(self, host,
user=None, password=None, port=None,
Expand Down Expand Up @@ -411,9 +413,11 @@ def copy_file(self, local_file, remote_file, recurse=False, sftp=None):
local_file, self.host, remote_file)

def _sftp_put(self, remote_fh, local_file):
with open(local_file, 'rb', 2097152) as local_fh:
for data in local_fh:
with open(local_file, 'rb', self._BUF_SIZE) as local_fh:
data = local_fh.read(self._BUF_SIZE)
while data:
self.eagain_write(remote_fh.write, data)
data = local_fh.read(self._BUF_SIZE)

def sftp_put(self, sftp, local_file, remote_file):
mode = LIBSSH2_SFTP_S_IRUSR | \
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ tag_prefix = ''
max-line-length = 100

[tool:pytest]
addopts=-v --cov=pssh --cov-append --cov-report=term --cov-report=term-missing
addopts=-v --cov=pssh --cov-append --cov-report=term --cov-report=term-missing --durations=10
testpaths =
tests
1 change: 1 addition & 0 deletions tests/embedded_server/openssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def start_server(self):
pass
else:
logger.error(self.server_proc.stdout.read())
logger.error(self.server_proc.stderr.read())
raise Exception("Server could not start")

def stop(self):
Expand Down
16 changes: 5 additions & 11 deletions tests/native/test_parallel_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# This file is part of parallel-ssh.
#
# Copyright (C) 2014-2020 Panos Kittenis
# Copyright (C) 2014-2021 Panos Kittenis
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
Expand Down Expand Up @@ -29,11 +29,11 @@
from datetime import datetime
from unittest.mock import patch, MagicMock

from gevent import joinall, spawn, socket, Greenlet, sleep, Timeout as GTimeout
from gevent import joinall, spawn, socket, sleep, Timeout as GTimeout
from pssh.config import HostConfig
from pssh.clients.native import ParallelSSHClient
from pssh.exceptions import UnknownHostException, \
AuthenticationException, ConnectionErrorException, SessionError, \
AuthenticationException, ConnectionErrorException, \
HostArgumentException, SFTPError, SFTPIOError, Timeout, SCPError, \
PKeyFileError, ShellError, HostArgumentError, NoIPv6AddressFoundError
from pssh.output import HostOutput
Expand Down Expand Up @@ -1042,6 +1042,7 @@ def test_per_host_tuple_args(self):
pkey=self.user_key,
num_retries=2)
output = client.run_command(cmd, host_args=host_args)
client.join()
for i, host in enumerate(hosts):
expected = [host_args[i]]
stdout = list(output[i].stdout)
Expand All @@ -1050,6 +1051,7 @@ def test_per_host_tuple_args(self):
host_args = (('arg1', 'arg2'), ('arg3', 'arg4'), ('arg5', 'arg6'),)
cmd = 'echo %s %s'
output = client.run_command(cmd, host_args=host_args)
client.join()
for i, host in enumerate(hosts):
expected = ["%s %s" % host_args[i]]
stdout = list(output[i].stdout)
Expand Down Expand Up @@ -1198,14 +1200,6 @@ def test_run_command_sudo(self):
self.assertEqual(len(output), len(self.client.hosts))
self.assertTrue(output[0].channel is not None)

@unittest.skipUnless(bool(os.getenv('TRAVIS')), "Not on Travis CI - skipping")
def test_run_command_sudo_var(self):
command = """for i in 1 2 3; do echo $i; done"""
output = list(self.client.run_command(
command, sudo=True)[0].stdout)
expected = ['1','2','3']
self.assertListEqual(output, expected)

def test_conn_failure(self):
"""Test connection error failure case - ConnectionErrorException"""
client = ParallelSSHClient(['127.0.0.100'], port=self.port,
Expand Down
18 changes: 18 additions & 0 deletions tests/native/test_single_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import subprocess
import shutil
import tempfile
from tempfile import NamedTemporaryFile
from pytest import raises
from unittest.mock import MagicMock, call, patch
from hashlib import sha256
Expand Down Expand Up @@ -562,6 +563,23 @@ def test_copy_file_remote_dir_relpath(self):
except Exception:
pass

def test_copy_file_with_newlines(self):
with NamedTemporaryFile('wb') as temp_file:
# 2MB
for _ in range(200512):
temp_file.write(b'asdfartkj\n')
temp_file.flush()
now = datetime.now()
try:
self.client.copy_file(os.path.abspath(temp_file.name), 'write_file')
took = datetime.now() - now
assert took.total_seconds() < 1
finally:
try:
os.unlink(os.path.expanduser('~/write_file'))
except OSError:
pass

def test_sftp_mkdir_abspath(self):
remote_dir = '/tmp/dir_to_create/dir1/dir2/dir3'
_sftp = self.client._make_sftp()
Expand Down
2 changes: 1 addition & 1 deletion tests/native/test_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_tunnel_server_reconn(self):
remote_server.start_server()

reconn_n = 20 # Number of reconnect attempts
reconn_delay = 1 # Number of seconds to delay betwen reconnects
reconn_delay = .1 # Number of seconds to delay between reconnects
try:
for _ in range(reconn_n):
client = SSHClient(
Expand Down

0 comments on commit 3215d05

Please sign in to comment.