Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Platform independent temporary files #933

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions src/saml2/sigver.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None):
return signed_xml


def make_temp(content, suffix="", decode=True, delete_tmpfiles=True):
def make_temp(content, suffix="", decode=True):
"""
Create a temporary file with the given content.

Expand All @@ -345,19 +345,25 @@ def make_temp(content, suffix="", decode=True, delete_tmpfiles=True):
suffix in certain circumstances.
:param decode: The input content might be base64 coded. If so it
must, in some cases, be decoded before being placed in the file.
:param delete_tmpfiles: Whether to keep the tmp files or delete them when they are
no longer in use
:return: 2-tuple with file pointer ( so the calling function can
close the file) and filename (which is for instance needed by the
xmlsec function).
"""
content_encoded = content.encode("utf-8") if not isinstance(content, bytes) else content
content_raw = base64.b64decode(content_encoded) if decode else content_encoded
ntf = NamedTemporaryFile(suffix=suffix, delete=delete_tmpfiles)
ntf = NamedTemporaryFile(suffix=suffix, delete=False)
ntf.write(content_raw)
ntf.seek(0)
return ntf

def delete_filename(filename):
"""
Silent remove filename if exists
"""
try:
os.remove(filename)
except FileNotFoundError:
pass

def split_len(seq, length):
return [seq[i : i + length] for i in range(0, len(seq), length)]
Expand Down Expand Up @@ -673,7 +679,7 @@ def encrypt(self, text, recv_key, template, session_key_type, xpath=""):
:return:
"""
logger.debug("Encryption input len: %d", len(text))
tmp = make_temp(text, decode=False, delete_tmpfiles=self.delete_tmpfiles)
tmp = make_temp(text, decode=False)
com_list = [
self.xmlsec,
"--encrypt",
Expand All @@ -693,6 +699,10 @@ def encrypt(self, text, recv_key, template, session_key_type, xpath=""):
except XmlsecError as e:
raise EncryptError(com_list) from e

if self.delete_tmpfiles:
tmp.close()
delete_filename(tmp.name)

return output

def encrypt_assertion(self, statement, enc_key, template, key_type="des-192", node_xpath=None, node_id=None):
Expand All @@ -709,8 +719,8 @@ def encrypt_assertion(self, statement, enc_key, template, key_type="des-192", no
if isinstance(statement, SamlBase):
statement = pre_encrypt_assertion(statement)

tmp = make_temp(str(statement), decode=False, delete_tmpfiles=self.delete_tmpfiles)
tmp2 = make_temp(str(template), decode=False, delete_tmpfiles=self.delete_tmpfiles)
tmp = make_temp(str(statement), decode=False)
tmp2 = make_temp(str(template), decode=False)

if not node_xpath:
node_xpath = ASSERT_XPATH
Expand All @@ -736,6 +746,12 @@ def encrypt_assertion(self, statement, enc_key, template, key_type="des-192", no
except XmlsecError as e:
raise EncryptError(com_list) from e

if self.delete_tmpfiles:
tmp.close()
tmp2.close()
delete_filename(tmp.name)
delete_filename(tmp2.name)

return output.decode("utf-8")

def decrypt(self, enctext, key_file):
Expand All @@ -747,7 +763,7 @@ def decrypt(self, enctext, key_file):
"""

logger.debug("Decrypt input len: %d", len(enctext))
tmp = make_temp(enctext, decode=False, delete_tmpfiles=self.delete_tmpfiles)
tmp = make_temp(enctext, decode=False)

com_list = [
self.xmlsec,
Expand All @@ -763,6 +779,10 @@ def decrypt(self, enctext, key_file):
except XmlsecError as e:
raise DecryptError(com_list) from e

if self.delete_tmpfiles:
tmp.close()
delete_filename(tmp.name)

return output.decode("utf-8")

def sign_statement(self, statement, node_name, key_file, node_id):
Expand All @@ -778,7 +798,7 @@ def sign_statement(self, statement, node_name, key_file, node_id):
if isinstance(statement, SamlBase):
statement = str(statement)

tmp = make_temp(statement, suffix=".xml", decode=False, delete_tmpfiles=self.delete_tmpfiles)
tmp = make_temp(statement, suffix=".xml", decode=False)

com_list = [
self.xmlsec,
Expand All @@ -797,6 +817,10 @@ def sign_statement(self, statement, node_name, key_file, node_id):
except XmlsecError as e:
raise SignatureError(com_list) from e

if self.delete_tmpfiles:
tmp.close()
delete_filename(tmp.name)

# this does not work if --store-signatures is used
if output:
return output.decode("utf-8")
Expand All @@ -818,7 +842,7 @@ def validate_signature(self, signedtext, cert_file, cert_type, node_name, node_i
if not isinstance(signedtext, bytes):
signedtext = signedtext.encode("utf-8")

tmp = make_temp(signedtext, suffix=".xml", decode=False, delete_tmpfiles=self.delete_tmpfiles)
tmp = make_temp(signedtext, suffix=".xml", decode=False)

com_list = [
self.xmlsec,
Expand All @@ -841,6 +865,10 @@ def validate_signature(self, signedtext, cert_file, cert_type, node_name, node_i
except XmlsecError as e:
raise SignatureError(com_list) from e

if self.delete_tmpfiles:
tmp.close()
delete_filename(tmp.name)

return parse_xmlsec_verify_output(stderr, self.version_nums)

def _run_xmlsec(self, com_list, extra_args):
Expand All @@ -851,7 +879,7 @@ def _run_xmlsec(self, com_list, extra_args):
key-value parameters
:result: Whatever xmlsec wrote to an --output temporary file
"""
with NamedTemporaryFile(suffix=".xml") as ntf:
with NamedTemporaryFile(suffix=".xml", delete=False) as ntf:
com_list.extend(["--output", ntf.name])
if self.version_nums >= (1, 3):
com_list.extend(['--lax-key-search'])
Expand All @@ -870,7 +898,13 @@ def _run_xmlsec(self, com_list, extra_args):
raise XmlsecError(errmsg)

ntf.seek(0)
return p_out, p_err, ntf.read()
ntf_read = ntf.read()

if self.delete_tmpfiles:
ntf.close()
delete_filename(ntf.name)

return p_out, p_err, ntf_read


class CryptoBackendXMLSecurity(CryptoBackend):
Expand Down Expand Up @@ -1309,10 +1343,16 @@ def decrypt_keys(self, enctext, keys=None):

keys_filtered = (key for key in keys if key)
keys_encoded = (key.encode("ascii") if not isinstance(key, bytes) else key for key in keys_filtered)
key_files = list(make_temp(key, decode=False, delete_tmpfiles=self.delete_tmpfiles) for key in keys_encoded)
key_files = list(make_temp(key, decode=False) for key in keys_encoded)
key_file_names = list(tmp.name for tmp in key_files)

dectext = self.decrypt(enctext, key_file=key_file_names)

if self.delete_tmpfiles:
for tmp in key_files:
tmp.close()
delete_filename(tmp.name)

return dectext

def decrypt(self, enctext, key_file=None):
Expand Down Expand Up @@ -1387,7 +1427,7 @@ def _check_signature(
for cert_name, cert in _certs:
if isinstance(cert, str):
content = pem_format(cert)
tmp = make_temp(content, suffix=".pem", decode=False, delete_tmpfiles=self.delete_tmpfiles)
tmp = make_temp(content, suffix=".pem", decode=False)
certs.append(tmp)
else:
certs.append(cert)
Expand All @@ -1397,7 +1437,7 @@ def _check_signature(
if not certs and not self.only_use_keys_in_metadata:
logger.debug("==== Certs from instance ====")
certs = [
make_temp(content=pem_format(cert), suffix=".pem", decode=False, delete_tmpfiles=self.delete_tmpfiles)
make_temp(content=pem_format(cert), suffix=".pem", decode=False)
for cert in cert_from_instance(item)
]
else:
Expand Down Expand Up @@ -1524,6 +1564,11 @@ def _check_signature(
else:
raise SignatureError("Failed to verify signature")

if self.delete_tmpfiles:
for tmp in certs:
tmp.close()
delete_filename(tmp.name)

return item

def check_signature(self, item, node_name=NODE_NAME, origdoc=None, must=False, issuer=None):
Expand Down Expand Up @@ -1686,12 +1731,16 @@ def sign_statement(self, statement, node_name, key=None, key_file=None, node_id=
"""
if not key_file and key:
content = str(key).encode()
tmp = make_temp(content, suffix=".pem", delete_tmpfiles=self.delete_tmpfiles)
tmp = make_temp(content, suffix=".pem")
key_file = tmp.name

if not key and not key_file:
key_file = self.key_file

if 'tmp' in locals() and self.delete_tmpfiles:
tmp.close()
delete_filename(tmp.name)

return self.crypto.sign_statement(
statement,
node_name,
Expand Down