Skip to content

Commit

Permalink
allow loading decoupled config from url
Browse files Browse the repository at this point in the history
  • Loading branch information
piax93 committed Jan 24, 2024
1 parent a032bec commit ca07b86
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 20 deletions.
12 changes: 9 additions & 3 deletions pidtree_bcc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from pidtree_bcc.utils import never_crash
from pidtree_bcc.utils import self_restart
from pidtree_bcc.utils import StopFlagWrapper
from pidtree_bcc.yaml_loader import FileIncludeLoader


Expand Down Expand Up @@ -62,14 +63,18 @@ def _drop_namespaces(names: Iterable[str]):
staticconf.config.configuration_namespaces.pop(name, None)


def parse_config(config_file: str, watch_config: bool = False) -> List[str]:
def parse_config(
config_file: str,
watch_config: bool = False,
stop_flag: Optional[StopFlagWrapper] = None,
) -> List[str]:
""" Parses yaml config file (if indicated)
:param str config_file: config file path
:param bool watch_config: perform necessary setup to enable configuration hot swaps
:return: list of all files loaded
"""
loader, included_files = FileIncludeLoader.get_loader_instance()
loader, included_files = FileIncludeLoader.get_loader_instance(stop_flag)
with open(config_file) as f:
config_data = yaml.load(f, Loader=loader)
included_files = sorted({config_file, *included_files})
Expand Down Expand Up @@ -112,6 +117,7 @@ def setup_config(
config_file: str,
watch_config: bool = False,
min_watch_interval: int = 60,
stop_flag: Optional[StopFlagWrapper] = None,
) -> Optional[ConfigurationWatcher]:
""" Load and setup configuration file
Expand All @@ -121,7 +127,7 @@ def setup_config(
:return: if `watch_config` is set, the configuration watcher object, None otherwise.
"""
logging.getLogger('staticconf.config').setLevel(logging.WARN)
config_loader = partial(parse_config, config_file, watch_config)
config_loader = partial(parse_config, config_file, watch_config, stop_flag=stop_flag)
filenames = config_loader()
watcher = ConfigurationWatcher(
config_loader=config_loader,
Expand Down
12 changes: 3 additions & 9 deletions pidtree_bcc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pidtree_bcc.probes import load_probes
from pidtree_bcc.utils import self_restart
from pidtree_bcc.utils import smart_open
from pidtree_bcc.utils import StopFlagWrapper


EXIT_CODE = 0
Expand All @@ -33,14 +34,6 @@ class RestartSignal(BaseException):
pass


class StopFlagWrapper:
def __init__(self):
self.do_stop = False

def stop(self):
self.do_stop = True


def parse_args() -> argparse.Namespace:
""" Parses command line arguments """
program_name = 'pidtree-bcc'
Expand Down Expand Up @@ -165,6 +158,7 @@ def health_and_config_watchdog(
def main(args: argparse.Namespace):
global EXIT_CODE
probe_workers = []
stop_wrapper = StopFlagWrapper()
logging.basicConfig(
stream=sys.stderr,
level=logging.INFO,
Expand All @@ -177,6 +171,7 @@ def main(args: argparse.Namespace):
args.config,
watch_config=args.watch_config,
min_watch_interval=args.health_check_period,
stop_flag=stop_wrapper,
)
out = smart_open(args.output_file, mode='w')
output_queue = SimpleQueue()
Expand All @@ -196,7 +191,6 @@ def main(args: argparse.Namespace):
for probe in probes.values():
probe_workers.append(Process(target=deregister_signals(probe.start_polling)))
probe_workers[-1].start()
stop_wrapper = StopFlagWrapper()
watchdog_thread = Thread(
target=health_and_config_watchdog,
args=(probe_workers, out, stop_wrapper, config_watcher, args.health_check_period),
Expand Down
8 changes: 8 additions & 0 deletions pidtree_bcc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,11 @@ def round_nearest_multiple(value: int, factor: int, headroom: int = 0) -> int:
:return: rounded value
"""
return factor * ((value + headroom) // factor + 1)


class StopFlagWrapper:
def __init__(self):
self.do_stop = False

def stop(self):
self.do_stop = True
144 changes: 136 additions & 8 deletions pidtree_bcc/yaml_loader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import os.path
import hashlib
import logging
import os
import re
import shutil
import sys
import tempfile
from functools import partial
from threading import Condition
from threading import Thread
from typing import Any
from typing import AnyStr
from typing import Dict
from typing import IO
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from urllib import request

import yaml

from pidtree_bcc.utils import never_crash
from pidtree_bcc.utils import StopFlagWrapper


class FileIncludeLoader(yaml.SafeLoader):
""" Custom YAML loader which allows including data from separate files, e.g.:
Expand All @@ -19,7 +32,20 @@ class FileIncludeLoader(yaml.SafeLoader):
```
"""

def __init__(self, stream: Union[AnyStr, IO], included_files: List[str]):
REMOTE_FETCH_INTERVAL_SECONDS = 60 * 60
REMOTE_FETCH_MAX_WAIT_SECONDS = 20

remote_fetcher: Optional[Thread] = None
remote_fetcher_outdir: Optional[str] = None
remote_fetcher_fence: Optional[Condition] = None
remote_fetch_workload: Dict[str, Tuple[str, Condition]] = {}

def __init__(
self,
stream: Union[AnyStr, IO],
included_files: List[str],
stop_flag: Optional[StopFlagWrapper] = None,
):
""" Constructor
:param Union[AnyStr, IO] stream: input data
Expand All @@ -28,6 +54,7 @@ def __init__(self, stream: Union[AnyStr, IO], included_files: List[str]):
super().__init__(stream)
self.add_constructor('!include', self.include_file)
self.included_files = included_files
self.stop_flag = stop_flag

def include_file(self, loader: yaml.Loader, node: yaml.Node) -> Any:
""" Constructs a yaml node from a separate file.
Expand All @@ -38,9 +65,13 @@ def include_file(self, loader: yaml.Loader, node: yaml.Node) -> Any:
"""
name = loader.construct_scalar(node)
filepath = (
os.path.join(os.path.dirname(loader.name), name)
if not os.path.isabs(name)
else name
self.include_remote(name)
if re.match(r'^https?://', name)
else (
os.path.join(os.path.dirname(loader.name), name)
if not os.path.isabs(name)
else name
)
)
try:
with open(filepath) as f:
Expand All @@ -51,8 +82,105 @@ def include_file(self, loader: yaml.Loader, node: yaml.Node) -> Any:
_, value, traceback = sys.exc_info()
raise yaml.YAMLError(value).with_traceback(traceback)

def include_remote(self, url: str) -> str:
""" Load remote configuration data
:param str url: resource url
:return: local filepath where data is stored
"""
if self.remote_fetcher is None or not self.remote_fetcher.is_alive():
self.remote_fetcher_fence = Condition()
self.remote_fetcher_outdir = tempfile.mkdtemp(prefix='pidtree-bcc-conf')
self.remote_fetcher = Thread(
target=fetch_remote_configurations,
args=(
self.REMOTE_FETCH_INTERVAL_SECONDS, self.remote_fetcher_fence,
self.remote_fetch_workload, self.stop_flag,
),
daemon=True,
)
self.remote_fetcher.start()
logging.info(f'Loading remote configuration from {url}')
ready = Condition()
url_sha = hashlib.sha256(url.encode()).hexdigest()
output_path = os.path.join(self.remote_fetcher_outdir, f'{url_sha}.yaml')
self.remote_fetch_workload[url] = (output_path, ready)
with self.remote_fetcher_fence:
self.remote_fetcher_fence.notify()
with ready:
if not ready.wait(timeout=self.REMOTE_FETCH_MAX_WAIT_SECONDS):
raise ValueError(f'Failed to load configuration at {url}')
return output_path

@classmethod
def get_loader_instance(cls) -> Tuple[partial, List[str]]:
""" Get loader and callback list of included files """
def get_loader_instance(cls, stop_flag: Optional[StopFlagWrapper] = None) -> Tuple[partial, List[str]]:
""" Get loader and callback list of included files
:param StopFlagWrapper stop_flag: signal for background threads to stop
:return: loader and callback list of included files
"""
included_files = []
return partial(cls, included_files=included_files), included_files
return partial(cls, included_files=included_files, stop_flag=stop_flag), included_files


@never_crash
def fetch_remote_configurations(
interval: int,
fence: Condition,
workload: Dict[str, Tuple[str, Condition]],
stop_flag: Optional[StopFlagWrapper] = None,
):
""" Periodically sync to disc remote configurations
:param int interval: seconds to wait between each check
:param Condition fence: condition object to cause
:param Dict[str, Tuple[str, Condition]] workload: set of resources to fetch (format: url -> (output_file, ready))
:param StopFlagWrapper stop_flag: signal thead to stop
"""
while not (stop_flag and stop_flag.do_stop):
# list() prevents dict from changing during the loop
for url, path_ready in list(workload.items()):
output_path, ready = path_ready
with ready:
_fetch_remote_configuration_impl(url, output_path)
ready.notify()
with fence:
fence.wait(timeout=interval)


def _fetch_remote_configuration_impl(url: str, output_path: str):
""" Downloads remote configuration to file, if changed
compared to current output path.
:param str url: remote config url
:param str output_path: output file path
"""
checksum = _md5sum(output_path) if os.path.exists(output_path) else ''
if checksum and '.s3.amazonaws.' in url:
# special case for AWS S3, which can give us a checksum in the header
req = request.Request(url=url, method='HEAD')
with request.urlopen(req) as response:
response_etag = response.headers.get('ETag').strip('"').lower()
if response_etag == checksum:
return
# store data to different path and rename, so eventual replacement is atomic
tmpfd, tmppath = tempfile.mkstemp()
tmp = os.fdopen(tmpfd, 'wb')
with request.urlopen(url) as response:
shutil.copyfileobj(response, tmp)
tmp.close()
if _md5sum(tmppath) != checksum:
os.rename(tmppath, output_path)


def _md5sum(filepath: str) -> str:
""" Compute MD5 checksum for file
:param str filepath: path to read data from
:return: hex encoded checksum string
"""
hash_md5 = hashlib.md5()
with open(filepath, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
hash_md5.update(chunk)
return hash_md5.hexdigest()
4 changes: 4 additions & 0 deletions tests/fixtures/remote_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
foo: !include https://raw.githubusercontent.com/Yelp/pidtree-bcc/master/tests/fixtures/child_config.yaml
bar:
fizz: buzz
35 changes: 35 additions & 0 deletions tests/yaml_loader_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import os
from unittest.mock import patch

import yaml

from pidtree_bcc.utils import StopFlagWrapper
from pidtree_bcc.yaml_loader import fetch_remote_configurations
from pidtree_bcc.yaml_loader import FileIncludeLoader


Expand All @@ -12,3 +17,33 @@ def test_file_include_loader():
'bar': {'fizz': 'buzz'},
}
assert included_files == ['tests/fixtures/child_config.yaml']


@patch('pidtree_bcc.yaml_loader.tempfile')
@patch('pidtree_bcc.yaml_loader.request')
def test_file_include_remote(mock_request, mock_tempfile, tmp_path):
stop_flag = StopFlagWrapper()
# test could technically work with a real network request, but we mock anyway for better isolation
mock_request.urlopen.return_value = open('tests/fixtures/child_config.yaml', 'rb')
mock_tempfile.mkdtemp.return_value = tmp_path.absolute().as_posix()
tmpout = (tmp_path / 'tmp.yaml').absolute().as_posix()
mock_tempfile.mkstemp.return_value = (
os.open(tmpout, os.O_WRONLY | os.O_CREAT | os.O_EXCL),
tmpout,
)
# this self-referring patch ensures mocks are propagated to the fetcher thread
with patch('pidtree_bcc.yaml_loader.fetch_remote_configurations', fetch_remote_configurations):
loader, included_files = FileIncludeLoader.get_loader_instance(stop_flag)
with open('tests/fixtures/remote_config.yaml') as f:
data = yaml.load(f, Loader=loader)
stop_flag.stop()
assert data == {
'foo': [1, {'a': 2, 'b': 3}, 4],
'bar': {'fizz': 'buzz'},
}
assert included_files == [
(tmp_path / '72e7a811f0c6baf6b49f9ddd2300d252a3eba7eb370f502cb834faa018ab26b9.yaml').absolute().as_posix(),
]
mock_request.urlopen.assert_called_once_with(
'https://raw.githubusercontent.com/Yelp/pidtree-bcc/master/tests/fixtures/child_config.yaml',
)

0 comments on commit ca07b86

Please sign in to comment.