Skip to content

Commit

Permalink
Merge pull request #68 from tkzt/fix/improve-key_cert-populating
Browse files Browse the repository at this point in the history
  • Loading branch information
greyli authored Dec 4, 2023
2 parents 3624515 + 94bc835 commit 39d4167
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
18 changes: 12 additions & 6 deletions flask_mailman/backends/smtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import ssl
import threading

from werkzeug.utils import cached_property

from flask_mailman.backends.base import BaseEmailBackend
from flask_mailman.message import sanitize_address
from flask_mailman.utils import DNS_NAME
Expand Down Expand Up @@ -48,6 +50,15 @@ def __init__(
def connection_class(self):
return smtplib.SMTP_SSL if self.use_ssl else smtplib.SMTP

@cached_property
def ssl_context(self):
if self.ssl_certfile or self.ssl_keyfile:
ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_cert_chain(self.ssl_certfile, self.ssl_keyfile)
return ssl_context
else:
return ssl.create_default_context()

def open(self):
"""
Ensure an open connection to the email server. Return whether or not a
Expand All @@ -64,12 +75,7 @@ def open(self):
if self.timeout is not None:
connection_params['timeout'] = self.timeout
if self.use_ssl:
connection_params.update(
{
'keyfile': self.ssl_keyfile,
'certfile': self.ssl_certfile,
}
)
connection_params["context"] = self.ssl_context
try:
self.connection = self.connection_class(self.host, self.port, **connection_params)

Expand Down
37 changes: 37 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from email.utils import parseaddr
import os
import socket
from ssl import SSLError
import tempfile
from unittest import mock
import pytest

from pathlib import Path
Expand Down Expand Up @@ -341,6 +343,41 @@ def test_email_tls_attempts_starttls(self):
with backend:
pass

def test_email_ssl_attempts_ssl_connection(self):
fake_keyfile = os.path.join(os.path.dirname(__file__), "attachments", 'file.txt')
fake_certfile = os.path.join(os.path.dirname(__file__), "attachments", 'file_txt')
with SmtpdContext(self.app.extensions['mailman']):
self.app.extensions['mailman'].use_ssl = True
self.app.extensions['mailman'].ssl_keyfile = fake_keyfile
self.app.extensions['mailman'].ssl_certfile = fake_certfile

backend = smtp.EmailBackend()
self.assertTrue(backend.use_ssl)
with self.assertRaises(SSLError):
with backend:
pass

@mock.patch("ssl.SSLContext.load_cert_chain", return_value="")
def test_email_ssl_cached_context(self, result_mocked):
fake_keyfile = os.path.join(os.path.dirname(__file__), "attachments", 'file.txt')
fake_certfile = os.path.join(os.path.dirname(__file__), "attachments", 'file_txt')

with SmtpdContext(self.app.extensions['mailman']):
self.app.extensions['mailman'].use_ssl = True

backend_one = smtp.EmailBackend()
backend_another = smtp.EmailBackend()

self.assertTrue(backend_one.ssl_context, backend_another.ssl_context)

self.app.extensions['mailman'].ssl_keyfile = fake_keyfile
self.app.extensions['mailman'].ssl_certfile = fake_certfile

backend_one = smtp.EmailBackend()
backend_another = smtp.EmailBackend()

self.assertTrue(backend_one.ssl_context, backend_another.ssl_context)

def test_connection_timeout_default(self):
"""The connection's timeout value is None by default."""
self.app.extensions['mailman'].backend = "smtp"
Expand Down

0 comments on commit 39d4167

Please sign in to comment.