diff --git a/flask_mailman/backends/smtp.py b/flask_mailman/backends/smtp.py index ceb60ee..4198b49 100644 --- a/flask_mailman/backends/smtp.py +++ b/flask_mailman/backends/smtp.py @@ -3,6 +3,8 @@ import ssl import threading +from werkzeug.utils import cached_property + from flask_mailman.backends.base import BaseEmailBackend from flask_mailman.message import sanitize_address from flask_mailman.utils import DNS_NAME @@ -48,6 +50,15 @@ def __init__( def connection_class(self): return smtplib.SMTP_SSL if self.use_ssl else smtplib.SMTP + @cached_property + def ssl_context(self): + if self.ssl_certfile or self.ssl_keyfile: + ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + ssl_context.load_cert_chain(self.ssl_certfile, self.ssl_keyfile) + return ssl_context + else: + return ssl.create_default_context() + def open(self): """ Ensure an open connection to the email server. Return whether or not a @@ -64,12 +75,7 @@ def open(self): if self.timeout is not None: connection_params['timeout'] = self.timeout if self.use_ssl: - connection_params.update( - { - 'keyfile': self.ssl_keyfile, - 'certfile': self.ssl_certfile, - } - ) + connection_params["context"] = self.ssl_context try: self.connection = self.connection_class(self.host, self.port, **connection_params) diff --git a/tests/test_backend.py b/tests/test_backend.py index 425be08..7a5418f 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2,7 +2,9 @@ from email.utils import parseaddr import os import socket +from ssl import SSLError import tempfile +from unittest import mock import pytest from pathlib import Path @@ -341,6 +343,41 @@ def test_email_tls_attempts_starttls(self): with backend: pass + def test_email_ssl_attempts_ssl_connection(self): + fake_keyfile = os.path.join(os.path.dirname(__file__), "attachments", 'file.txt') + fake_certfile = os.path.join(os.path.dirname(__file__), "attachments", 'file_txt') + with SmtpdContext(self.app.extensions['mailman']): + self.app.extensions['mailman'].use_ssl = True + self.app.extensions['mailman'].ssl_keyfile = fake_keyfile + self.app.extensions['mailman'].ssl_certfile = fake_certfile + + backend = smtp.EmailBackend() + self.assertTrue(backend.use_ssl) + with self.assertRaises(SSLError): + with backend: + pass + + @mock.patch("ssl.SSLContext.load_cert_chain", return_value="") + def test_email_ssl_cached_context(self, result_mocked): + fake_keyfile = os.path.join(os.path.dirname(__file__), "attachments", 'file.txt') + fake_certfile = os.path.join(os.path.dirname(__file__), "attachments", 'file_txt') + + with SmtpdContext(self.app.extensions['mailman']): + self.app.extensions['mailman'].use_ssl = True + + backend_one = smtp.EmailBackend() + backend_another = smtp.EmailBackend() + + self.assertTrue(backend_one.ssl_context, backend_another.ssl_context) + + self.app.extensions['mailman'].ssl_keyfile = fake_keyfile + self.app.extensions['mailman'].ssl_certfile = fake_certfile + + backend_one = smtp.EmailBackend() + backend_another = smtp.EmailBackend() + + self.assertTrue(backend_one.ssl_context, backend_another.ssl_context) + def test_connection_timeout_default(self): """The connection's timeout value is None by default.""" self.app.extensions['mailman'].backend = "smtp"