Skip to content

Commit

Permalink
Merge branch 'concurrent-config' into main - closes #155
Browse files Browse the repository at this point in the history
  • Loading branch information
simonrob committed Sep 6, 2023
2 parents b676a9c + 33e7ecc commit 733278f
Showing 1 changed file with 91 additions and 68 deletions.
159 changes: 91 additions & 68 deletions emailproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
__author__ = 'Simon Robinson'
__copyright__ = 'Copyright (c) 2023 Simon Robinson'
__license__ = 'Apache 2.0'
__version__ = '2023-09-03' # ISO 8601 (YYYY-MM-DD)
__version__ = '2023-09-06' # ISO 8601 (YYYY-MM-DD)

import abc
import argparse
Expand Down Expand Up @@ -405,15 +405,59 @@ def save(store_id, config_dict, create_secret=True):
Log.error('Unable to get AWS SDK client; cannot cache credentials to AWS Secrets Manager')


class ConcurrentConfigParser:
"""Helper wrapper to add locking to a ConfigParser object (note: only wraps the methods used in this script)"""

def __init__(self):
self.config = configparser.ConfigParser()
self.lock = threading.Lock()

def read(self, filename):
with self.lock:
self.config.read(filename)

def sections(self):
with self.lock:
return self.config.sections()

def add_section(self, section):
with self.lock:
self.config.add_section(section)

def get(self, section, option, fallback=None):
with self.lock:
return self.config.get(section, option, fallback=fallback)

def getint(self, section, option, fallback=None):
with self.lock:
return self.config.getint(section, option, fallback=fallback)

def getboolean(self, section, option, fallback=None):
with self.lock:
return self.config.getboolean(section, option, fallback=fallback)

def set(self, section, option, value):
with self.lock:
self.config.set(section, option, value)

def remove_option(self, section, option):
with self.lock:
self.config.remove_option(section, option)

def write(self, file):
with self.lock:
self.config.write(file)

def items(self):
with self.lock:
return self.config.items() # used in read_dict when saving to cache store


class AppConfig:
"""Helper wrapper around ConfigParser to cache servers/accounts, and avoid writing to the file until necessary"""

_PARSER = None
_LOADED = False

_GLOBALS = None
_SERVERS = []
_ACCOUNTS = []
_PARSER_LOCK = threading.Lock()

# note: removing the unencrypted version of `client_secret_encrypted` is not automatic with --cache-store (see docs)
_CACHED_OPTION_KEYS = ['token_salt', 'access_token', 'access_token_expiry', 'refresh_token', 'last_activity',
Expand All @@ -424,38 +468,26 @@ class AppConfig:

@staticmethod
def _load():
AppConfig.unload()
AppConfig._PARSER = configparser.ConfigParser()
AppConfig._PARSER.read(CONFIG_FILE_PATH)

config_sections = AppConfig._PARSER.sections()
if APP_SHORT_NAME in config_sections:
AppConfig._GLOBALS = AppConfig._PARSER[APP_SHORT_NAME]
else:
AppConfig._GLOBALS = configparser.SectionProxy(AppConfig._PARSER, APP_SHORT_NAME)
config_parser = ConcurrentConfigParser()
config_parser.read(CONFIG_FILE_PATH)

# cached account credentials can be stored in the configuration file (default) or, via `--cache-store`, a
# separate local file or external service (such as a secrets manager) - we combine these sources at load time
if CACHE_STORE != CONFIG_FILE_PATH:
# it would be cleaner to avoid specific options here, but best to load unexpected sections only when enabled
allow_catch_all_accounts = AppConfig._GLOBALS.getboolean('allow_catch_all_accounts', fallback=False)
allow_catch_all_accounts = config_parser.getboolean(APP_SHORT_NAME, 'allow_catch_all_accounts',
fallback=False)

cache_file_parser = AppConfig._load_cache(CACHE_STORE)
cache_file_accounts = [s for s in cache_file_parser.sections() if '@' in s]
for account in cache_file_accounts:
if allow_catch_all_accounts and account not in AppConfig._PARSER.sections(): # missing sub-accounts
AppConfig._PARSER.add_section(account)
if allow_catch_all_accounts and account not in config_parser.sections(): # missing sub-accounts
config_parser.add_section(account)
for option in cache_file_parser.options(account):
if option in AppConfig._CACHED_OPTION_KEYS:
AppConfig._PARSER.set(account, option, cache_file_parser.get(account, option))

if allow_catch_all_accounts:
config_sections = AppConfig._PARSER.sections() # new sections may have been added
config_parser.set(account, option, cache_file_parser.get(account, option))

AppConfig._SERVERS = [s for s in config_sections if CONFIG_SERVER_MATCHER.match(s)]
AppConfig._ACCOUNTS = [s for s in config_sections if '@' in s]

AppConfig._LOADED = True
return config_parser

@staticmethod
def _load_cache(cache_store_identifier):
Expand All @@ -469,59 +501,47 @@ def _load_cache(cache_store_identifier):

@staticmethod
def get():
if not AppConfig._LOADED:
AppConfig._load()
return AppConfig._PARSER
with AppConfig._PARSER_LOCK:
if AppConfig._PARSER is None:
AppConfig._PARSER = AppConfig._load()
return AppConfig._PARSER

@staticmethod
def unload():
AppConfig._PARSER = None
AppConfig._LOADED = False

AppConfig._GLOBALS = None
AppConfig._SERVERS = []
AppConfig._ACCOUNTS = []

@staticmethod
def reload():
AppConfig.unload()
return AppConfig.get()
with AppConfig._PARSER_LOCK:
AppConfig._PARSER = None

@staticmethod
def globals():
AppConfig.get() # make sure config is loaded
return AppConfig._GLOBALS
def get_global(name, fallback):
return AppConfig.get().getboolean(APP_SHORT_NAME, name, fallback)

@staticmethod
def servers():
AppConfig.get() # make sure config is loaded
return AppConfig._SERVERS
return [s for s in AppConfig.get().sections() if CONFIG_SERVER_MATCHER.match(s)]

@staticmethod
def accounts():
AppConfig.get() # make sure config is loaded
return AppConfig._ACCOUNTS

@staticmethod
def add_account(username):
AppConfig._PARSER.add_section(username)
AppConfig._ACCOUNTS = [s for s in AppConfig._PARSER.sections() if '@' in s]
return [s for s in AppConfig.get().sections() if '@' in s]

@staticmethod
def save():
if AppConfig._LOADED:
with AppConfig._PARSER_LOCK:
if AppConfig._PARSER is None: # intentionally using _PARSER not get() so we don't (re-)load if unloaded
return

if CACHE_STORE != CONFIG_FILE_PATH:
# in `--cache-store` mode we ignore everything except _CACHED_OPTION_KEYS (OAuth 2.0 tokens, etc)
output_config_parser = configparser.ConfigParser()
output_config_parser.read_dict(AppConfig._PARSER) # a deep copy of the current configuration
config_accounts = [s for s in output_config_parser.sections() if '@' in s]

for account in AppConfig._ACCOUNTS:
for account in config_accounts:
for option in output_config_parser.options(account):
if option not in AppConfig._CACHED_OPTION_KEYS:
output_config_parser.remove_option(account, option)

for section in output_config_parser.sections():
if section not in AppConfig._ACCOUNTS or len(output_config_parser.options(section)) <= 0:
if section not in config_accounts or len(output_config_parser.options(section)) <= 0:
output_config_parser.remove_section(section)

AppConfig._save_cache(CACHE_STORE, output_config_parser)
Expand Down Expand Up @@ -557,10 +577,11 @@ def get_oauth2_credentials(username, password, recurse_retries=True):
if invalid). Returns either (True, '[OAuth2 string for authentication]') or (False, '[Error message]')"""

# we support broader catch-all account names (e.g., `@domain.com` / `@`) if enabled
valid_accounts = [username in AppConfig.accounts()]
if AppConfig.globals().getboolean('allow_catch_all_accounts', fallback=False):
config_accounts = AppConfig.accounts()
valid_accounts = [username in config_accounts]
if AppConfig.get_global('allow_catch_all_accounts', fallback=False):
user_domain = '@%s' % username.split('@')[-1]
valid_accounts.extend([account in AppConfig.accounts() for account in [user_domain, '@']])
valid_accounts.extend([account in config_accounts for account in [user_domain, '@']])

if not any(valid_accounts):
Log.error('Proxy config file entry missing for account', username, '- aborting login')
Expand All @@ -572,7 +593,7 @@ def get_oauth2_credentials(username, password, recurse_retries=True):

def get_account_with_catch_all_fallback(option):
fallback = None
if AppConfig.globals().getboolean('allow_catch_all_accounts', fallback=False):
if AppConfig.get_global('allow_catch_all_accounts', fallback=False):
fallback = config.get(user_domain, option, fallback=config.get('@', option, fallback=None))
return config.get(username, option, fallback=fallback)

Expand Down Expand Up @@ -614,7 +635,7 @@ def get_account_with_catch_all_fallback(option):

# try reloading remotely cached tokens if possible
if not access_token and CACHE_STORE != CONFIG_FILE_PATH and recurse_retries:
AppConfig.reload()
AppConfig.unload()
return OAuth2Helper.get_oauth2_credentials(username, password, recurse_retries=False)

# we hash locally-stored tokens with the given password
Expand Down Expand Up @@ -682,8 +703,8 @@ def get_account_with_catch_all_fallback(option):
oauth2_flow, username, password)

access_token = response['access_token']
if not config.has_section(username):
AppConfig.add_account(username) # in wildcard mode the section may not yet exist
if username not in config.sections():
config.add_section(username) # in catch-all mode the section may not yet exist
REQUEST_QUEUE.put(MENU_UPDATE) # make sure the menu shows the newly-added account
config.set(username, 'token_salt', token_salt)
config.set(username, 'access_token', OAuth2Helper.encrypt(fernet, access_token))
Expand All @@ -695,7 +716,7 @@ def get_account_with_catch_all_fallback(option):
Log.info('Warning: no refresh token returned for', username, '- you will need to re-authenticate',
'each time the access token expires (does your `oauth2_scope` value allow `offline` use?)')

if AppConfig.globals().getboolean('encrypt_client_secret_on_first_use', fallback=False):
if AppConfig.get_global('encrypt_client_secret_on_first_use', fallback=False):
if client_secret:
# note: save to the `username` entry even if `user_domain` exists, avoiding conflicts when using
# incompatible `encrypt_client_secret_on_first_use` and `allow_catch_all_accounts` options
Expand All @@ -712,8 +733,8 @@ def get_account_with_catch_all_fallback(option):
except InvalidToken as e:
# if invalid details are the reason for failure we remove our cached version and re-authenticate - this can
# be disabled by a configuration setting, but note that we always remove credentials on 400 Bad Request
if e.args == (400, APP_PACKAGE) or AppConfig.globals().getboolean('delete_account_token_on_password_error',
fallback=True):
if e.args == (400, APP_PACKAGE) or AppConfig.get_global('delete_account_token_on_password_error',
fallback=True):
config.remove_option(username, 'token_salt')
config.remove_option(username, 'access_token')
config.remove_option(username, 'access_token_expiry')
Expand Down Expand Up @@ -2290,11 +2311,11 @@ def macos_nsworkspace_notification_listener_(self, notification):
Log.info('Received power off notification; exiting', APP_NAME)
self.exit(self.icon)

# noinspection PyDeprecation
def create_icon(self):
# temporary fix for pystray <= 0.19.4 incompatibility with PIL 10.0.0+; fixed once pystray PR #147 is released
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
# noinspection PyDeprecation
pystray_version = pkg_resources.get_distribution('pystray').version
pillow_version = pkg_resources.get_distribution('pillow').version
if pkg_resources.parse_version(pystray_version) <= pkg_resources.parse_version('0.19.4') and \
Expand Down Expand Up @@ -2382,7 +2403,7 @@ def create_config_menu(self):
if len(config_accounts) <= 0:
items.append(pystray.MenuItem(' No accounts configured', None, enabled=False))
else:
catch_all_enabled = AppConfig.globals().getboolean('allow_catch_all_accounts', fallback=False)
catch_all_enabled = AppConfig.get_global('allow_catch_all_accounts', fallback=False)
catch_all_accounts = []
for account in config_accounts:
if account.startswith('@') and catch_all_enabled:
Expand Down Expand Up @@ -2772,7 +2793,9 @@ def load_and_start_servers(self, icon=None, reload=True):
# we allow reloading, so must first stop any existing servers
self.stop_servers()
Log.info('Initialising', APP_NAME, '(version %s)' % __version__, 'from config file', CONFIG_FILE_PATH)
config = AppConfig.reload() if reload else AppConfig.get()
if reload:
AppConfig.unload()
config = AppConfig.get()

# load server types and configurations
server_load_error = False
Expand Down

0 comments on commit 733278f

Please sign in to comment.