diff --git a/tests/test_backend.py b/tests/test_backend.py index ca0f2e8..d2cb08c 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -4,6 +4,7 @@ import socket from ssl import SSLError import tempfile +from unittest import mock import pytest from pathlib import Path @@ -75,6 +76,16 @@ def setUpClass(cls): def stop_smtp(cls): cls.smtp_controller.stop() + def setUp(self): + super().setUp() + _, self.tmpKey = tempfile.mkstemp() + _, self.tmpCert = tempfile.mkstemp() + + def tearDown(self): + super().tearDown() + Path(self.tmpKey).unlink() + Path(self.tmpCert).unlink() + def test_console_backend(self): self.app.extensions['mailman'].backend = 'console' msg = EmailMessage( @@ -344,11 +355,9 @@ def test_email_tls_attempts_starttls(self): def test_email_ssl_attempts_ssl_connection(self): with SmtpdContext(self.app.extensions['mailman']): - _, tmpKey = tempfile.mkstemp() - _, tmpCert = tempfile.mkstemp() self.app.extensions['mailman'].use_ssl = True - self.app.extensions['mailman'].ssl_keyfile = tmpKey - self.app.extensions['mailman'].ssl_certfile = tmpCert + self.app.extensions['mailman'].ssl_keyfile = self.tmpKey + self.app.extensions['mailman'].ssl_certfile = self.tmpCert backend = smtp.EmailBackend() self.assertTrue(backend.use_ssl) @@ -356,8 +365,23 @@ def test_email_ssl_attempts_ssl_connection(self): with backend: pass - Path(tmpKey).unlink() - Path(tmpCert).unlink() + @mock.patch("ssl.SSLContext.load_cert_chain", return_value="") + def test_email_ssl_cached_context(self, result_mocked): + with SmtpdContext(self.app.extensions['mailman']): + self.app.extensions['mailman'].use_ssl = True + + backend_one = smtp.EmailBackend() + backend_another = smtp.EmailBackend() + + self.assertTrue(backend_one.ssl_context, backend_another.ssl_context) + + self.app.extensions['mailman'].ssl_keyfile = self.tmpKey + self.app.extensions['mailman'].ssl_certfile = self.tmpCert + + backend_one = smtp.EmailBackend() + backend_another = smtp.EmailBackend() + + self.assertTrue(backend_one.ssl_context, backend_another.ssl_context) def test_connection_timeout_default(self): """The connection's timeout value is None by default."""