Skip to content

Commit

Permalink
Report sftp_error when write fails (#216)
Browse files Browse the repository at this point in the history
* Report sftp_error when write fails

* sftp_get: write bytes directly, don't expect them to be unicode

* Change _get_sftp_error_str to actually return strs
  • Loading branch information
Qalthos authored Oct 28, 2021
1 parent 30d8dfe commit a74c798
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/changelog-fragments/216.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Changed ``sftp.sftp_get`` to write files as bytes rather than assuming files are valid UTF8 -- by :user:`Qalthos`
1 change: 1 addition & 0 deletions docs/changelog-fragments/216.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added additional details when SFTP write errors are raised -- by :user:`Qalthos`
26 changes: 17 additions & 9 deletions src/pylibsshext/sftp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,20 @@ cdef class SFTP:

rf = sftp.sftp_open(self._libssh_sftp_session, remote_file_b, O_WRONLY | O_CREAT | O_TRUNC, sftp.S_IRWXU)
if rf is NULL:
raise LibsshSFTPException("Opening remote file [%s] for write failed with error [%s]" % (remote_file, MSG_MAP.get(self._get_sftp_error_str())))
raise LibsshSFTPException("Opening remote file [%s] for write failed with error [%s]" % (remote_file, self._get_sftp_error_str()))
buffer = f.read(1024)

while buffer != b"":
length = len(buffer)
written = sftp.sftp_write(rf, PyBytes_AS_STRING(buffer), length)
if written != length:
sftp.sftp_close(rf)
raise LibsshSFTPException("Writing to remote file [%s] failed" % remote_file)
raise LibsshSFTPException(
"Writing to remote file [%s] failed with error [%s]" % (
remote_file,
self._get_sftp_error_str(),
)
)
buffer = f.read(1024)
sftp.sftp_close(rf)

Expand All @@ -84,7 +89,7 @@ cdef class SFTP:

rf = sftp.sftp_open(self._libssh_sftp_session, remote_file_b, O_RDONLY, sftp.S_IRWXU)
if rf is NULL:
raise LibsshSFTPException("Opening remote file [%s] for read failed with error [%s]" % (remote_file, MSG_MAP.get(self._get_sftp_error_str())))
raise LibsshSFTPException("Opening remote file [%s] for read failed with error [%s]" % (remote_file, self._get_sftp_error_str()))

while True:
file_data = sftp.sftp_read(rf, <void *>read_buffer, sizeof(char) * 1024)
Expand All @@ -93,16 +98,16 @@ cdef class SFTP:
elif file_data < 0:
sftp.sftp_close(rf)
raise LibsshSFTPException("Reading data from remote file [%s] failed with error [%s]"
% (remote_file, MSG_MAP.get(self._get_sftp_error_str())))
% (remote_file, self._get_sftp_error_str()))

with open(local_file, 'w+') as f:
bytes_wrote = f.write(read_buffer[:file_data].decode('utf-8'))
if bytes_wrote and file_data != bytes_wrote:
with open(local_file, 'wb+') as f:
bytes_written = f.write(read_buffer[:file_data])
if bytes_written and file_data != bytes_written:
sftp.sftp_close(rf)
raise LibsshSFTPException("Number of bytes [%s] read from remote file [%s]"
" does not match number of bytes [%s] written to local file [%s]"
" due to error [%s]"
% (file_data, remote_file, bytes_wrote, local_file, MSG_MAP.get(self._get_sftp_error_str())))
% (file_data, remote_file, bytes_written, local_file, self._get_sftp_error_str()))
sftp.sftp_close(rf)

def close(self):
Expand All @@ -111,4 +116,7 @@ cdef class SFTP:
self._libssh_sftp_session = NULL

def _get_sftp_error_str(self):
return sftp.sftp_get_error(self._libssh_sftp_session)
error = sftp.sftp_get_error(self._libssh_sftp_session)
if error in MSG_MAP and error != sftp.SSH_FX_FAILURE:
return MSG_MAP[error]
return "Generic failure: %s" % self.session._get_session_error_str()

0 comments on commit a74c798

Please sign in to comment.