diff --git a/custom_engines.py b/custom_engines.py index b78a30c..d782e01 100644 --- a/custom_engines.py +++ b/custom_engines.py @@ -14,13 +14,14 @@ import logging.config import time +from datetime import datetime from papermill.clientwrap import PapermillNotebookClient from papermill.engines import NBClientEngine, NotebookExecutionManager, PapermillEngines from papermill.utils import remove_args, merge_kwargs, logger -class MetadataKey: +class EngineKey: def __init__(self, client_id, notebook_file): self.client_id = client_id self.notebook_file = notebook_file @@ -30,7 +31,7 @@ def __hash__(self): return hash((self.client_id, self.notebook_file)) def __eq__(self, other): - if isinstance(other, MetadataKey): + if isinstance(other, EngineKey): return self.client_id == other.client_id and self.notebook_file == other.notebook_file return False @@ -41,9 +42,53 @@ def __str__(self): return f"{self.client_id}:{self.notebook_file}" -class EngineMetadata: - client: PapermillNotebookClient = None - last_used_time: float = time.time() +class EngineHolder: + _key: EngineKey + _client: PapermillNotebookClient + _last_used_time: float + _busy: bool = False + + def __init__(self, key: EngineKey, client: PapermillNotebookClient): + self._key = key + self._client = client + self._last_used_time = time.time() + + def __str__(self): + return f"Engine(key={self._key}, last_used_time={self._last_used_time}, is_busy={self._busy})" + + async def async_execute(self, nb_man): + if self._busy: + raise EngineBusyError( + f"Notebook client related to '{self._key}' has been busy since {self._get_last_used_date_time()}") + + try: + self._busy = True + # accept new notebook into (possibly) existing client + self._client.nb_man = nb_man + self._client.nb = nb_man.nb + # reuse client connection to existing kernel + output = await self._client.async_execute(cleanup_kc=False) + # renumber executions + for i, cell in enumerate(nb_man.nb.cells): + if 'execution_count' in cell: + cell['execution_count'] = i + 1 + + return output + finally: + self._busy = False + + def get_last_used_time(self) -> float: + return self._last_used_time + + def close(self): + self._client = None + + def _get_last_used_date_time(self): + return datetime.fromtimestamp(self._last_used_time) + + +class EngineBusyError(RuntimeError): + pass class CustomEngine(NBClientEngine): @@ -51,12 +96,6 @@ class CustomEngine(NBClientEngine): metadata_dict: dict = {} logger: logging.Logger - @classmethod - def renumber_executions(cls, nb): - for i, cell in enumerate(nb.cells): - if 'execution_count' in cell: - cell['execution_count'] = i + 1 - # The code of this method is derived from https://github.com/nteract/papermill/blob/2.6.0 under the BSD License. # Original license follows: # @@ -153,35 +192,28 @@ async def async_execute_managed_notebook( execution_timeout (int): Duration to wait before failing execution (default: never). """ - # Exclude parameters that named differently downstream - safe_kwargs = remove_args(['timeout', 'startup_timeout'], **kwargs) + def create_client(): # TODO: should be static + # Exclude parameters that named differently downstream + safe_kwargs = remove_args(['timeout', 'startup_timeout'], **kwargs) - # Nicely handle preprocessor arguments prioritizing values set by engine - final_kwargs = merge_kwargs( - safe_kwargs, - timeout=execution_timeout if execution_timeout else kwargs.get('timeout'), - startup_timeout=start_timeout, - kernel_name=kernel_name, - log=logger, - log_output=log_output, - stdout_file=stdout_file, - stderr_file=stderr_file, - ) - # TODO: pass client_id - key = MetadataKey("", nb_man.nb['metadata']['papermill']['input_path']) - metadata = cls.get_engine_metadata(key) - if metadata.client is None: - metadata.client = PapermillNotebookClient(nb_man, **final_kwargs) + # Nicely handle preprocessor arguments prioritizing values set by engine + final_kwargs = merge_kwargs( + safe_kwargs, + timeout=execution_timeout if execution_timeout else kwargs.get('timeout'), + startup_timeout=start_timeout, + kernel_name=kernel_name, + log=logger, + log_output=log_output, + stdout_file=stdout_file, + stderr_file=stderr_file, + ) cls.logger.info(f"Created papermill notebook client for {key}") + return PapermillNotebookClient(nb_man, **final_kwargs) - # accept new notebook into (possibly) existing client - metadata.client.nb_man = nb_man - metadata.client.nb = nb_man.nb - # reuse client connection to existing kernel - output = await metadata.client.async_execute(cleanup_kc=False) - cls.renumber_executions(nb_man.nb) - - return output + # TODO: pass client_id + key = EngineKey("", nb_man.nb['metadata']['papermill']['input_path']) + engine_holder: EngineHolder = cls.get_or_create_engine_metadata(key, create_client) + return await engine_holder.async_execute(nb_man) @classmethod def create_logger(cls): @@ -192,27 +224,27 @@ def set_out_of_use_engine_time(cls, value: int): cls.out_of_use_engine_time = value @classmethod - def get_engine_metadata(cls, key: MetadataKey): + def get_or_create_engine_metadata(cls, key: EngineKey, func): cls.remove_out_of_date_engines(key) - metadata: EngineMetadata - if key not in cls.metadata_dict: - metadata = EngineMetadata() - cls.metadata_dict[key] = metadata - else: - metadata = cls.metadata_dict[key] - return metadata + + engine_holder: EngineHolder = cls.metadata_dict.get(key) + if engine_holder is None: + engine_holder = EngineHolder(key, func()) + cls.metadata_dict[key] = engine_holder + + return engine_holder @classmethod - def remove_out_of_date_engines(cls, exclude_key: MetadataKey): + def remove_out_of_date_engines(cls, exclude_key: EngineKey): now = time.time() dead_line = now - cls.out_of_use_engine_time out_of_use_engines = [key for key, metadata in cls.metadata_dict.items() if - key != exclude_key and metadata.last_used_time < dead_line] + key != exclude_key and metadata.get_last_used_time() < dead_line] for key in out_of_use_engines: - metadata: EngineMetadata = cls.metadata_dict.pop(key) - metadata.client = None + engine_holder: EngineHolder = cls.metadata_dict.pop(key) + engine_holder.close() cls.logger.info( - f"Unregistered '{key}' papermill engine, last used time {now - metadata.last_used_time} sec ago") + f"unregistered '{key}' papermill engine, last used time {now - engine_holder.get_last_used_time()} sec ago") class CustomEngines(PapermillEngines): diff --git a/server.py b/server.py index 1e6b127..b7a16ad 100644 --- a/server.py +++ b/server.py @@ -31,7 +31,7 @@ from aiojobs import Job from aiojobs.aiohttp import setup -from custom_engines import CustomEngine +from custom_engines import CustomEngine, EngineBusyError from log_configuratior import configure_logging os.system('pip list') @@ -407,6 +407,10 @@ async def launch_notebook(input_path, arguments: dict, file_name, task_metadata: task_metadata.status = TaskStatus.SUCCESS task_metadata.result = arguments.get('output_path') task_metadata.customization = arguments.get('customization_path') + except EngineBusyError as error: + logger.warning(error.args) + task_metadata.status = TaskStatus.FAILED + task_metadata.result = error except Exception as error: logger.error(f'failed to launch notebook {input_path}', error) task_metadata.status = TaskStatus.FAILED