Skip to content

Commit

Permalink
Add locking wrapper to ConfigParser
Browse files Browse the repository at this point in the history
  • Loading branch information
simonrob committed Aug 16, 2023
1 parent bb93f2e commit c9f54ce
Showing 1 changed file with 75 additions and 46 deletions.
121 changes: 75 additions & 46 deletions emailproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
__author__ = 'Simon Robinson'
__copyright__ = 'Copyright (c) 2023 Simon Robinson'
__license__ = 'Apache 2.0'
__version__ = '2023-08-11' # ISO 8601 (YYYY-MM-DD)
__version__ = '2023-08-16' # ISO 8601 (YYYY-MM-DD)

import abc
import argparse
Expand Down Expand Up @@ -405,16 +405,61 @@ def save(store_id, config_dict, create_secret=True):
Log.error('Unable to get AWS SDK client; cannot cache credentials to AWS Secrets Manager')


class ConcurrentConfigParser:
"""Add locking to a ConfigParser object"""

def __init__(self):
self.config = configparser.ConfigParser()
self.lock = threading.Lock()

def read(self, filename):
with self.lock:
self.config.read(filename)

def sections(self):
with self.lock:
return self.config.sections()

def add_section(self, section):
with self.lock:
self.config.add_section(section)

def get(self, section, option, fallback=None):
with self.lock:
if not self.config.has_section(section):
return fallback
return self.config.get(section, option, fallback=fallback)

def getint(self, section, option, fallback=None):
with self.lock:
return self.config.getint(section, option, fallback=fallback)

def getboolean(self, section, option, fallback=None):
with self.lock:
return self.config.getboolean(section, option, fallback=fallback)

def set(self, section, option, value):
with self.lock:
if not self.config.has_section(section):
self.config.add_section(section)
self.config.set(section, option, value)

def remove_option(self, section, option):
with self.lock:
if self.config.has_option(section, option):
self.config.remove_option(section, option)

def write(self, file):
with self.lock:
self.config.write(file)


class AppConfig:
"""Helper wrapper around ConfigParser to cache servers/accounts, and avoid writing to the file until necessary"""

_PARSER = None
_LOADED = False

_GLOBALS = None
_SERVERS = []
_ACCOUNTS = []

# note: removing the unencrypted version of `client_secret_encrypted` is not automatic with --cache-store (see docs)
_CACHED_OPTION_KEYS = ['token_salt', 'access_token', 'access_token_expiry', 'refresh_token', 'last_activity',
'client_secret_encrypted']
Expand All @@ -425,20 +470,15 @@ class AppConfig:
@staticmethod
def _load():
AppConfig.unload()
AppConfig._PARSER = configparser.ConfigParser()
AppConfig._PARSER = ConcurrentConfigParser()
AppConfig._PARSER.read(CONFIG_FILE_PATH)

config_sections = AppConfig._PARSER.sections()
if APP_SHORT_NAME in config_sections:
AppConfig._GLOBALS = AppConfig._PARSER[APP_SHORT_NAME]
else:
AppConfig._GLOBALS = configparser.SectionProxy(AppConfig._PARSER, APP_SHORT_NAME)

# cached account credentials can be stored in the configuration file (default) or, via `--cache-store`, a
# separate local file or external service (such as a secrets manager) - we combine these sources at load time
if CACHE_STORE != CONFIG_FILE_PATH:
# it would be cleaner to avoid specific options here, but best to load unexpected sections only when enabled
allow_catch_all_accounts = AppConfig._GLOBALS.getboolean('allow_catch_all_accounts', fallback=False)
allow_catch_all_accounts = AppConfig._PARSER.getboolean(APP_SHORT_NAME, 'allow_catch_all_accounts',
fallback=False)

cache_file_parser = AppConfig._load_cache(CACHE_STORE)
cache_file_accounts = [s for s in cache_file_parser.sections() if '@' in s]
Expand All @@ -449,12 +489,6 @@ def _load():
if option in AppConfig._CACHED_OPTION_KEYS:
AppConfig._PARSER.set(account, option, cache_file_parser.get(account, option))

if allow_catch_all_accounts:
config_sections = AppConfig._PARSER.sections() # new sections may have been added

AppConfig._SERVERS = [s for s in config_sections if CONFIG_SERVER_MATCHER.match(s)]
AppConfig._ACCOUNTS = [s for s in config_sections if '@' in s]

AppConfig._LOADED = True

@staticmethod
Expand All @@ -478,53 +512,47 @@ def unload():
AppConfig._PARSER = None
AppConfig._LOADED = False

AppConfig._GLOBALS = None
AppConfig._SERVERS = []
AppConfig._ACCOUNTS = []

@staticmethod
def reload():
AppConfig.unload()
return AppConfig.get()

@staticmethod
def globals():
def get_global(name, fallback):
AppConfig.get() # make sure config is loaded
return AppConfig._GLOBALS
return AppConfig._PARSER.getboolean(APP_SHORT_NAME, name, fallback)

@staticmethod
def servers():
AppConfig.get() # make sure config is loaded
return AppConfig._SERVERS
return [s for s in AppConfig._PARSER.sections() if CONFIG_SERVER_MATCHER.match(s)]

@staticmethod
def accounts():
AppConfig.get() # make sure config is loaded
return AppConfig._ACCOUNTS

@staticmethod
def add_account(username):
AppConfig._PARSER.add_section(username)
AppConfig._ACCOUNTS = [s for s in AppConfig._PARSER.sections() if '@' in s]
return [s for s in AppConfig._PARSER.sections() if '@' in s]

@staticmethod
def save():
if AppConfig._LOADED:
if CACHE_STORE != CONFIG_FILE_PATH:
# in `--cache-store` mode we ignore everything except _CACHED_OPTION_KEYS (OAuth 2.0 tokens, etc)
output_config_parser = configparser.ConfigParser()
output_config_parser.read_dict(AppConfig._PARSER) # a deep copy of the current configuration
with AppConfig._PARSER.lock:
output_config_parser.read_dict(AppConfig._PARSER) # a deep copy of the current configuration
config_accounts = AppConfig.accounts()

for account in AppConfig._ACCOUNTS:
for account in config_accounts:
for option in output_config_parser.options(account):
if option not in AppConfig._CACHED_OPTION_KEYS:
output_config_parser.remove_option(account, option)

for section in output_config_parser.sections():
if section not in AppConfig._ACCOUNTS or len(output_config_parser.options(section)) <= 0:
if section not in config_accounts or len(output_config_parser.options(section)) <= 0:
output_config_parser.remove_section(section)

AppConfig._save_cache(CACHE_STORE, output_config_parser)
with AppConfig._PARSER.lock:
AppConfig._save_cache(CACHE_STORE, output_config_parser)

else:
# by default we cache to the local configuration file, and rewrite all values each time
Expand Down Expand Up @@ -557,10 +585,11 @@ def get_oauth2_credentials(username, password, recurse_retries=True):
if invalid). Returns either (True, '[OAuth2 string for authentication]') or (False, '[Error message]')"""

# we support broader catch-all account names (e.g., `@domain.com` / `@`) if enabled
valid_accounts = [username in AppConfig.accounts()]
if AppConfig.globals().getboolean('allow_catch_all_accounts', fallback=False):
config_accounts = AppConfig.accounts()
valid_accounts = [username in config_accounts]
if AppConfig.get_global('allow_catch_all_accounts', fallback=False):
user_domain = '@%s' % username.split('@')[-1]
valid_accounts.extend([account in AppConfig.accounts() for account in [user_domain, '@']])
valid_accounts.extend([account in config_accounts for account in [user_domain, '@']])

if not any(valid_accounts):
Log.error('Proxy config file entry missing for account', username, '- aborting login')
Expand All @@ -572,7 +601,7 @@ def get_oauth2_credentials(username, password, recurse_retries=True):

def get_account_with_catch_all_fallback(option):
fallback = None
if AppConfig.globals().getboolean('allow_catch_all_accounts', fallback=False):
if AppConfig.get_global('allow_catch_all_accounts', fallback=False):
fallback = config.get(user_domain, option, fallback=config.get('@', option, fallback=None))
return config.get(username, option, fallback=fallback)

Expand Down Expand Up @@ -682,8 +711,8 @@ def get_account_with_catch_all_fallback(option):
oauth2_flow, username, password)

access_token = response['access_token']
if not config.has_section(username):
AppConfig.add_account(username) # in wildcard mode the section may not yet exist
if username not in config.sections():
config.add_section(username) # in wildcard mode the section may not yet exist
REQUEST_QUEUE.put(MENU_UPDATE) # make sure the menu shows the newly-added account
config.set(username, 'token_salt', token_salt)
config.set(username, 'access_token', OAuth2Helper.encrypt(fernet, access_token))
Expand All @@ -695,7 +724,7 @@ def get_account_with_catch_all_fallback(option):
Log.info('Warning: no refresh token returned for', username, '- you will need to re-authenticate',
'each time the access token expires (does your `oauth2_scope` value allow `offline` use?)')

if AppConfig.globals().getboolean('encrypt_client_secret_on_first_use', fallback=False):
if AppConfig.get_global('encrypt_client_secret_on_first_use', fallback=False):
if client_secret:
# note: save to the `username` entry even if `user_domain` exists, avoiding conflicts when using
# incompatible `encrypt_client_secret_on_first_use` and `allow_catch_all_accounts` options
Expand All @@ -712,8 +741,8 @@ def get_account_with_catch_all_fallback(option):
except InvalidToken as e:
# if invalid details are the reason for failure we remove our cached version and re-authenticate - this can
# be disabled by a configuration setting, but note that we always remove credentials on 400 Bad Request
if e.args == (400, APP_PACKAGE) or AppConfig.globals().getboolean('delete_account_token_on_password_error',
fallback=True):
if e.args == (400, APP_PACKAGE) or AppConfig.get_global('delete_account_token_on_password_error',
fallback=True):
config.remove_option(username, 'token_salt')
config.remove_option(username, 'access_token')
config.remove_option(username, 'access_token_expiry')
Expand Down Expand Up @@ -2360,7 +2389,7 @@ def create_config_menu(self):
if len(config_accounts) <= 0:
items.append(pystray.MenuItem(' No accounts configured', None, enabled=False))
else:
catch_all_enabled = AppConfig.globals().getboolean('allow_catch_all_accounts', fallback=False)
catch_all_enabled = AppConfig.get_global('allow_catch_all_accounts', fallback=False)
catch_all_accounts = []
for account in config_accounts:
if account.startswith('@') and catch_all_enabled:
Expand Down

0 comments on commit c9f54ce

Please sign in to comment.