diff --git a/docs/conf.py b/docs/conf.py index 00ddfde6..78d1a2fe 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -80,7 +80,7 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'UPDATE.md'] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = "sphinx" +pygments_style = 'sphinx' # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -90,14 +90,14 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = "furo" +html_theme = 'furo' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # html_theme_options = { - "sidebar_hide_name": True, + 'sidebar_hide_name': True, } # Add any paths that contain custom static files (such as style sheets) here, @@ -105,7 +105,7 @@ # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] -html_logo = "_static/images/papermill.png" +html_logo = '_static/images/papermill.png' # -- Options for HTMLHelp output ------------------------------------------ diff --git a/papermill/__init__.py b/papermill/__init__.py index af32a9d3..e3b98fb6 100644 --- a/papermill/__init__.py +++ b/papermill/__init__.py @@ -1,5 +1,4 @@ -from .version import version as __version__ - from .exceptions import PapermillException, PapermillExecutionError from .execute import execute_notebook from .inspection import inspect_notebook +from .version import version as __version__ diff --git a/papermill/__main__.py b/papermill/__main__.py index 1f08dacb..c386c2ff 100644 --- a/papermill/__main__.py +++ b/papermill/__main__.py @@ -1,4 +1,4 @@ from papermill.cli import papermill -if __name__ == "__main__": +if __name__ == '__main__': papermill() diff --git a/papermill/abs.py b/papermill/abs.py index 2c5d4a45..6e138053 100644 --- a/papermill/abs.py +++ b/papermill/abs.py @@ -1,9 +1,9 @@ """Utilities for working with Azure blob storage""" -import re import io +import re -from azure.storage.blob import BlobServiceClient from azure.identity import EnvironmentCredential +from azure.storage.blob import BlobServiceClient class AzureBlobStore: @@ -20,7 +20,7 @@ class AzureBlobStore: def _blob_service_client(self, account_name, sas_token=None): blob_service_client = BlobServiceClient( - account_url=f"{account_name}.blob.core.windows.net", + account_url=f'{account_name}.blob.core.windows.net', credential=sas_token or EnvironmentCredential(), ) @@ -33,16 +33,16 @@ def _split_url(self, url): abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken """ match = re.match( - r"abs://(.*)\.blob\.core\.windows\.net\/(.*?)\/([^\?]*)\??(.*)$", url + r'abs://(.*)\.blob\.core\.windows\.net\/(.*?)\/([^\?]*)\??(.*)$', url ) if not match: raise Exception(f"Invalid azure blob url '{url}'") else: params = { - "account": match.group(1), - "container": match.group(2), - "blob": match.group(3), - "sas_token": match.group(4), + 'account': match.group(1), + 'container': match.group(2), + 'blob': match.group(3), + 'sas_token': match.group(4), } return params @@ -51,31 +51,31 @@ def read(self, url): params = self._split_url(url) output_stream = io.BytesIO() blob_service_client = self._blob_service_client( - params["account"], params["sas_token"] + params['account'], params['sas_token'] ) blob_client = blob_service_client.get_blob_client( - params["container"], params["blob"] + params['container'], params['blob'] ) blob_client.download_blob().readinto(output_stream) output_stream.seek(0) - return [line.decode("utf-8") for line in output_stream] + return [line.decode('utf-8') for line in output_stream] def listdir(self, url): """Returns a list of the files under the specified path""" params = self._split_url(url) blob_service_client = self._blob_service_client( - params["account"], params["sas_token"] + params['account'], params['sas_token'] ) - container_client = blob_service_client.get_container_client(params["container"]) - return list(container_client.list_blobs(params["blob"])) + container_client = blob_service_client.get_container_client(params['container']) + return list(container_client.list_blobs(params['blob'])) def write(self, buf, url): """Write buffer to storage at a given url""" params = self._split_url(url) blob_service_client = self._blob_service_client( - params["account"], params["sas_token"] + params['account'], params['sas_token'] ) blob_client = blob_service_client.get_blob_client( - params["container"], params["blob"] + params['container'], params['blob'] ) blob_client.upload_blob(data=buf, overwrite=True) diff --git a/papermill/adl.py b/papermill/adl.py index de7b64cb..73e73efd 100644 --- a/papermill/adl.py +++ b/papermill/adl.py @@ -21,7 +21,7 @@ def __init__(self): @classmethod def _split_url(cls, url): - match = re.match(r"adl://(.*)\.azuredatalakestore\.net\/(.*)$", url) + match = re.match(r'adl://(.*)\.azuredatalakestore\.net\/(.*)$', url) if not match: raise Exception(f"Invalid ADL url '{url}'") else: @@ -40,9 +40,7 @@ def listdir(self, url): (store_name, path) = self._split_url(url) adapter = self._create_adapter(store_name) return [ - "adl://{store_name}.azuredatalakestore.net/{path_to_child}".format( - store_name=store_name, path_to_child=path_to_child - ) + f'adl://{store_name}.azuredatalakestore.net/{path_to_child}' for path_to_child in adapter.ls(path) ] @@ -60,5 +58,5 @@ def write(self, buf, url): """Write buffer to storage at a given url""" (store_name, path) = self._split_url(url) adapter = self._create_adapter(store_name) - with adapter.open(path, "wb") as f: + with adapter.open(path, 'wb') as f: f.write(buf.encode()) diff --git a/papermill/cli.py b/papermill/cli.py index 3b76b00e..8338e67c 100755 --- a/papermill/cli.py +++ b/papermill/cli.py @@ -1,23 +1,21 @@ """Main `papermill` interface.""" +import base64 +import logging import os +import platform import sys -from stat import S_ISFIFO -import nbclient import traceback - -import base64 -import logging +from stat import S_ISFIFO import click - +import nbclient import yaml -import platform +from . import __version__ as papermill_version from .execute import execute_notebook -from .iorw import read_yaml_file, NoDatesSafeLoader from .inspection import display_notebook_help -from . import __version__ as papermill_version +from .iorw import NoDatesSafeLoader, read_yaml_file click.disable_unicode_literals_warning = True @@ -28,155 +26,151 @@ def print_papermill_version(ctx, param, value): if not value: return - print( - "{version} from {path} ({pyver})".format( - version=papermill_version, path=__file__, pyver=platform.python_version() - ) - ) + print(f'{papermill_version} from {__file__} ({platform.python_version()})') ctx.exit() -@click.command(context_settings=dict(help_option_names=["-h", "--help"])) +@click.command(context_settings=dict(help_option_names=['-h', '--help'])) @click.pass_context -@click.argument("notebook_path", required=not INPUT_PIPED) -@click.argument("output_path", default="") +@click.argument('notebook_path', required=not INPUT_PIPED) +@click.argument('output_path', default='') @click.option( - "--help-notebook", + '--help-notebook', is_flag=True, default=False, - help="Display parameters information for the given notebook path.", + help='Display parameters information for the given notebook path.', ) @click.option( - "--parameters", - "-p", + '--parameters', + '-p', nargs=2, multiple=True, - help="Parameters to pass to the parameters cell.", + help='Parameters to pass to the parameters cell.', ) @click.option( - "--parameters_raw", - "-r", + '--parameters_raw', + '-r', nargs=2, multiple=True, - help="Parameters to be read as raw string.", + help='Parameters to be read as raw string.', ) @click.option( - "--parameters_file", - "-f", + '--parameters_file', + '-f', multiple=True, - help="Path to YAML file containing parameters.", + help='Path to YAML file containing parameters.', ) @click.option( - "--parameters_yaml", - "-y", + '--parameters_yaml', + '-y', multiple=True, - help="YAML string to be used as parameters.", + help='YAML string to be used as parameters.', ) @click.option( - "--parameters_base64", - "-b", + '--parameters_base64', + '-b', multiple=True, - help="Base64 encoded YAML string as parameters.", + help='Base64 encoded YAML string as parameters.', ) @click.option( - "--inject-input-path", + '--inject-input-path', is_flag=True, default=False, - help="Insert the path of the input notebook as PAPERMILL_INPUT_PATH as a notebook parameter.", + help='Insert the path of the input notebook as PAPERMILL_INPUT_PATH as a notebook parameter.', ) @click.option( - "--inject-output-path", + '--inject-output-path', is_flag=True, default=False, - help="Insert the path of the output notebook as PAPERMILL_OUTPUT_PATH as a notebook parameter.", + help='Insert the path of the output notebook as PAPERMILL_OUTPUT_PATH as a notebook parameter.', ) @click.option( - "--inject-paths", + '--inject-paths', is_flag=True, default=False, help=( - "Insert the paths of input/output notebooks as PAPERMILL_INPUT_PATH/PAPERMILL_OUTPUT_PATH" - " as notebook parameters." + 'Insert the paths of input/output notebooks as PAPERMILL_INPUT_PATH/PAPERMILL_OUTPUT_PATH' + ' as notebook parameters.' ), ) @click.option( - "--engine", help="The execution engine name to use in evaluating the notebook." + '--engine', help='The execution engine name to use in evaluating the notebook.' ) @click.option( - "--request-save-on-cell-execute/--no-request-save-on-cell-execute", + '--request-save-on-cell-execute/--no-request-save-on-cell-execute', default=True, - help="Request save notebook after each cell execution", + help='Request save notebook after each cell execution', ) @click.option( - "--autosave-cell-every", + '--autosave-cell-every', default=30, type=int, - help="How often in seconds to autosave the notebook during long cell executions (0 to disable)", + help='How often in seconds to autosave the notebook during long cell executions (0 to disable)', ) @click.option( - "--prepare-only/--prepare-execute", + '--prepare-only/--prepare-execute', default=False, - help="Flag for outputting the notebook without execution, but with parameters applied.", + help='Flag for outputting the notebook without execution, but with parameters applied.', ) @click.option( - "--kernel", - "-k", - help="Name of kernel to run. Ignores kernel name in the notebook document metadata.", + '--kernel', + '-k', + help='Name of kernel to run. Ignores kernel name in the notebook document metadata.', ) @click.option( - "--language", - "-l", - help="Language for notebook execution. Ignores language in the notebook document metadata.", + '--language', + '-l', + help='Language for notebook execution. Ignores language in the notebook document metadata.', ) -@click.option("--cwd", default=None, help="Working directory to run notebook in.") +@click.option('--cwd', default=None, help='Working directory to run notebook in.') @click.option( - "--progress-bar/--no-progress-bar", + '--progress-bar/--no-progress-bar', default=None, - help="Flag for turning on the progress bar.", + help='Flag for turning on the progress bar.', ) @click.option( - "--log-output/--no-log-output", + '--log-output/--no-log-output', default=False, - help="Flag for writing notebook output to the configured logger.", + help='Flag for writing notebook output to the configured logger.', ) @click.option( - "--stdout-file", - type=click.File(mode="w", encoding="utf-8"), - help="File to write notebook stdout output to.", + '--stdout-file', + type=click.File(mode='w', encoding='utf-8'), + help='File to write notebook stdout output to.', ) @click.option( - "--stderr-file", - type=click.File(mode="w", encoding="utf-8"), - help="File to write notebook stderr output to.", + '--stderr-file', + type=click.File(mode='w', encoding='utf-8'), + help='File to write notebook stderr output to.', ) @click.option( - "--log-level", - type=click.Choice(["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), - default="INFO", - help="Set log level", + '--log-level', + type=click.Choice(['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']), + default='INFO', + help='Set log level', ) @click.option( - "--start-timeout", - "--start_timeout", # Backwards compatible naming + '--start-timeout', + '--start_timeout', # Backwards compatible naming type=int, default=60, - help="Time in seconds to wait for kernel to start.", + help='Time in seconds to wait for kernel to start.', ) @click.option( - "--execution-timeout", + '--execution-timeout', type=int, - help="Time in seconds to wait for each cell before failing execution (default: forever)", + help='Time in seconds to wait for each cell before failing execution (default: forever)', ) @click.option( - "--report-mode/--no-report-mode", default=False, help="Flag for hiding input." + '--report-mode/--no-report-mode', default=False, help='Flag for hiding input.' ) @click.option( - "--version", + '--version', is_flag=True, callback=print_papermill_version, expose_value=False, is_eager=True, - help="Flag for displaying the version.", + help='Flag for displaying the version.', ) def papermill( click_ctx, @@ -224,8 +218,8 @@ def papermill( """ # Jupyter deps use frozen modules, so we disable the python 3.11+ warning about debugger if running the CLI - if "PYDEVD_DISABLE_FILE_VALIDATION" not in os.environ: - os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1" + if 'PYDEVD_DISABLE_FILE_VALIDATION' not in os.environ: + os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1' if not help_notebook: required_output_path = not (INPUT_PIPED or OUTPUT_PIPED) @@ -233,31 +227,31 @@ def papermill( raise click.UsageError("Missing argument 'OUTPUT_PATH'") if INPUT_PIPED and notebook_path and not output_path: - input_path = "-" + input_path = '-' output_path = notebook_path else: - input_path = notebook_path or "-" - output_path = output_path or "-" + input_path = notebook_path or '-' + output_path = output_path or '-' - if output_path == "-": + if output_path == '-': # Save notebook to stdout just once request_save_on_cell_execute = False # Reduce default log level if we pipe to stdout - if log_level == "INFO": - log_level = "ERROR" + if log_level == 'INFO': + log_level = 'ERROR' elif progress_bar is None: progress_bar = not log_output - logging.basicConfig(level=log_level, format="%(message)s") + logging.basicConfig(level=log_level, format='%(message)s') # Read in Parameters parameters_final = {} if inject_input_path or inject_paths: - parameters_final["PAPERMILL_INPUT_PATH"] = input_path + parameters_final['PAPERMILL_INPUT_PATH'] = input_path if inject_output_path or inject_paths: - parameters_final["PAPERMILL_OUTPUT_PATH"] = output_path + parameters_final['PAPERMILL_OUTPUT_PATH'] = output_path for params in parameters_base64 or []: parameters_final.update( yaml.load(base64.b64decode(params), Loader=NoDatesSafeLoader) or {} @@ -301,11 +295,11 @@ def papermill( def _resolve_type(value): - if value == "True": + if value == 'True': return True - elif value == "False": + elif value == 'False': return False - elif value == "None": + elif value == 'None': return None elif _is_int(value): return int(value) diff --git a/papermill/clientwrap.py b/papermill/clientwrap.py index b6718a2f..f8bdd050 100644 --- a/papermill/clientwrap.py +++ b/papermill/clientwrap.py @@ -1,5 +1,5 @@ -import sys import asyncio +import sys from nbclient import NotebookClient from nbclient.exceptions import CellExecutionError @@ -42,15 +42,15 @@ def execute(self, **kwargs): if ( sys.version_info[0] == 3 and sys.version_info[1] >= 8 - and sys.platform.startswith("win") + and sys.platform.startswith('win') ): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) with self.setup_kernel(**kwargs): - self.log.info("Executing notebook with kernel: %s" % self.kernel_name) + self.log.info('Executing notebook with kernel: %s' % self.kernel_name) self.papermill_execute_cells() info_msg = self.wait_for_reply(self.kc.kernel_info()) - self.nb.metadata["language_info"] = info_msg["content"]["language_info"] + self.nb.metadata['language_info'] = info_msg['content']['language_info'] self.set_widgets_metadata() return self.nb @@ -92,23 +92,23 @@ def log_output_message(self, output): :param output: nbformat.notebooknode.NotebookNode :return: """ - if output.output_type == "stream": - content = "".join(output.text) - if output.name == "stdout": + if output.output_type == 'stream': + content = ''.join(output.text) + if output.name == 'stdout': if self.log_output: self.log.info(content) if self.stdout_file: self.stdout_file.write(content) self.stdout_file.flush() - elif output.name == "stderr": + elif output.name == 'stderr': if self.log_output: # In case users want to redirect stderr differently, pipe to warning self.log.warning(content) if self.stderr_file: self.stderr_file.write(content) self.stderr_file.flush() - elif self.log_output and ("data" in output and "text/plain" in output.data): - self.log.info("".join(output.data["text/plain"])) + elif self.log_output and ('data' in output and 'text/plain' in output.data): + self.log.info(''.join(output.data['text/plain'])) def process_message(self, *arg, **kwargs): output = super().process_message(*arg, **kwargs) diff --git a/papermill/engines.py b/papermill/engines.py index 5200ff7d..28779f11 100644 --- a/papermill/engines.py +++ b/papermill/engines.py @@ -1,16 +1,16 @@ """Engines to perform different roles""" -import sys import datetime -import dateutil - +import sys from functools import wraps + +import dateutil import entrypoints -from .log import logger -from .exceptions import PapermillException from .clientwrap import PapermillNotebookClient +from .exceptions import PapermillException from .iorw import write_ipynb -from .utils import merge_kwargs, remove_args, nb_kernel_name, nb_language +from .log import logger +from .utils import merge_kwargs, nb_kernel_name, nb_language, remove_args class PapermillEngines: @@ -33,7 +33,7 @@ def register_entry_points(self): Load handlers provided by other packages """ - for entrypoint in entrypoints.get_group_all("papermill.engine"): + for entrypoint in entrypoints.get_group_all('papermill.engine'): self.register(entrypoint.name, entrypoint.load()) def get_engine(self, name=None): @@ -69,7 +69,7 @@ def catch_nb_assignment(func): @wraps(func) def wrapper(self, *args, **kwargs): - nb = kwargs.get("nb") + nb = kwargs.get('nb') if nb: # Reassign if executing notebook object was replaced self.nb = nb @@ -90,10 +90,10 @@ class NotebookExecutionManager: shared manner. """ - PENDING = "pending" - RUNNING = "running" - COMPLETED = "completed" - FAILED = "failed" + PENDING = 'pending' + RUNNING = 'running' + COMPLETED = 'completed' + FAILED = 'failed' def __init__( self, @@ -118,7 +118,7 @@ def __init__( # lazy import due to implict slow ipython import from tqdm.auto import tqdm - self.pbar = tqdm(total=len(self.nb.cells), unit="cell", desc="Executing") + self.pbar = tqdm(total=len(self.nb.cells), unit='cell', desc='Executing') def now(self): """Helper to return current UTC time""" @@ -169,7 +169,7 @@ def autosave_cell(self): # Autosave is taking too long, so exponentially back off. self.autosave_cell_every *= 2 logger.warning( - "Autosave too slow: {:.2f} sec, over {}% limit. Backing off to {} sec".format( + 'Autosave too slow: {:.2f} sec, over {}% limit. Backing off to {} sec'.format( save_elapsed, self.max_autosave_pct, self.autosave_cell_every ) ) @@ -187,14 +187,14 @@ def notebook_start(self, **kwargs): """ self.set_timer() - self.nb.metadata.papermill["start_time"] = self.start_time.isoformat() - self.nb.metadata.papermill["end_time"] = None - self.nb.metadata.papermill["duration"] = None - self.nb.metadata.papermill["exception"] = None + self.nb.metadata.papermill['start_time'] = self.start_time.isoformat() + self.nb.metadata.papermill['end_time'] = None + self.nb.metadata.papermill['duration'] = None + self.nb.metadata.papermill['exception'] = None for cell in self.nb.cells: # Reset the cell execution counts. - if cell.get("cell_type") == "code": + if cell.get('cell_type') == 'code': cell.execution_count = None # Clear out the papermill metadata for each cell. @@ -205,7 +205,7 @@ def notebook_start(self, **kwargs): duration=None, status=self.PENDING, # pending, running, completed ) - if cell.get("cell_type") == "code": + if cell.get('cell_type') == 'code': cell.outputs = [] self.save() @@ -219,17 +219,17 @@ def cell_start(self, cell, cell_index=None, **kwargs): metadata for a cell and save the notebook to the output path. """ if self.log_output: - ceel_num = cell_index + 1 if cell_index is not None else "" - logger.info(f"Executing Cell {ceel_num:-<40}") + ceel_num = cell_index + 1 if cell_index is not None else '' + logger.info(f'Executing Cell {ceel_num:-<40}') - cell.metadata.papermill["start_time"] = self.now().isoformat() - cell.metadata.papermill["status"] = self.RUNNING - cell.metadata.papermill["exception"] = False + cell.metadata.papermill['start_time'] = self.now().isoformat() + cell.metadata.papermill['status'] = self.RUNNING + cell.metadata.papermill['exception'] = False # injects optional description of the current cell directly in the tqdm cell_description = self.get_cell_description(cell) - if cell_description is not None and hasattr(self, "pbar") and self.pbar: - self.pbar.set_description(f"Executing {cell_description}") + if cell_description is not None and hasattr(self, 'pbar') and self.pbar: + self.pbar.set_description(f'Executing {cell_description}') self.save() @@ -242,9 +242,9 @@ def cell_exception(self, cell, cell_index=None, **kwargs): set the metadata on the notebook indicating the location of the failure. """ - cell.metadata.papermill["exception"] = True - cell.metadata.papermill["status"] = self.FAILED - self.nb.metadata.papermill["exception"] = True + cell.metadata.papermill['exception'] = True + cell.metadata.papermill['status'] = self.FAILED + self.nb.metadata.papermill['exception'] = True @catch_nb_assignment def cell_complete(self, cell, cell_index=None, **kwargs): @@ -257,20 +257,20 @@ def cell_complete(self, cell, cell_index=None, **kwargs): end_time = self.now() if self.log_output: - ceel_num = cell_index + 1 if cell_index is not None else "" - logger.info(f"Ending Cell {ceel_num:-<43}") + ceel_num = cell_index + 1 if cell_index is not None else '' + logger.info(f'Ending Cell {ceel_num:-<43}') # Ensure our last cell messages are not buffered by python sys.stdout.flush() sys.stderr.flush() - cell.metadata.papermill["end_time"] = end_time.isoformat() - if cell.metadata.papermill.get("start_time"): - start_time = dateutil.parser.parse(cell.metadata.papermill["start_time"]) - cell.metadata.papermill["duration"] = ( + cell.metadata.papermill['end_time'] = end_time.isoformat() + if cell.metadata.papermill.get('start_time'): + start_time = dateutil.parser.parse(cell.metadata.papermill['start_time']) + cell.metadata.papermill['duration'] = ( end_time - start_time ).total_seconds() - if cell.metadata.papermill["status"] != self.FAILED: - cell.metadata.papermill["status"] = self.COMPLETED + if cell.metadata.papermill['status'] != self.FAILED: + cell.metadata.papermill['status'] = self.COMPLETED self.save() if self.pbar: @@ -285,18 +285,18 @@ def notebook_complete(self, **kwargs): Called by Engine when execution concludes, regardless of exceptions. """ self.end_time = self.now() - self.nb.metadata.papermill["end_time"] = self.end_time.isoformat() - if self.nb.metadata.papermill.get("start_time"): - self.nb.metadata.papermill["duration"] = ( + self.nb.metadata.papermill['end_time'] = self.end_time.isoformat() + if self.nb.metadata.papermill.get('start_time'): + self.nb.metadata.papermill['duration'] = ( self.end_time - self.start_time ).total_seconds() # Cleanup cell statuses in case callbacks were never called for cell in self.nb.cells: - if cell.metadata.papermill["status"] == self.FAILED: + if cell.metadata.papermill['status'] == self.FAILED: break - elif cell.metadata.papermill["status"] == self.PENDING: - cell.metadata.papermill["status"] = self.COMPLETED + elif cell.metadata.papermill['status'] == self.PENDING: + cell.metadata.papermill['status'] = self.COMPLETED self.complete_pbar() self.cleanup_pbar() @@ -304,12 +304,12 @@ def notebook_complete(self, **kwargs): # Force a final sync self.save() - def get_cell_description(self, cell, escape_str="papermill_description="): + def get_cell_description(self, cell, escape_str='papermill_description='): """Fetches cell description if present""" if cell is None: return None - cell_code = cell["source"] + cell_code = cell['source'] if cell_code is None or escape_str not in cell_code: return None @@ -317,13 +317,13 @@ def get_cell_description(self, cell, escape_str="papermill_description="): def complete_pbar(self): """Refresh progress bar""" - if hasattr(self, "pbar") and self.pbar: + if hasattr(self, 'pbar') and self.pbar: self.pbar.n = len(self.nb.cells) self.pbar.refresh() def cleanup_pbar(self): """Clean up a progress bar""" - if hasattr(self, "pbar") and self.pbar: + if hasattr(self, 'pbar') and self.pbar: self.pbar.close() self.pbar = None @@ -431,12 +431,12 @@ def execute_managed_notebook( """ # Exclude parameters that named differently downstream - safe_kwargs = remove_args(["timeout", "startup_timeout"], **kwargs) + safe_kwargs = remove_args(['timeout', 'startup_timeout'], **kwargs) # Nicely handle preprocessor arguments prioritizing values set by engine final_kwargs = merge_kwargs( safe_kwargs, - timeout=execution_timeout if execution_timeout else kwargs.get("timeout"), + timeout=execution_timeout if execution_timeout else kwargs.get('timeout'), startup_timeout=start_timeout, kernel_name=kernel_name, log=logger, @@ -450,5 +450,5 @@ def execute_managed_notebook( # Instantiate a PapermillEngines instance, register Handlers and entrypoints papermill_engines = PapermillEngines() papermill_engines.register(None, NBClientEngine) -papermill_engines.register("nbclient", NBClientEngine) +papermill_engines.register('nbclient', NBClientEngine) papermill_engines.register_entry_points() diff --git a/papermill/exceptions.py b/papermill/exceptions.py index 38aab7e8..f78f95f7 100644 --- a/papermill/exceptions.py +++ b/papermill/exceptions.py @@ -33,10 +33,10 @@ def __str__(self): # when called with str(). In order to maintain compatability with previous versions which # passed only the message to the superclass constructor, __str__ method is implemented to # provide the same result as was produced in the past. - message = "\n" + 75 * "-" + "\n" + message = '\n' + 75 * '-' + '\n' message += 'Exception encountered at "In [%s]":\n' % str(self.exec_count) - message += "\n".join(self.traceback) - message += "\n" + message += '\n'.join(self.traceback) + message += '\n' return message @@ -59,10 +59,8 @@ class PapermillParameterOverwriteWarning(PapermillWarning): def missing_dependency_generator(package, dep): def missing_dep(): raise PapermillOptionalDependencyException( - "The {package} optional dependency is missing. " - "Please run pip install papermill[{dep}] to install this dependency".format( - package=package, dep=dep - ) + f'The {package} optional dependency is missing. ' + f'Please run pip install papermill[{dep}] to install this dependency' ) return missing_dep @@ -71,11 +69,9 @@ def missing_dep(): def missing_environment_variable_generator(package, env_key): def missing_dep(): raise PapermillOptionalDependencyException( - "The {package} optional dependency is present, but the environment " - "variable {env_key} is not set. Please set this variable as " - "required by {package} on your platform.".format( - package=package, env_key=env_key - ) + f'The {package} optional dependency is present, but the environment ' + f'variable {env_key} is not set. Please set this variable as ' + f'required by {package} on your platform.' ) return missing_dep diff --git a/papermill/execute.py b/papermill/execute.py index 3d0d23ae..74f577f6 100644 --- a/papermill/execute.py +++ b/papermill/execute.py @@ -1,17 +1,18 @@ -import nbformat from pathlib import Path -from .log import logger -from .exceptions import PapermillExecutionError -from .iorw import get_pretty_path, local_file_io_cwd, load_notebook_node, write_ipynb +import nbformat + from .engines import papermill_engines -from .utils import chdir +from .exceptions import PapermillExecutionError +from .inspection import _infer_parameters +from .iorw import get_pretty_path, load_notebook_node, local_file_io_cwd, write_ipynb +from .log import logger from .parameterize import ( add_builtin_parameters, parameterize_notebook, parameterize_path, ) -from .inspection import _infer_parameters +from .utils import chdir def execute_notebook( @@ -83,11 +84,11 @@ def execute_notebook( input_path = parameterize_path(input_path, path_parameters) output_path = parameterize_path(output_path, path_parameters) - logger.info("Input Notebook: %s" % get_pretty_path(input_path)) - logger.info("Output Notebook: %s" % get_pretty_path(output_path)) + logger.info('Input Notebook: %s' % get_pretty_path(input_path)) + logger.info('Output Notebook: %s' % get_pretty_path(output_path)) with local_file_io_cwd(): if cwd is not None: - logger.info(f"Working directory: {get_pretty_path(cwd)}") + logger.info(f'Working directory: {get_pretty_path(cwd)}') nb = load_notebook_node(input_path) @@ -99,7 +100,7 @@ def execute_notebook( parameter_predefined = {p.name for p in parameter_predefined} for p in parameters: if p not in parameter_predefined: - logger.warning(f"Passed unknown parameter: {p}") + logger.warning(f'Passed unknown parameter: {p}') nb = parameterize_notebook( nb, parameters, @@ -160,31 +161,31 @@ def prepare_notebook_metadata(nb, input_path, output_path, report_mode=False): # Hide input if report-mode is set to True. if report_mode: for cell in nb.cells: - if cell.cell_type == "code": - cell.metadata["jupyter"] = cell.get("jupyter", {}) - cell.metadata["jupyter"]["source_hidden"] = True + if cell.cell_type == 'code': + cell.metadata['jupyter'] = cell.get('jupyter', {}) + cell.metadata['jupyter']['source_hidden'] = True # Record specified environment variable values. - nb.metadata.papermill["input_path"] = input_path - nb.metadata.papermill["output_path"] = output_path + nb.metadata.papermill['input_path'] = input_path + nb.metadata.papermill['output_path'] = output_path return nb -ERROR_MARKER_TAG = "papermill-error-cell-tag" +ERROR_MARKER_TAG = 'papermill-error-cell-tag' ERROR_STYLE = 'style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;"' ERROR_MESSAGE_TEMPLATE = ( - "" - "An Exception was encountered at 'In [%s]'." - "" + '' + 'An Exception was encountered at \'In [%s]\'.' + '' ) ERROR_ANCHOR_MSG = ( - '" - "Execution using papermill encountered an exception here and stopped:" - "" + '' + 'Execution using papermill encountered an exception here and stopped:' + '' ) @@ -192,7 +193,7 @@ def remove_error_markers(nb): nb.cells = [ cell for cell in nb.cells - if ERROR_MARKER_TAG not in cell.metadata.get("tags", []) + if ERROR_MARKER_TAG not in cell.metadata.get('tags', []) ] return nb @@ -209,13 +210,13 @@ def raise_for_execution_errors(nb, output_path): """ error = None for index, cell in enumerate(nb.cells): - if cell.get("outputs") is None: + if cell.get('outputs') is None: continue for output in cell.outputs: - if output.output_type == "error": - if output.ename == "SystemExit" and ( - output.evalue == "" or output.evalue == "0" + if output.output_type == 'error': + if output.ename == 'SystemExit' and ( + output.evalue == '' or output.evalue == '0' ): continue error = PapermillExecutionError( @@ -233,9 +234,9 @@ def raise_for_execution_errors(nb, output_path): # the relevant cell (by adding a note just before the failure with an HTML anchor) error_msg = ERROR_MESSAGE_TEMPLATE % str(error.exec_count) error_msg_cell = nbformat.v4.new_markdown_cell(error_msg) - error_msg_cell.metadata["tags"] = [ERROR_MARKER_TAG] + error_msg_cell.metadata['tags'] = [ERROR_MARKER_TAG] error_anchor_cell = nbformat.v4.new_markdown_cell(ERROR_ANCHOR_MSG) - error_anchor_cell.metadata["tags"] = [ERROR_MARKER_TAG] + error_anchor_cell.metadata['tags'] = [ERROR_MARKER_TAG] # Upgrade the Notebook to the latest v4 before writing into it nb = nbformat.v4.upgrade(nb) diff --git a/papermill/inspection.py b/papermill/inspection.py index b1ec68f7..75ca765e 100644 --- a/papermill/inspection.py +++ b/papermill/inspection.py @@ -1,7 +1,8 @@ """Deduce parameters of a notebook from the parameters cell.""" -import click from pathlib import Path +import click + from .iorw import get_pretty_path, load_notebook_node, local_file_io_cwd from .log import logger from .parameterize import add_builtin_parameters, parameterize_path @@ -17,7 +18,7 @@ def _open_notebook(notebook_path, parameters): path_parameters = add_builtin_parameters(parameters) input_path = parameterize_path(notebook_path, path_parameters) - logger.info("Input Notebook: %s" % get_pretty_path(input_path)) + logger.info('Input Notebook: %s' % get_pretty_path(input_path)) with local_file_io_cwd(): return load_notebook_node(input_path) @@ -38,7 +39,7 @@ def _infer_parameters(nb, name=None, language=None): """ params = [] - parameter_cell_idx = find_first_tagged_cell_index(nb, "parameters") + parameter_cell_idx = find_first_tagged_cell_index(nb, 'parameters') if parameter_cell_idx < 0: return params parameter_cell = nb.cells[parameter_cell_idx] @@ -51,9 +52,7 @@ def _infer_parameters(nb, name=None, language=None): params = translator.inspect(parameter_cell) except NotImplementedError: logger.warning( - "Translator for '{}' language does not support parameter introspection.".format( - language - ) + f"Translator for '{language}' language does not support parameter introspection." ) return params @@ -74,7 +73,7 @@ def display_notebook_help(ctx, notebook_path, parameters): pretty_path = get_pretty_path(notebook_path) click.echo(f"\nParameters inferred for notebook '{pretty_path}':") - if not any_tagged_cell(nb, "parameters"): + if not any_tagged_cell(nb, 'parameters'): click.echo("\n No cell tagged 'parameters'") return 1 @@ -82,25 +81,25 @@ def display_notebook_help(ctx, notebook_path, parameters): if params: for param in params: p = param._asdict() - type_repr = p["inferred_type_name"] - if type_repr == "None": - type_repr = "Unknown type" + type_repr = p['inferred_type_name'] + if type_repr == 'None': + type_repr = 'Unknown type' - definition = " {}: {} (default {})".format( - p["name"], type_repr, p["default"] + definition = ' {}: {} (default {})'.format( + p['name'], type_repr, p['default'] ) if len(definition) > 30: - if len(p["help"]): - param_help = "".join((definition, "\n", 34 * " ", p["help"])) + if len(p['help']): + param_help = ''.join((definition, '\n', 34 * ' ', p['help'])) else: param_help = definition else: - param_help = "{:<34}{}".format(definition, p["help"]) + param_help = '{:<34}{}'.format(definition, p['help']) click.echo(param_help) else: click.echo( "\n Can't infer anything about this notebook's parameters. " - "It may not have any parameter defined." + 'It may not have any parameter defined.' ) return 0 diff --git a/papermill/iorw.py b/papermill/iorw.py index 961ee207..d698bd6b 100644 --- a/papermill/iorw.py +++ b/papermill/iorw.py @@ -1,15 +1,14 @@ +import fnmatch +import json import os import sys -import json -import yaml -import fnmatch -import nbformat -import requests import warnings -import entrypoints - from contextlib import contextmanager +import entrypoints +import nbformat +import requests +import yaml from tenacity import ( retry, retry_if_exception_type, @@ -30,37 +29,37 @@ try: from .s3 import S3 except ImportError: - S3 = missing_dependency_generator("boto3", "s3") + S3 = missing_dependency_generator('boto3', 's3') try: from .adl import ADL except ImportError: - ADL = missing_dependency_generator("azure.datalake.store", "azure") + ADL = missing_dependency_generator('azure.datalake.store', 'azure') except KeyError as exc: - if exc.args[0] == "APPDATA": - ADL = missing_environment_variable_generator("azure.datalake.store", "APPDATA") + if exc.args[0] == 'APPDATA': + ADL = missing_environment_variable_generator('azure.datalake.store', 'APPDATA') else: raise try: from .abs import AzureBlobStore except ImportError: - AzureBlobStore = missing_dependency_generator("azure.storage.blob", "azure") + AzureBlobStore = missing_dependency_generator('azure.storage.blob', 'azure') try: from gcsfs import GCSFileSystem except ImportError: - GCSFileSystem = missing_dependency_generator("gcsfs", "gcs") + GCSFileSystem = missing_dependency_generator('gcsfs', 'gcs') try: - from pyarrow.fs import HadoopFileSystem, FileSelector + from pyarrow.fs import FileSelector, HadoopFileSystem except ImportError: - HadoopFileSystem = missing_dependency_generator("pyarrow", "hdfs") + HadoopFileSystem = missing_dependency_generator('pyarrow', 'hdfs') try: from github import Github except ImportError: - Github = missing_dependency_generator("pygithub", "github") + Github = missing_dependency_generator('pygithub', 'github') def fallback_gs_is_retriable(e): @@ -97,14 +96,14 @@ class PapermillIO: def __init__(self): self.reset() - def read(self, path, extensions=[".ipynb", ".json"]): + def read(self, path, extensions=['.ipynb', '.json']): # Handle https://github.com/nteract/papermill/issues/317 notebook_metadata = self.get_handler(path, extensions).read(path) if isinstance(notebook_metadata, (bytes, bytearray)): - return notebook_metadata.decode("utf-8") + return notebook_metadata.decode('utf-8') return notebook_metadata - def write(self, buf, path, extensions=[".ipynb", ".json"]): + def write(self, buf, path, extensions=['.ipynb', '.json']): return self.get_handler(path, extensions).write(buf, path) def listdir(self, path): @@ -122,7 +121,7 @@ def register(self, scheme, handler): def register_entry_points(self): # Load handlers provided by other packages - for entrypoint in entrypoints.get_group_all("papermill.io"): + for entrypoint in entrypoints.get_group_all('papermill.io'): self.register(entrypoint.name, entrypoint.load()) def get_handler(self, path, extensions=None): @@ -151,22 +150,22 @@ def get_handler(self, path, extensions=None): return NotebookNodeHandler() if extensions: - if not fnmatch.fnmatch(os.path.basename(path).split("?")[0], "*.*"): + if not fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*.*'): warnings.warn( - "the file is not specified with any extension : " + 'the file is not specified with any extension : ' + os.path.basename(path) ) elif not any( - fnmatch.fnmatch(os.path.basename(path).split("?")[0], "*" + ext) + fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*' + ext) for ext in extensions ): warnings.warn( - f"The specified file ({path}) does not end in one of {extensions}" + f'The specified file ({path}) does not end in one of {extensions}' ) local_handler = None for scheme, handler in self._handlers: - if scheme == "local": + if scheme == 'local': local_handler = handler if path.startswith(scheme): @@ -174,7 +173,7 @@ def get_handler(self, path, extensions=None): if local_handler is None: raise PapermillException( - f"Could not find a registered schema handler for: {path}" + f'Could not find a registered schema handler for: {path}' ) return local_handler @@ -183,11 +182,11 @@ def get_handler(self, path, extensions=None): class HttpHandler: @classmethod def read(cls, path): - return requests.get(path, headers={"Accept": "application/json"}).text + return requests.get(path, headers={'Accept': 'application/json'}).text @classmethod def listdir(cls, path): - raise PapermillException("listdir is not supported by HttpHandler") + raise PapermillException('listdir is not supported by HttpHandler') @classmethod def write(cls, buf, path): @@ -206,7 +205,7 @@ def __init__(self): def read(self, path): try: with chdir(self._cwd): - with open(path, encoding="utf-8") as f: + with open(path, encoding='utf-8') as f: return f.read() except OSError as e: try: @@ -227,7 +226,7 @@ def write(self, buf, path): dirname = os.path.dirname(path) if dirname and not os.path.exists(dirname): raise FileNotFoundError(f"output folder {dirname} doesn't exist.") - with open(path, "w", encoding="utf-8") as f: + with open(path, 'w', encoding='utf-8') as f: f.write(buf) def pretty_path(self, path): @@ -243,7 +242,7 @@ def cwd(self, new_path): class S3Handler: @classmethod def read(cls, path): - return "\n".join(S3().read(path)) + return '\n'.join(S3().read(path)) @classmethod def listdir(cls, path): @@ -269,7 +268,7 @@ def _get_client(self): def read(self, path): lines = self._get_client().read(path) - return "\n".join(lines) + return '\n'.join(lines) def listdir(self, path): return self._get_client().listdir(path) @@ -292,7 +291,7 @@ def _get_client(self): def read(self, path): lines = self._get_client().read(path) - return "\n".join(lines) + return '\n'.join(lines) def listdir(self, path): return self._get_client().listdir(path) @@ -339,13 +338,13 @@ def write(self, buf, path): ) def retry_write(): try: - with self._get_client().open(path, "w") as f: + with self._get_client().open(path, 'w') as f: return f.write(buf) except Exception as e: try: message = e.message except AttributeError: - message = f"Generic exception {type(e)} raised" + message = f'Generic exception {type(e)} raised' if gs_is_retriable(e): raise PapermillRateLimitException(message) # Reraise the original exception without retries @@ -363,7 +362,7 @@ def __init__(self): def _get_client(self): if self._client is None: - self._client = HadoopFileSystem(host="default") + self._client = HadoopFileSystem(host='default') return self._client def read(self, path): @@ -387,7 +386,7 @@ def __init__(self): def _get_client(self): if self._client is None: - token = os.environ.get("GITHUB_ACCESS_TOKEN", None) + token = os.environ.get('GITHUB_ACCESS_TOKEN', None) if token: self._client = Github(token) else: @@ -395,20 +394,20 @@ def _get_client(self): return self._client def read(self, path): - splits = path.split("/") + splits = path.split('/') org_id = splits[3] repo_id = splits[4] ref_id = splits[6] - sub_path = "/".join(splits[7:]) - repo = self._get_client().get_repo(org_id + "/" + repo_id) + sub_path = '/'.join(splits[7:]) + repo = self._get_client().get_repo(org_id + '/' + repo_id) content = repo.get_contents(sub_path, ref=ref_id) return content.decoded_content def listdir(self, path): - raise PapermillException("listdir is not supported by GithubHandler") + raise PapermillException('listdir is not supported by GithubHandler') def write(self, buf, path): - raise PapermillException("write is not supported by GithubHandler") + raise PapermillException('write is not supported by GithubHandler') def pretty_path(self, path): return path @@ -421,15 +420,15 @@ def read(self, path): return sys.stdin.read() def listdir(self, path): - raise PapermillException("listdir is not supported by Stream Handler") + raise PapermillException('listdir is not supported by Stream Handler') def write(self, buf, path): try: - return sys.stdout.buffer.write(buf.encode("utf-8")) + return sys.stdout.buffer.write(buf.encode('utf-8')) except AttributeError: # Originally required by https://github.com/nteract/papermill/issues/420 # Support Buffer.io objects - return sys.stdout.write(buf.encode("utf-8")) + return sys.stdout.write(buf.encode('utf-8')) def pretty_path(self, path): return path @@ -442,60 +441,60 @@ def read(self, path): return nbformat.writes(path) def listdir(self, path): - raise PapermillException("listdir is not supported by NotebookNode Handler") + raise PapermillException('listdir is not supported by NotebookNode Handler') def write(self, buf, path): - raise PapermillException("write is not supported by NotebookNode Handler") + raise PapermillException('write is not supported by NotebookNode Handler') def pretty_path(self, path): - return "NotebookNode object" + return 'NotebookNode object' class NoIOHandler: """Handler for output_path of None - intended to not write anything""" def read(self, path): - raise PapermillException("read is not supported by NoIOHandler") + raise PapermillException('read is not supported by NoIOHandler') def listdir(self, path): - raise PapermillException("listdir is not supported by NoIOHandler") + raise PapermillException('listdir is not supported by NoIOHandler') def write(self, buf, path): return def pretty_path(self, path): - return "Notebook will not be saved" + return 'Notebook will not be saved' # Hack to make YAML loader not auto-convert datetimes # https://stackoverflow.com/a/52312810 class NoDatesSafeLoader(yaml.SafeLoader): yaml_implicit_resolvers = { - k: [r for r in v if r[0] != "tag:yaml.org,2002:timestamp"] + k: [r for r in v if r[0] != 'tag:yaml.org,2002:timestamp'] for k, v in yaml.SafeLoader.yaml_implicit_resolvers.items() } # Instantiate a PapermillIO instance and register Handlers. papermill_io = PapermillIO() -papermill_io.register("local", LocalHandler()) -papermill_io.register("s3://", S3Handler) -papermill_io.register("adl://", ADLHandler()) -papermill_io.register("abs://", ABSHandler()) -papermill_io.register("http://", HttpHandler) -papermill_io.register("https://", HttpHandler) -papermill_io.register("gs://", GCSHandler()) -papermill_io.register("hdfs://", HDFSHandler()) -papermill_io.register("http://github.com/", GithubHandler()) -papermill_io.register("https://github.com/", GithubHandler()) -papermill_io.register("-", StreamHandler()) +papermill_io.register('local', LocalHandler()) +papermill_io.register('s3://', S3Handler) +papermill_io.register('adl://', ADLHandler()) +papermill_io.register('abs://', ABSHandler()) +papermill_io.register('http://', HttpHandler) +papermill_io.register('https://', HttpHandler) +papermill_io.register('gs://', GCSHandler()) +papermill_io.register('hdfs://', HDFSHandler()) +papermill_io.register('http://github.com/', GithubHandler()) +papermill_io.register('https://github.com/', GithubHandler()) +papermill_io.register('-', StreamHandler()) papermill_io.register_entry_points() def read_yaml_file(path): """Reads a YAML file from the location specified at 'path'.""" return yaml.load( - papermill_io.read(path, [".json", ".yaml", ".yml"]), Loader=NoDatesSafeLoader + papermill_io.read(path, ['.json', '.yaml', '.yml']), Loader=NoDatesSafeLoader ) @@ -523,27 +522,27 @@ def load_notebook_node(notebook_path): if nb_upgraded is not None: nb = nb_upgraded - if not hasattr(nb.metadata, "papermill"): - nb.metadata["papermill"] = { - "default_parameters": dict(), - "parameters": dict(), - "environment_variables": dict(), - "version": __version__, + if not hasattr(nb.metadata, 'papermill'): + nb.metadata['papermill'] = { + 'default_parameters': dict(), + 'parameters': dict(), + 'environment_variables': dict(), + 'version': __version__, } for cell in nb.cells: - if not hasattr(cell.metadata, "tags"): - cell.metadata["tags"] = [] # Create tags attr if one doesn't exist. + if not hasattr(cell.metadata, 'tags'): + cell.metadata['tags'] = [] # Create tags attr if one doesn't exist. - if not hasattr(cell.metadata, "papermill"): - cell.metadata["papermill"] = dict() + if not hasattr(cell.metadata, 'papermill'): + cell.metadata['papermill'] = dict() return nb def list_notebook_files(path): """Returns a list of all the notebook files in a directory.""" - return [p for p in papermill_io.listdir(path) if p.endswith(".ipynb")] + return [p for p in papermill_io.listdir(path) if p.endswith('.ipynb')] def get_pretty_path(path): @@ -553,14 +552,14 @@ def get_pretty_path(path): @contextmanager def local_file_io_cwd(path=None): try: - local_handler = papermill_io.get_handler("local") + local_handler = papermill_io.get_handler('local') except PapermillException: - logger.warning("No local file handler detected") + logger.warning('No local file handler detected') else: try: old_cwd = local_handler.cwd(path or os.getcwd()) except AttributeError: - logger.warning("Local file handler does not support cwd assignment") + logger.warning('Local file handler does not support cwd assignment') else: try: yield diff --git a/papermill/log.py b/papermill/log.py index 273bc8f3..b90225d2 100644 --- a/papermill/log.py +++ b/papermill/log.py @@ -1,4 +1,4 @@ """Sets up a logger""" import logging -logger = logging.getLogger("papermill") +logger = logging.getLogger('papermill') diff --git a/papermill/models.py b/papermill/models.py index fcbb627f..35c077e5 100644 --- a/papermill/models.py +++ b/papermill/models.py @@ -2,11 +2,11 @@ from collections import namedtuple Parameter = namedtuple( - "Parameter", + 'Parameter', [ - "name", - "inferred_type_name", # string of type - "default", # string representing the default value - "help", + 'name', + 'inferred_type_name', # string of type + 'default', # string representing the default value + 'help', ], ) diff --git a/papermill/parameterize.py b/papermill/parameterize.py index db3ac837..a210f26e 100644 --- a/papermill/parameterize.py +++ b/papermill/parameterize.py @@ -1,15 +1,15 @@ +from datetime import datetime +from uuid import uuid4 + import nbformat from .engines import papermill_engines -from .log import logger from .exceptions import PapermillMissingParameterException from .iorw import read_yaml_file +from .log import logger from .translators import translate_parameters from .utils import find_first_tagged_cell_index -from uuid import uuid4 -from datetime import datetime - def add_builtin_parameters(parameters): """Add built-in parameters to a dictionary of parameters @@ -20,10 +20,10 @@ def add_builtin_parameters(parameters): Dictionary of parameters provided by the user """ with_builtin_parameters = { - "pm": { - "run_uuid": str(uuid4()), - "current_datetime_local": datetime.now(), - "current_datetime_utc": datetime.utcnow(), + 'pm': { + 'run_uuid': str(uuid4()), + 'current_datetime_local': datetime.now(), + 'current_datetime_utc': datetime.utcnow(), } } @@ -53,14 +53,14 @@ def parameterize_path(path, parameters): try: return path.format(**parameters) except KeyError as key_error: - raise PapermillMissingParameterException(f"Missing parameter {key_error}") + raise PapermillMissingParameterException(f'Missing parameter {key_error}') def parameterize_notebook( nb, parameters, report_mode=False, - comment="Parameters", + comment='Parameters', kernel_name=None, language=None, engine_name=None, @@ -93,14 +93,14 @@ def parameterize_notebook( nb = nbformat.v4.upgrade(nb) newcell = nbformat.v4.new_code_cell(source=param_content) - newcell.metadata["tags"] = ["injected-parameters"] + newcell.metadata['tags'] = ['injected-parameters'] if report_mode: - newcell.metadata["jupyter"] = newcell.get("jupyter", {}) - newcell.metadata["jupyter"]["source_hidden"] = True + newcell.metadata['jupyter'] = newcell.get('jupyter', {}) + newcell.metadata['jupyter']['source_hidden'] = True - param_cell_index = find_first_tagged_cell_index(nb, "parameters") - injected_cell_index = find_first_tagged_cell_index(nb, "injected-parameters") + param_cell_index = find_first_tagged_cell_index(nb, 'parameters') + injected_cell_index = find_first_tagged_cell_index(nb, 'injected-parameters') if injected_cell_index >= 0: # Replace the injected cell with a new version before = nb.cells[:injected_cell_index] @@ -116,6 +116,6 @@ def parameterize_notebook( after = nb.cells nb.cells = before + [newcell] + after - nb.metadata.papermill["parameters"] = parameters + nb.metadata.papermill['parameters'] = parameters return nb diff --git a/papermill/s3.py b/papermill/s3.py index ccd2141a..af11571e 100644 --- a/papermill/s3.py +++ b/papermill/s3.py @@ -1,8 +1,7 @@ """Utilities for working with S3.""" -import os - import logging +import os import threading import zlib @@ -11,8 +10,7 @@ from .exceptions import AwsError from .utils import retry - -logger = logging.getLogger("papermill.s3") +logger = logging.getLogger('papermill.s3') class Bucket: @@ -32,7 +30,7 @@ def __init__(self, name, service=None): self.name = name self.service = service - def list(self, prefix="", delimiter=None): + def list(self, prefix='', delimiter=None): """Limits a list of Bucket's objects based on prefix and delimiter.""" return self.service._list( bucket=self.name, prefix=prefix, delimiter=delimiter, objects=True @@ -61,7 +59,7 @@ def __init__(self, bucket, name, service=None): self.service = service def __str__(self): - return f"s3://{self.bucket.name}/{self.name}" + return f's3://{self.bucket.name}/{self.name}' def __repr__(self): return self.__str__() @@ -106,7 +104,7 @@ def __init__( self.etag = etag if last_modified: try: - self.last_modified = last_modified.isoformat().split("+")[0] + ".000Z" + self.last_modified = last_modified.isoformat().split('+')[0] + '.000Z' except ValueError: self.last_modified = last_modified self.storage_class = storage_class @@ -114,7 +112,7 @@ def __init__( self.service = service def __str__(self): - return f"s3://{self.bucket.name}/{self.name}" + return f's3://{self.bucket.name}/{self.name}' def __repr__(self): return self.__str__() @@ -146,30 +144,30 @@ def __init__(self, keyname=None, *args, **kwargs): with self.lock: if not all(S3.s3_session): session = Session() - client = session.client("s3") + client = session.client('s3') session_params = {} - endpoint_url = os.environ.get("BOTO3_ENDPOINT_URL", None) + endpoint_url = os.environ.get('BOTO3_ENDPOINT_URL', None) if endpoint_url: - session_params["endpoint_url"] = endpoint_url + session_params['endpoint_url'] = endpoint_url - s3 = session.resource("s3", **session_params) + s3 = session.resource('s3', **session_params) S3.s3_session = (session, client, s3) (self.session, self.client, self.s3) = S3.s3_session def _bucket_name(self, bucket): - return self._clean(bucket).split("/", 1)[0] + return self._clean(bucket).split('/', 1)[0] def _clean(self, name): - if name.startswith("s3n:"): - name = "s3:" + name[4:] + if name.startswith('s3n:'): + name = 's3:' + name[4:] if self._is_s3(name): return name[5:] return name def _clean_s3(self, name): - return "s3:" + name[4:] if name.startswith("s3n:") else name + return 's3:' + name[4:] if name.startswith('s3n:') else name def _get_key(self, name): if isinstance(name, Key): @@ -180,13 +178,13 @@ def _get_key(self, name): ) def _key_name(self, name): - cleaned = self._clean(name).split("/", 1) + cleaned = self._clean(name).split('/', 1) return cleaned[1] if len(cleaned) > 1 else None @retry(3) def _list( self, - prefix="", + prefix='', bucket=None, delimiter=None, keys=False, @@ -194,55 +192,55 @@ def _list( page_size=1000, **kwargs, ): - assert bucket is not None, "You must specify a bucket to list" + assert bucket is not None, 'You must specify a bucket to list' bucket = self._bucket_name(bucket) - paginator = self.client.get_paginator("list_objects_v2") + paginator = self.client.get_paginator('list_objects_v2') operation_parameters = { - "Bucket": bucket, - "Prefix": prefix, - "PaginationConfig": {"PageSize": page_size}, + 'Bucket': bucket, + 'Prefix': prefix, + 'PaginationConfig': {'PageSize': page_size}, } if delimiter: - operation_parameters["Delimiter"] = delimiter + operation_parameters['Delimiter'] = delimiter page_iterator = paginator.paginate(**operation_parameters) def sort(item): - if "Key" in item: - return item["Key"] - return item["Prefix"] + if 'Key' in item: + return item['Key'] + return item['Prefix'] for page in page_iterator: locations = sorted( - [i for i in page.get("Contents", []) + page.get("CommonPrefixes", [])], + [i for i in page.get('Contents', []) + page.get('CommonPrefixes', [])], key=sort, ) for item in locations: if objects or keys: - if "Key" in item: + if 'Key' in item: yield Key( bucket, - item["Key"], - size=item.get("Size"), - etag=item.get("ETag"), - last_modified=item.get("LastModified"), - storage_class=item.get("StorageClass"), + item['Key'], + size=item.get('Size'), + etag=item.get('ETag'), + last_modified=item.get('LastModified'), + storage_class=item.get('StorageClass'), service=self, ) elif objects: - yield Prefix(bucket, item["Prefix"], service=self) + yield Prefix(bucket, item['Prefix'], service=self) else: - prefix = item["Key"] if "Key" in item else item["Prefix"] - yield f"s3://{bucket}/{prefix}" + prefix = item['Key'] if 'Key' in item else item['Prefix'] + yield f's3://{bucket}/{prefix}' def _put( self, source, dest, num_callbacks=10, - policy="bucket-owner-full-control", + policy='bucket-owner-full-control', **kwargs, ): key = self._get_key(dest) @@ -251,9 +249,9 @@ def _put( # support passing in open file obj. Why did we do this in the past? if not isinstance(source, str): - obj.upload_fileobj(source, ExtraArgs={"ACL": policy}) + obj.upload_fileobj(source, ExtraArgs={'ACL': policy}) else: - obj.upload_file(source, ExtraArgs={"ACL": policy}) + obj.upload_file(source, ExtraArgs={'ACL': policy}) return key def _put_string( @@ -261,14 +259,14 @@ def _put_string( source, dest, num_callbacks=10, - policy="bucket-owner-full-control", + policy='bucket-owner-full-control', **kwargs, ): key = self._get_key(dest) obj = self.s3.Object(key.bucket.name, key.name) if isinstance(source, str): - source = source.encode("utf-8") + source = source.encode('utf-8') obj.put(Body=source, ACL=policy) return key @@ -278,7 +276,7 @@ def _is_s3(self, name): return False name = self._clean_s3(name) - return "s3://" in name + return 's3://' in name def cat( self, @@ -286,7 +284,7 @@ def cat( buffersize=None, memsize=2**24, compressed=False, - encoding="UTF-8", + encoding='UTF-8', raw=False, ): """ @@ -298,17 +296,17 @@ def cat( """ assert self._is_s3(source) or isinstance( source, Key - ), "source must be a valid s3 path" + ), 'source must be a valid s3 path' key = self._get_key(source) if not isinstance(source, Key) else source - compressed = (compressed or key.name.endswith(".gz")) and not raw + compressed = (compressed or key.name.endswith('.gz')) and not raw if compressed: decompress = zlib.decompressobj(16 + zlib.MAX_WBITS) size = 0 bytes_read = 0 err = None - undecoded = "" + undecoded = '' if key: # try to read the file multiple times for i in range(100): @@ -318,7 +316,7 @@ def cat( if not size: size = obj.content_length elif size != obj.content_length: - raise AwsError("key size unexpectedly changed while reading") + raise AwsError('key size unexpectedly changed while reading') # For an empty file, 0 (first-bytes-pos) is equal to the length of the object # hence the range is "unsatisfiable", and botocore correctly handles it by @@ -326,16 +324,16 @@ def cat( if size == 0: break - r = obj.get(Range=f"bytes={bytes_read}-") + r = obj.get(Range=f'bytes={bytes_read}-') try: while bytes_read < size: # this making this weird check because this call is # about 100 times slower if the amt is too high if size - bytes_read > buffersize: - bytes = r["Body"].read(amt=buffersize) + bytes = r['Body'].read(amt=buffersize) else: - bytes = r["Body"].read() + bytes = r['Body'].read() if compressed: s = decompress.decompress(bytes) else: @@ -344,7 +342,7 @@ def cat( if encoding and not raw: try: decoded = undecoded + s.decode(encoding) - undecoded = "" + undecoded = '' yield decoded except UnicodeDecodeError: undecoded += s @@ -356,7 +354,7 @@ def cat( bytes_read += len(bytes) except zlib.error: - logger.error("Error while decompressing [%s]", key.name) + logger.error('Error while decompressing [%s]', key.name) raise except UnicodeDecodeError: raise @@ -371,7 +369,7 @@ def cat( if err: raise Exception else: - raise AwsError("Failed to fully read [%s]" % source.name) + raise AwsError('Failed to fully read [%s]' % source.name) if undecoded: assert encoding is not None # only time undecoded is set @@ -392,8 +390,8 @@ def cp_string(self, source, dest, **kwargs): the s3 location """ - assert isinstance(source, str), "source must be a string" - assert self._is_s3(dest), "Destination must be s3 location" + assert isinstance(source, str), 'source must be a string' + assert self._is_s3(dest), 'Destination must be s3 location' return self._put_string(source, dest, **kwargs) @@ -416,7 +414,7 @@ def list(self, name, iterator=False, **kwargs): if True return iterator rather than converting to list object """ - assert self._is_s3(name), "name must be in form s3://bucket/key" + assert self._is_s3(name), 'name must be in form s3://bucket/key' it = self._list( bucket=self._bucket_name(name), prefix=self._key_name(name), **kwargs @@ -442,27 +440,27 @@ def listdir(self, name, **kwargs): files or prefixes that are encountered """ - assert self._is_s3(name), "name must be in form s3://bucket/prefix/" + assert self._is_s3(name), 'name must be in form s3://bucket/prefix/' - if not name.endswith("/"): - name += "/" - return self.list(name, delimiter="/", **kwargs) + if not name.endswith('/'): + name += '/' + return self.list(name, delimiter='/', **kwargs) - def read(self, source, compressed=False, encoding="UTF-8"): + def read(self, source, compressed=False, encoding='UTF-8'): """ Iterates over a file in s3 split on newline. Yields a line in file. """ - buf = "" + buf = '' for block in self.cat(source, compressed=compressed, encoding=encoding): buf += block - if "\n" in buf: - ret, buf = buf.rsplit("\n", 1) - yield from ret.split("\n") + if '\n' in buf: + ret, buf = buf.rsplit('\n', 1) + yield from ret.split('\n') - lines = buf.split("\n") + lines = buf.split('\n') yield from lines[:-1] # only yield the last line if the line has content in it diff --git a/papermill/tests/__init__.py b/papermill/tests/__init__.py index 9843f37e..6ef2067e 100644 --- a/papermill/tests/__init__.py +++ b/papermill/tests/__init__.py @@ -1,13 +1,11 @@ import os - from io import StringIO - -kernel_name = "python3" +kernel_name = 'python3' def get_notebook_path(*args): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "notebooks", *args) + return os.path.join(os.path.dirname(os.path.abspath(__file__)), 'notebooks', *args) def get_notebook_dir(*args): diff --git a/papermill/tests/test_abs.py b/papermill/tests/test_abs.py index 7793f4bd..5ffdb026 100644 --- a/papermill/tests/test_abs.py +++ b/papermill/tests/test_abs.py @@ -1,14 +1,15 @@ import os import unittest - from unittest.mock import Mock, patch + from azure.identity import EnvironmentCredential + from ..abs import AzureBlobStore class MockBytesIO: def __init__(self): - self.list = [b"hello", b"world!"] + self.list = [b'hello', b'world!'] def __getitem__(self, index): return self.list[index] @@ -23,7 +24,7 @@ class ABSTest(unittest.TestCase): """ def setUp(self): - self.list_blobs = Mock(return_value=["foo", "bar", "baz"]) + self.list_blobs = Mock(return_value=['foo', 'bar', 'baz']) self.upload_blob = Mock() self.download_blob = Mock() self._container_client = Mock(list_blobs=self.list_blobs) @@ -36,93 +37,93 @@ def setUp(self): ) self.abs = AzureBlobStore() self.abs._blob_service_client = Mock(return_value=self._blob_service_client) - os.environ["AZURE_TENANT_ID"] = "mytenantid" - os.environ["AZURE_CLIENT_ID"] = "myclientid" - os.environ["AZURE_CLIENT_SECRET"] = "myclientsecret" + os.environ['AZURE_TENANT_ID'] = 'mytenantid' + os.environ['AZURE_CLIENT_ID'] = 'myclientid' + os.environ['AZURE_CLIENT_SECRET'] = 'myclientsecret' def test_split_url_raises_exception_on_invalid_url(self): with self.assertRaises(Exception) as context: - AzureBlobStore._split_url("this_is_not_a_valid_url") + AzureBlobStore._split_url('this_is_not_a_valid_url') self.assertTrue( "Invalid azure blob url 'this_is_not_a_valid_url'" in str(context.exception) ) def test_split_url_splits_valid_url(self): params = AzureBlobStore._split_url( - "abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken" + 'abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken' ) - self.assertEqual(params["account"], "myaccount") - self.assertEqual(params["container"], "sascontainer") - self.assertEqual(params["blob"], "sasblob.txt") - self.assertEqual(params["sas_token"], "sastoken") + self.assertEqual(params['account'], 'myaccount') + self.assertEqual(params['container'], 'sascontainer') + self.assertEqual(params['blob'], 'sasblob.txt') + self.assertEqual(params['sas_token'], 'sastoken') def test_split_url_splits_valid_url_no_sas(self): params = AzureBlobStore._split_url( - "abs://myaccount.blob.core.windows.net/container/blob.txt" + 'abs://myaccount.blob.core.windows.net/container/blob.txt' ) - self.assertEqual(params["account"], "myaccount") - self.assertEqual(params["container"], "container") - self.assertEqual(params["blob"], "blob.txt") - self.assertEqual(params["sas_token"], "") + self.assertEqual(params['account'], 'myaccount') + self.assertEqual(params['container'], 'container') + self.assertEqual(params['blob'], 'blob.txt') + self.assertEqual(params['sas_token'], '') def test_split_url_splits_valid_url_with_prefix(self): params = AzureBlobStore._split_url( - "abs://myaccount.blob.core.windows.net/sascontainer/A/B/sasblob.txt?sastoken" + 'abs://myaccount.blob.core.windows.net/sascontainer/A/B/sasblob.txt?sastoken' ) - self.assertEqual(params["account"], "myaccount") - self.assertEqual(params["container"], "sascontainer") - self.assertEqual(params["blob"], "A/B/sasblob.txt") - self.assertEqual(params["sas_token"], "sastoken") + self.assertEqual(params['account'], 'myaccount') + self.assertEqual(params['container'], 'sascontainer') + self.assertEqual(params['blob'], 'A/B/sasblob.txt') + self.assertEqual(params['sas_token'], 'sastoken') def test_listdir_calls(self): self.assertEqual( self.abs.listdir( - "abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken" + 'abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken' ), - ["foo", "bar", "baz"], + ['foo', 'bar', 'baz'], ) self._blob_service_client.get_container_client.assert_called_once_with( - "sascontainer" + 'sascontainer' ) - self.list_blobs.assert_called_once_with("sasblob.txt") + self.list_blobs.assert_called_once_with('sasblob.txt') - @patch("papermill.abs.io.BytesIO", side_effect=MockBytesIO) + @patch('papermill.abs.io.BytesIO', side_effect=MockBytesIO) def test_reads_file(self, mockBytesIO): self.assertEqual( self.abs.read( - "abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken" + 'abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken' ), - ["hello", "world!"], + ['hello', 'world!'], ) self._blob_service_client.get_blob_client.assert_called_once_with( - "sascontainer", "sasblob.txt" + 'sascontainer', 'sasblob.txt' ) self.download_blob.assert_called_once_with() def test_write_file(self): self.abs.write( - "hello world", - "abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken", + 'hello world', + 'abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken', ) self._blob_service_client.get_blob_client.assert_called_once_with( - "sascontainer", "sasblob.txt" + 'sascontainer', 'sasblob.txt' ) - self.upload_blob.assert_called_once_with(data="hello world", overwrite=True) + self.upload_blob.assert_called_once_with(data='hello world', overwrite=True) def test_blob_service_client(self): abs = AzureBlobStore() - blob = abs._blob_service_client(account_name="myaccount", sas_token="sastoken") - self.assertEqual(blob.account_name, "myaccount") + blob = abs._blob_service_client(account_name='myaccount', sas_token='sastoken') + self.assertEqual(blob.account_name, 'myaccount') # Credentials gets funky with v12.0.0, so I comment this out # self.assertEqual(blob.credential, "sastoken") def test_blob_service_client_environment_credentials(self): abs = AzureBlobStore() - blob = abs._blob_service_client(account_name="myaccount", sas_token="") - self.assertEqual(blob.account_name, "myaccount") + blob = abs._blob_service_client(account_name='myaccount', sas_token='') + self.assertEqual(blob.account_name, 'myaccount') self.assertIsInstance(blob.credential, EnvironmentCredential) - self.assertEqual(blob.credential._credential._tenant_id, "mytenantid") - self.assertEqual(blob.credential._credential._client_id, "myclientid") + self.assertEqual(blob.credential._credential._tenant_id, 'mytenantid') + self.assertEqual(blob.credential._credential._client_id, 'myclientid') self.assertEqual( - blob.credential._credential._client_credential, "myclientsecret" + blob.credential._credential._client_credential, 'myclientsecret' ) diff --git a/papermill/tests/test_adl.py b/papermill/tests/test_adl.py index 6db76be3..3c32f5d4 100644 --- a/papermill/tests/test_adl.py +++ b/papermill/tests/test_adl.py @@ -1,8 +1,9 @@ import unittest +from unittest.mock import MagicMock, Mock, patch -from unittest.mock import Mock, MagicMock, patch - -from ..adl import ADL, core as adl_core, lib as adl_lib +from ..adl import ADL +from ..adl import core as adl_core +from ..adl import lib as adl_lib class ADLTest(unittest.TestCase): @@ -13,13 +14,13 @@ class ADLTest(unittest.TestCase): def setUp(self): self.ls = Mock( return_value=[ - "path/to/directory/foo", - "path/to/directory/bar", - "path/to/directory/baz", + 'path/to/directory/foo', + 'path/to/directory/bar', + 'path/to/directory/baz', ] ) self.fakeFile = MagicMock() - self.fakeFile.__iter__.return_value = [b"a", b"b", b"c"] + self.fakeFile.__iter__.return_value = [b'a', b'b', b'c'] self.fakeFile.__enter__.return_value = self.fakeFile self.open = Mock(return_value=self.fakeFile) self.fakeAdapter = Mock(open=self.open, ls=self.ls) @@ -28,49 +29,49 @@ def setUp(self): def test_split_url_raises_exception_on_invalid_url(self): with self.assertRaises(Exception) as context: - ADL._split_url("this_is_not_a_valid_url") + ADL._split_url('this_is_not_a_valid_url') self.assertTrue( "Invalid ADL url 'this_is_not_a_valid_url'" in str(context.exception) ) def test_split_url_splits_valid_url(self): - (store_name, path) = ADL._split_url("adl://foo.azuredatalakestore.net/bar/baz") - self.assertEqual(store_name, "foo") - self.assertEqual(path, "bar/baz") + (store_name, path) = ADL._split_url('adl://foo.azuredatalakestore.net/bar/baz') + self.assertEqual(store_name, 'foo') + self.assertEqual(path, 'bar/baz') def test_listdir_calls_ls_on_adl_adapter(self): self.assertEqual( self.adl.listdir( - "adl://foo_store.azuredatalakestore.net/path/to/directory" + 'adl://foo_store.azuredatalakestore.net/path/to/directory' ), [ - "adl://foo_store.azuredatalakestore.net/path/to/directory/foo", - "adl://foo_store.azuredatalakestore.net/path/to/directory/bar", - "adl://foo_store.azuredatalakestore.net/path/to/directory/baz", + 'adl://foo_store.azuredatalakestore.net/path/to/directory/foo', + 'adl://foo_store.azuredatalakestore.net/path/to/directory/bar', + 'adl://foo_store.azuredatalakestore.net/path/to/directory/baz', ], ) - self.ls.assert_called_once_with("path/to/directory") + self.ls.assert_called_once_with('path/to/directory') def test_read_opens_and_reads_file(self): self.assertEqual( - self.adl.read("adl://foo_store.azuredatalakestore.net/path/to/file"), - ["a", "b", "c"], + self.adl.read('adl://foo_store.azuredatalakestore.net/path/to/file'), + ['a', 'b', 'c'], ) self.fakeFile.__iter__.assert_called_once_with() def test_write_opens_file_and_writes_to_it(self): self.adl.write( - "hello world", "adl://foo_store.azuredatalakestore.net/path/to/file" + 'hello world', 'adl://foo_store.azuredatalakestore.net/path/to/file' ) - self.fakeFile.write.assert_called_once_with(b"hello world") + self.fakeFile.write.assert_called_once_with(b'hello world') - @patch.object(adl_lib, "auth", return_value="my_token") - @patch.object(adl_core, "AzureDLFileSystem", return_value="my_adapter") + @patch.object(adl_lib, 'auth', return_value='my_token') + @patch.object(adl_core, 'AzureDLFileSystem', return_value='my_adapter') def test_create_adapter(self, azure_dl_filesystem_mock, auth_mock): sut = ADL() - actual = sut._create_adapter("my_store_name") - assert actual == "my_adapter" + actual = sut._create_adapter('my_store_name') + assert actual == 'my_adapter' auth_mock.assert_called_once_with() azure_dl_filesystem_mock.assert_called_once_with( - "my_token", store_name="my_store_name" + 'my_token', store_name='my_store_name' ) diff --git a/papermill/tests/test_autosave.py b/papermill/tests/test_autosave.py index b234c29a..c2d77420 100644 --- a/papermill/tests/test_autosave.py +++ b/papermill/tests/test_autosave.py @@ -1,28 +1,28 @@ -import nbformat import os import tempfile import time import unittest from unittest.mock import patch -from . import get_notebook_path +import nbformat from .. import engines from ..engines import NotebookExecutionManager from ..execute import execute_notebook +from . import get_notebook_path class TestMidCellAutosave(unittest.TestCase): def setUp(self): - self.notebook_name = "test_autosave.ipynb" + self.notebook_name = 'test_autosave.ipynb' self.notebook_path = get_notebook_path(self.notebook_name) self.nb = nbformat.read(self.notebook_path, as_version=4) def test_autosave_not_too_fast(self): nb_man = NotebookExecutionManager( - self.nb, output_path="test.ipynb", autosave_cell_every=0.5 + self.nb, output_path='test.ipynb', autosave_cell_every=0.5 ) - with patch.object(engines, "write_ipynb") as write_mock: + with patch.object(engines, 'write_ipynb') as write_mock: write_mock.reset_mock() assert write_mock.call_count == 0 # check that the mock is sane nb_man.autosave_cell() # First call to autosave shouldn't trigger save @@ -35,9 +35,9 @@ def test_autosave_not_too_fast(self): def test_autosave_disable(self): nb_man = NotebookExecutionManager( - self.nb, output_path="test.ipynb", autosave_cell_every=0 + self.nb, output_path='test.ipynb', autosave_cell_every=0 ) - with patch.object(engines, "write_ipynb") as write_mock: + with patch.object(engines, 'write_ipynb') as write_mock: write_mock.reset_mock() assert write_mock.call_count == 0 # check that the mock is sane nb_man.autosave_cell() # First call to autosave shouldn't trigger save @@ -52,17 +52,17 @@ def test_autosave_disable(self): def test_end2end_autosave_slow_notebook(self): test_dir = tempfile.mkdtemp() - nb_test_executed_fname = os.path.join(test_dir, f"output_{self.notebook_name}") + nb_test_executed_fname = os.path.join(test_dir, f'output_{self.notebook_name}') # Count how many times it writes the file w/o autosave - with patch.object(engines, "write_ipynb") as write_mock: + with patch.object(engines, 'write_ipynb') as write_mock: execute_notebook( self.notebook_path, nb_test_executed_fname, autosave_cell_every=0 ) default_write_count = write_mock.call_count # Turn on autosave and see how many more times it gets saved. - with patch.object(engines, "write_ipynb") as write_mock: + with patch.object(engines, 'write_ipynb') as write_mock: execute_notebook( self.notebook_path, nb_test_executed_fname, autosave_cell_every=1 ) diff --git a/papermill/tests/test_cli.py b/papermill/tests/test_cli.py index 7381fd24..7d1f6e11 100755 --- a/papermill/tests/test_cli.py +++ b/papermill/tests/test_cli.py @@ -2,35 +2,34 @@ """ Test the command line interface """ import os -from pathlib import Path -import sys import subprocess +import sys import tempfile -import uuid -import nbclient - -import nbformat import unittest +import uuid +from pathlib import Path from unittest.mock import patch +import nbclient +import nbformat import pytest from click.testing import CliRunner -from . import get_notebook_path, kernel_name from .. import cli -from ..cli import papermill, _is_int, _is_float, _resolve_type +from ..cli import _is_float, _is_int, _resolve_type, papermill +from . import get_notebook_path, kernel_name @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("True", True), - ("False", False), - ("None", None), - ("12.51", 12.51), - ("10", 10), - ("hello world", "hello world"), - ("😍", "😍"), + ('True', True), + ('False', False), + ('None', None), + ('12.51', 12.51), + ('10', 10), + ('hello world', 'hello world'), + ('😍', '😍'), ], ) def test_resolve_type(test_input, expected): @@ -38,17 +37,17 @@ def test_resolve_type(test_input, expected): @pytest.mark.parametrize( - "value,expected", + 'value,expected', [ (13.71, True), - ("False", False), - ("None", False), + ('False', False), + ('None', False), (-8.2, True), (10, True), - ("10", True), - ("12.31", True), - ("hello world", False), - ("😍", False), + ('10', True), + ('12.31', True), + ('hello world', False), + ('😍', False), ], ) def test_is_float(value, expected): @@ -56,17 +55,17 @@ def test_is_float(value, expected): @pytest.mark.parametrize( - "value,expected", + 'value,expected', [ (13.71, True), - ("False", False), - ("None", False), + ('False', False), + ('None', False), (-8.2, True), - ("-23.2", False), + ('-23.2', False), (10, True), - ("13", True), - ("hello world", False), - ("😍", False), + ('13', True), + ('hello world', False), + ('😍', False), ], ) def test_is_int(value, expected): @@ -75,8 +74,8 @@ def test_is_int(value, expected): class TestCLI(unittest.TestCase): default_execute_kwargs = dict( - input_path="input.ipynb", - output_path="output.ipynb", + input_path='input.ipynb', + output_path='output.ipynb', parameters={}, engine_name=None, request_save_on_cell_execute=True, @@ -97,14 +96,14 @@ class TestCLI(unittest.TestCase): def setUp(self): self.runner = CliRunner() self.default_args = [ - self.default_execute_kwargs["input_path"], - self.default_execute_kwargs["output_path"], + self.default_execute_kwargs['input_path'], + self.default_execute_kwargs['output_path'], ] self.sample_yaml_file = os.path.join( - os.path.dirname(__file__), "parameters", "example.yaml" + os.path.dirname(__file__), 'parameters', 'example.yaml' ) self.sample_json_file = os.path.join( - os.path.dirname(__file__), "parameters", "example.json" + os.path.dirname(__file__), 'parameters', 'example.json' ) def augment_execute_kwargs(self, **new_kwargs): @@ -112,32 +111,32 @@ def augment_execute_kwargs(self, **new_kwargs): kwargs.update(new_kwargs) return kwargs - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters(self, execute_patch): self.runner.invoke( papermill, - self.default_args + ["-p", "foo", "bar", "--parameters", "baz", "42"], + self.default_args + ['-p', 'foo', 'bar', '--parameters', 'baz', '42'], ) execute_patch.assert_called_with( - **self.augment_execute_kwargs(parameters={"foo": "bar", "baz": 42}) + **self.augment_execute_kwargs(parameters={'foo': 'bar', 'baz': 42}) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_raw(self, execute_patch): self.runner.invoke( papermill, - self.default_args + ["-r", "foo", "bar", "--parameters_raw", "baz", "42"], + self.default_args + ['-r', 'foo', 'bar', '--parameters_raw', 'baz', '42'], ) execute_patch.assert_called_with( - **self.augment_execute_kwargs(parameters={"foo": "bar", "baz": "42"}) + **self.augment_execute_kwargs(parameters={'foo': 'bar', 'baz': '42'}) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_file(self, execute_patch): extra_args = [ - "-f", + '-f', self.sample_yaml_file, - "--parameters_file", + '--parameters_file', self.sample_json_file, ] self.runner.invoke(papermill, self.default_args + extra_args) @@ -145,45 +144,45 @@ def test_parameters_file(self, execute_patch): **self.augment_execute_kwargs( # Last input wins dict update parameters={ - "foo": 54321, - "bar": "value", - "baz": {"k2": "v2", "k1": "v1"}, - "a_date": "2019-01-01", + 'foo': 54321, + 'bar': 'value', + 'baz': {'k2': 'v2', 'k1': 'v1'}, + 'a_date': '2019-01-01', } ) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_yaml(self, execute_patch): self.runner.invoke( papermill, self.default_args - + ["-y", '{"foo": "bar"}', "--parameters_yaml", '{"foo2": ["baz"]}'], + + ['-y', '{"foo": "bar"}', '--parameters_yaml', '{"foo2": ["baz"]}'], ) execute_patch.assert_called_with( - **self.augment_execute_kwargs(parameters={"foo": "bar", "foo2": ["baz"]}) + **self.augment_execute_kwargs(parameters={'foo': 'bar', 'foo2': ['baz']}) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_yaml_date(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["-y", "a_date: 2019-01-01"]) + self.runner.invoke(papermill, self.default_args + ['-y', 'a_date: 2019-01-01']) execute_patch.assert_called_with( - **self.augment_execute_kwargs(parameters={"a_date": "2019-01-01"}) + **self.augment_execute_kwargs(parameters={'a_date': '2019-01-01'}) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_empty(self, execute_patch): # "#empty" ---base64--> "I2VtcHR5" with tempfile.TemporaryDirectory() as tmpdir: - empty_yaml = Path(tmpdir) / "empty.yaml" - empty_yaml.write_text("#empty") + empty_yaml = Path(tmpdir) / 'empty.yaml' + empty_yaml.write_text('#empty') extra_args = [ - "--parameters_file", + '--parameters_file', str(empty_yaml), - "--parameters_yaml", - "#empty", - "--parameters_base64", - "I2VtcHR5", + '--parameters_yaml', + '#empty', + '--parameters_base64', + 'I2VtcHR5', ] self.runner.invoke( papermill, @@ -196,139 +195,139 @@ def test_parameters_empty(self, execute_patch): ) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_yaml_override(self, execute_patch): self.runner.invoke( papermill, self.default_args - + ["--parameters_yaml", '{"foo": "bar"}', "-y", '{"foo": ["baz"]}'], + + ['--parameters_yaml', '{"foo": "bar"}', '-y', '{"foo": ["baz"]}'], ) execute_patch.assert_called_with( **self.augment_execute_kwargs( # Last input wins dict update - parameters={"foo": ["baz"]} + parameters={'foo': ['baz']} ) ) @patch( - cli.__name__ + ".execute_notebook", - side_effect=nbclient.exceptions.DeadKernelError("Fake"), + cli.__name__ + '.execute_notebook', + side_effect=nbclient.exceptions.DeadKernelError('Fake'), ) def test_parameters_dead_kernel(self, execute_patch): result = self.runner.invoke( papermill, self.default_args - + ["--parameters_yaml", '{"foo": "bar"}', "-y", '{"foo": ["baz"]}'], + + ['--parameters_yaml', '{"foo": "bar"}', '-y', '{"foo": ["baz"]}'], ) assert result.exit_code == 138 - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_base64(self, execute_patch): extra_args = [ - "--parameters_base64", - "eyJmb28iOiAicmVwbGFjZWQiLCAiYmFyIjogMn0=", - "-b", - "eydmb28nOiAxfQ==", + '--parameters_base64', + 'eyJmb28iOiAicmVwbGFjZWQiLCAiYmFyIjogMn0=', + '-b', + 'eydmb28nOiAxfQ==', ] self.runner.invoke(papermill, self.default_args + extra_args) execute_patch.assert_called_with( - **self.augment_execute_kwargs(parameters={"foo": 1, "bar": 2}) + **self.augment_execute_kwargs(parameters={'foo': 1, 'bar': 2}) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_parameters_base64_date(self, execute_patch): self.runner.invoke( papermill, - self.default_args + ["--parameters_base64", "YV9kYXRlOiAyMDE5LTAxLTAx"], + self.default_args + ['--parameters_base64', 'YV9kYXRlOiAyMDE5LTAxLTAx'], ) execute_patch.assert_called_with( - **self.augment_execute_kwargs(parameters={"a_date": "2019-01-01"}) + **self.augment_execute_kwargs(parameters={'a_date': '2019-01-01'}) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_inject_input_path(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--inject-input-path"]) + self.runner.invoke(papermill, self.default_args + ['--inject-input-path']) execute_patch.assert_called_with( **self.augment_execute_kwargs( - parameters={"PAPERMILL_INPUT_PATH": "input.ipynb"} + parameters={'PAPERMILL_INPUT_PATH': 'input.ipynb'} ) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_inject_output_path(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--inject-output-path"]) + self.runner.invoke(papermill, self.default_args + ['--inject-output-path']) execute_patch.assert_called_with( **self.augment_execute_kwargs( - parameters={"PAPERMILL_OUTPUT_PATH": "output.ipynb"} + parameters={'PAPERMILL_OUTPUT_PATH': 'output.ipynb'} ) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_inject_paths(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--inject-paths"]) + self.runner.invoke(papermill, self.default_args + ['--inject-paths']) execute_patch.assert_called_with( **self.augment_execute_kwargs( parameters={ - "PAPERMILL_INPUT_PATH": "input.ipynb", - "PAPERMILL_OUTPUT_PATH": "output.ipynb", + 'PAPERMILL_INPUT_PATH': 'input.ipynb', + 'PAPERMILL_OUTPUT_PATH': 'output.ipynb', } ) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_engine(self, execute_patch): self.runner.invoke( - papermill, self.default_args + ["--engine", "engine-that-could"] + papermill, self.default_args + ['--engine', 'engine-that-could'] ) execute_patch.assert_called_with( - **self.augment_execute_kwargs(engine_name="engine-that-could") + **self.augment_execute_kwargs(engine_name='engine-that-could') ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_prepare_only(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--prepare-only"]) + self.runner.invoke(papermill, self.default_args + ['--prepare-only']) execute_patch.assert_called_with( **self.augment_execute_kwargs(prepare_only=True) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_kernel(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["-k", "python3"]) + self.runner.invoke(papermill, self.default_args + ['-k', 'python3']) execute_patch.assert_called_with( - **self.augment_execute_kwargs(kernel_name="python3") + **self.augment_execute_kwargs(kernel_name='python3') ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_language(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["-l", "python"]) + self.runner.invoke(papermill, self.default_args + ['-l', 'python']) execute_patch.assert_called_with( - **self.augment_execute_kwargs(language="python") + **self.augment_execute_kwargs(language='python') ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_set_cwd(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--cwd", "a/path/here"]) + self.runner.invoke(papermill, self.default_args + ['--cwd', 'a/path/here']) execute_patch.assert_called_with( - **self.augment_execute_kwargs(cwd="a/path/here") + **self.augment_execute_kwargs(cwd='a/path/here') ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_progress_bar(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--progress-bar"]) + self.runner.invoke(papermill, self.default_args + ['--progress-bar']) execute_patch.assert_called_with( **self.augment_execute_kwargs(progress_bar=True) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_no_progress_bar(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--no-progress-bar"]) + self.runner.invoke(papermill, self.default_args + ['--no-progress-bar']) execute_patch.assert_called_with( **self.augment_execute_kwargs(progress_bar=False) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_log_output(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--log-output"]) + self.runner.invoke(papermill, self.default_args + ['--log-output']) execute_patch.assert_called_with( **self.augment_execute_kwargs( log_output=True, @@ -336,107 +335,107 @@ def test_log_output(self, execute_patch): ) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_log_output_plus_progress(self, execute_patch): self.runner.invoke( - papermill, self.default_args + ["--log-output", "--progress-bar"] + papermill, self.default_args + ['--log-output', '--progress-bar'] ) execute_patch.assert_called_with( **self.augment_execute_kwargs(log_output=True, progress_bar=True) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_no_log_output(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--no-log-output"]) + self.runner.invoke(papermill, self.default_args + ['--no-log-output']) execute_patch.assert_called_with( **self.augment_execute_kwargs(log_output=False) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_log_level(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--log-level", "WARNING"]) + self.runner.invoke(papermill, self.default_args + ['--log-level', 'WARNING']) # TODO: this does not actually test log-level being set execute_patch.assert_called_with(**self.augment_execute_kwargs()) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_start_timeout(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--start-timeout", "123"]) + self.runner.invoke(papermill, self.default_args + ['--start-timeout', '123']) execute_patch.assert_called_with( **self.augment_execute_kwargs(start_timeout=123) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_start_timeout_backwards_compatibility(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--start_timeout", "123"]) + self.runner.invoke(papermill, self.default_args + ['--start_timeout', '123']) execute_patch.assert_called_with( **self.augment_execute_kwargs(start_timeout=123) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_execution_timeout(self, execute_patch): self.runner.invoke( - papermill, self.default_args + ["--execution-timeout", "123"] + papermill, self.default_args + ['--execution-timeout', '123'] ) execute_patch.assert_called_with( **self.augment_execute_kwargs(execution_timeout=123) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_report_mode(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--report-mode"]) + self.runner.invoke(papermill, self.default_args + ['--report-mode']) execute_patch.assert_called_with( **self.augment_execute_kwargs(report_mode=True) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_no_report_mode(self, execute_patch): - self.runner.invoke(papermill, self.default_args + ["--no-report-mode"]) + self.runner.invoke(papermill, self.default_args + ['--no-report-mode']) execute_patch.assert_called_with( **self.augment_execute_kwargs(report_mode=False) ) - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_version(self, execute_patch): - self.runner.invoke(papermill, ["--version"]) + self.runner.invoke(papermill, ['--version']) execute_patch.assert_not_called() - @patch(cli.__name__ + ".execute_notebook") - @patch(cli.__name__ + ".display_notebook_help") + @patch(cli.__name__ + '.execute_notebook') + @patch(cli.__name__ + '.display_notebook_help') def test_help_notebook(self, display_notebook_help, execute_path): - self.runner.invoke(papermill, ["--help-notebook", "input_path.ipynb"]) + self.runner.invoke(papermill, ['--help-notebook', 'input_path.ipynb']) execute_path.assert_not_called() assert display_notebook_help.call_count == 1 - assert display_notebook_help.call_args[0][1] == "input_path.ipynb" + assert display_notebook_help.call_args[0][1] == 'input_path.ipynb' - @patch(cli.__name__ + ".execute_notebook") + @patch(cli.__name__ + '.execute_notebook') def test_many_args(self, execute_patch): extra_args = [ - "-f", + '-f', self.sample_yaml_file, - "-y", + '-y', '{"yaml_foo": {"yaml_bar": "yaml_baz"}}', - "-b", - "eyJiYXNlNjRfZm9vIjogImJhc2U2NF9iYXIifQ==", - "-p", - "baz", - "replace", - "-r", - "foo", - "54321", - "--kernel", - "R", - "--engine", - "engine-that-could", - "--prepare-only", - "--log-output", - "--autosave-cell-every", - "17", - "--no-progress-bar", - "--start-timeout", - "321", - "--execution-timeout", - "654", - "--report-mode", + '-b', + 'eyJiYXNlNjRfZm9vIjogImJhc2U2NF9iYXIifQ==', + '-p', + 'baz', + 'replace', + '-r', + 'foo', + '54321', + '--kernel', + 'R', + '--engine', + 'engine-that-could', + '--prepare-only', + '--log-output', + '--autosave-cell-every', + '17', + '--no-progress-bar', + '--start-timeout', + '321', + '--execution-timeout', + '654', + '--report-mode', ] self.runner.invoke( papermill, @@ -445,18 +444,18 @@ def test_many_args(self, execute_patch): execute_patch.assert_called_with( **self.augment_execute_kwargs( parameters={ - "foo": "54321", - "bar": "value", - "baz": "replace", - "yaml_foo": {"yaml_bar": "yaml_baz"}, - "base64_foo": "base64_bar", - "a_date": "2019-01-01", + 'foo': '54321', + 'bar': 'value', + 'baz': 'replace', + 'yaml_foo': {'yaml_bar': 'yaml_baz'}, + 'base64_foo': 'base64_bar', + 'a_date': '2019-01-01', }, - engine_name="engine-that-could", + engine_name='engine-that-could', request_save_on_cell_execute=True, autosave_cell_every=17, prepare_only=True, - kernel_name="R", + kernel_name='R', log_output=True, progress_bar=False, start_timeout=321, @@ -468,7 +467,7 @@ def test_many_args(self, execute_patch): def papermill_cli(papermill_args=None, **kwargs): - cmd = [sys.executable, "-m", "papermill"] + cmd = [sys.executable, '-m', 'papermill'] if papermill_args: cmd.extend(papermill_args) return subprocess.Popen(cmd, **kwargs) @@ -476,11 +475,11 @@ def papermill_cli(papermill_args=None, **kwargs): def papermill_version(): try: - proc = papermill_cli(["--version"], stdout=subprocess.PIPE) + proc = papermill_cli(['--version'], stdout=subprocess.PIPE) out, _ = proc.communicate() if proc.returncode: return None - return out.decode("utf-8") + return out.decode('utf-8') except (OSError, SystemExit): # pragma: no cover return None @@ -488,22 +487,22 @@ def papermill_version(): @pytest.fixture() def notebook(): metadata = { - "kernelspec": { - "name": "python3", - "language": "python", - "display_name": "python3", + 'kernelspec': { + 'name': 'python3', + 'language': 'python', + 'display_name': 'python3', } } return nbformat.v4.new_notebook( metadata=metadata, cells=[ - nbformat.v4.new_markdown_cell("This is a notebook with kernel: python3") + nbformat.v4.new_markdown_cell('This is a notebook with kernel: python3') ], ) require_papermill_installed = pytest.mark.skipif( - not papermill_version(), reason="papermill is not installed" + not papermill_version(), reason='papermill is not installed' ) @@ -511,31 +510,31 @@ def notebook(): def test_pipe_in_out_auto(notebook): process = papermill_cli(stdout=subprocess.PIPE, stdin=subprocess.PIPE) text = nbformat.writes(notebook) - out, err = process.communicate(input=text.encode("utf-8")) + out, err = process.communicate(input=text.encode('utf-8')) # Test no message on std error assert not err # Test that output is a valid notebook - nbformat.reads(out.decode("utf-8"), as_version=4) + nbformat.reads(out.decode('utf-8'), as_version=4) @require_papermill_installed def test_pipe_in_out_explicit(notebook): - process = papermill_cli(["-", "-"], stdout=subprocess.PIPE, stdin=subprocess.PIPE) + process = papermill_cli(['-', '-'], stdout=subprocess.PIPE, stdin=subprocess.PIPE) text = nbformat.writes(notebook) - out, err = process.communicate(input=text.encode("utf-8")) + out, err = process.communicate(input=text.encode('utf-8')) # Test no message on std error assert not err # Test that output is a valid notebook - nbformat.reads(out.decode("utf-8"), as_version=4) + nbformat.reads(out.decode('utf-8'), as_version=4) @require_papermill_installed def test_pipe_out_auto(tmpdir, notebook): - nb_file = tmpdir.join("notebook.ipynb") + nb_file = tmpdir.join('notebook.ipynb') nb_file.write(nbformat.writes(notebook)) process = papermill_cli([str(nb_file)], stdout=subprocess.PIPE) @@ -545,31 +544,31 @@ def test_pipe_out_auto(tmpdir, notebook): assert not err # Test that output is a valid notebook - nbformat.reads(out.decode("utf-8"), as_version=4) + nbformat.reads(out.decode('utf-8'), as_version=4) @require_papermill_installed def test_pipe_out_explicit(tmpdir, notebook): - nb_file = tmpdir.join("notebook.ipynb") + nb_file = tmpdir.join('notebook.ipynb') nb_file.write(nbformat.writes(notebook)) - process = papermill_cli([str(nb_file), "-"], stdout=subprocess.PIPE) + process = papermill_cli([str(nb_file), '-'], stdout=subprocess.PIPE) out, err = process.communicate() # Test no message on std error assert not err # Test that output is a valid notebook - nbformat.reads(out.decode("utf-8"), as_version=4) + nbformat.reads(out.decode('utf-8'), as_version=4) @require_papermill_installed def test_pipe_in_auto(tmpdir, notebook): - nb_file = tmpdir.join("notebook.ipynb") + nb_file = tmpdir.join('notebook.ipynb') process = papermill_cli([str(nb_file)], stdin=subprocess.PIPE) text = nbformat.writes(notebook) - out, _ = process.communicate(input=text.encode("utf-8")) + out, _ = process.communicate(input=text.encode('utf-8')) # Nothing on stdout assert not out @@ -581,11 +580,11 @@ def test_pipe_in_auto(tmpdir, notebook): @require_papermill_installed def test_pipe_in_explicit(tmpdir, notebook): - nb_file = tmpdir.join("notebook.ipynb") + nb_file = tmpdir.join('notebook.ipynb') - process = papermill_cli(["-", str(nb_file)], stdin=subprocess.PIPE) + process = papermill_cli(['-', str(nb_file)], stdin=subprocess.PIPE) text = nbformat.writes(notebook) - out, _ = process.communicate(input=text.encode("utf-8")) + out, _ = process.communicate(input=text.encode('utf-8')) # Nothing on stdout assert not out @@ -597,20 +596,20 @@ def test_pipe_in_explicit(tmpdir, notebook): @require_papermill_installed def test_stdout_file(tmpdir): - nb_file = tmpdir.join("notebook.ipynb") - stdout_file = tmpdir.join("notebook.stdout") + nb_file = tmpdir.join('notebook.ipynb') + stdout_file = tmpdir.join('notebook.stdout') secret = str(uuid.uuid4()) process = papermill_cli( [ - get_notebook_path("simple_execute.ipynb"), + get_notebook_path('simple_execute.ipynb'), str(nb_file), - "-k", + '-k', kernel_name, - "-p", - "msg", + '-p', + 'msg', secret, - "--stdout-file", + '--stdout-file', str(stdout_file), ] ) @@ -620,4 +619,4 @@ def test_stdout_file(tmpdir): assert not err with open(str(stdout_file)) as fp: - assert fp.read() == secret + "\n" + assert fp.read() == secret + '\n' diff --git a/papermill/tests/test_clientwrap.py b/papermill/tests/test_clientwrap.py index deeb29a1..32309cf6 100644 --- a/papermill/tests/test_clientwrap.py +++ b/papermill/tests/test_clientwrap.py @@ -1,40 +1,39 @@ -import nbformat import unittest - from unittest.mock import call, patch -from . import get_notebook_path +import nbformat -from ..log import logger -from ..engines import NotebookExecutionManager from ..clientwrap import PapermillNotebookClient +from ..engines import NotebookExecutionManager +from ..log import logger +from . import get_notebook_path class TestPapermillClientWrapper(unittest.TestCase): def setUp(self): - self.nb = nbformat.read(get_notebook_path("test_logging.ipynb"), as_version=4) + self.nb = nbformat.read(get_notebook_path('test_logging.ipynb'), as_version=4) self.nb_man = NotebookExecutionManager(self.nb) self.client = PapermillNotebookClient(self.nb_man, log=logger, log_output=True) def test_logging_stderr_msg(self): - with patch.object(logger, "warning") as warning_mock: - for output in self.nb.cells[0].get("outputs", []): + with patch.object(logger, 'warning') as warning_mock: + for output in self.nb.cells[0].get('outputs', []): self.client.log_output_message(output) - warning_mock.assert_called_once_with("INFO:test:test text\n") + warning_mock.assert_called_once_with('INFO:test:test text\n') def test_logging_stdout_msg(self): - with patch.object(logger, "info") as info_mock: - for output in self.nb.cells[1].get("outputs", []): + with patch.object(logger, 'info') as info_mock: + for output in self.nb.cells[1].get('outputs', []): self.client.log_output_message(output) - info_mock.assert_called_once_with("hello world\n") + info_mock.assert_called_once_with('hello world\n') def test_logging_data_msg(self): - with patch.object(logger, "info") as info_mock: - for output in self.nb.cells[2].get("outputs", []): + with patch.object(logger, 'info') as info_mock: + for output in self.nb.cells[2].get('outputs', []): self.client.log_output_message(output) info_mock.assert_has_calls( [ - call(""), - call(""), + call(''), + call(''), ] ) diff --git a/papermill/tests/test_engines.py b/papermill/tests/test_engines.py index e635a6f9..ec25b376 100644 --- a/papermill/tests/test_engines.py +++ b/papermill/tests/test_engines.py @@ -1,17 +1,16 @@ import copy -import dateutil import unittest - from abc import ABCMeta -from unittest.mock import Mock, patch, call -from nbformat.notebooknode import NotebookNode +from unittest.mock import Mock, call, patch -from . import get_notebook_path +import dateutil +from nbformat.notebooknode import NotebookNode from .. import engines, exceptions -from ..log import logger +from ..engines import Engine, NBClientEngine, NotebookExecutionManager from ..iorw import load_notebook_node -from ..engines import NotebookExecutionManager, Engine, NBClientEngine +from ..log import logger +from . import get_notebook_path def AnyMock(cls): @@ -30,11 +29,11 @@ def __eq__(self, other): class TestNotebookExecutionManager(unittest.TestCase): def setUp(self): - self.notebook_name = "simple_execute.ipynb" + self.notebook_name = 'simple_execute.ipynb' self.notebook_path = get_notebook_path(self.notebook_name) self.nb = load_notebook_node(self.notebook_path) self.foo_nb = copy.deepcopy(self.nb) - self.foo_nb.metadata["foo"] = "bar" + self.foo_nb.metadata['foo'] = 'bar' def test_basic_pbar(self): nb_man = NotebookExecutionManager(self.nb) @@ -51,73 +50,73 @@ def test_set_timer(self): nb_man = NotebookExecutionManager(self.nb) now = nb_man.now() - with patch.object(nb_man, "now", return_value=now): + with patch.object(nb_man, 'now', return_value=now): nb_man.set_timer() self.assertEqual(nb_man.start_time, now) self.assertIsNone(nb_man.end_time) def test_save(self): - nb_man = NotebookExecutionManager(self.nb, output_path="test.ipynb") - with patch.object(engines, "write_ipynb") as write_mock: + nb_man = NotebookExecutionManager(self.nb, output_path='test.ipynb') + with patch.object(engines, 'write_ipynb') as write_mock: nb_man.save() - write_mock.assert_called_with(self.nb, "test.ipynb") + write_mock.assert_called_with(self.nb, 'test.ipynb') def test_save_no_output(self): nb_man = NotebookExecutionManager(self.nb) - with patch.object(engines, "write_ipynb") as write_mock: + with patch.object(engines, 'write_ipynb') as write_mock: nb_man.save() write_mock.assert_not_called() def test_save_new_nb(self): nb_man = NotebookExecutionManager(self.nb) nb_man.save(nb=self.foo_nb) - self.assertEqual(nb_man.nb.metadata["foo"], "bar") + self.assertEqual(nb_man.nb.metadata['foo'], 'bar') def test_get_cell_description(self): nb_man = NotebookExecutionManager(self.nb) self.assertIsNone(nb_man.get_cell_description(nb_man.nb.cells[0])) - self.assertEqual(nb_man.get_cell_description(nb_man.nb.cells[1]), "DESC") + self.assertEqual(nb_man.get_cell_description(nb_man.nb.cells[1]), 'DESC') def test_notebook_start(self): nb_man = NotebookExecutionManager(self.nb) - nb_man.nb.metadata["foo"] = "bar" + nb_man.nb.metadata['foo'] = 'bar' nb_man.save = Mock() nb_man.notebook_start() self.assertEqual( - nb_man.nb.metadata.papermill["start_time"], nb_man.start_time.isoformat() + nb_man.nb.metadata.papermill['start_time'], nb_man.start_time.isoformat() ) - self.assertIsNone(nb_man.nb.metadata.papermill["end_time"]) - self.assertIsNone(nb_man.nb.metadata.papermill["duration"]) - self.assertIsNone(nb_man.nb.metadata.papermill["exception"]) + self.assertIsNone(nb_man.nb.metadata.papermill['end_time']) + self.assertIsNone(nb_man.nb.metadata.papermill['duration']) + self.assertIsNone(nb_man.nb.metadata.papermill['exception']) for cell in nb_man.nb.cells: - self.assertIsNone(cell.metadata.papermill["start_time"]) - self.assertIsNone(cell.metadata.papermill["end_time"]) - self.assertIsNone(cell.metadata.papermill["duration"]) - self.assertIsNone(cell.metadata.papermill["exception"]) + self.assertIsNone(cell.metadata.papermill['start_time']) + self.assertIsNone(cell.metadata.papermill['end_time']) + self.assertIsNone(cell.metadata.papermill['duration']) + self.assertIsNone(cell.metadata.papermill['exception']) self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.PENDING + cell.metadata.papermill['status'], NotebookExecutionManager.PENDING ) - self.assertIsNone(cell.get("execution_count")) - if cell.cell_type == "code": - self.assertEqual(cell.get("outputs"), []) + self.assertIsNone(cell.get('execution_count')) + if cell.cell_type == 'code': + self.assertEqual(cell.get('outputs'), []) else: - self.assertIsNone(cell.get("outputs")) + self.assertIsNone(cell.get('outputs')) nb_man.save.assert_called_once() def test_notebook_start_new_nb(self): nb_man = NotebookExecutionManager(self.nb) nb_man.notebook_start(nb=self.foo_nb) - self.assertEqual(nb_man.nb.metadata["foo"], "bar") + self.assertEqual(nb_man.nb.metadata['foo'], 'bar') def test_notebook_start_markdown_code(self): nb_man = NotebookExecutionManager(self.nb) nb_man.notebook_start(nb=self.foo_nb) - self.assertNotIn("execution_count", nb_man.nb.cells[-1]) - self.assertNotIn("outputs", nb_man.nb.cells[-1]) + self.assertNotIn('execution_count', nb_man.nb.cells[-1]) + self.assertNotIn('outputs', nb_man.nb.cells[-1]) def test_cell_start(self): nb_man = NotebookExecutionManager(self.nb) @@ -129,10 +128,10 @@ def test_cell_start(self): nb_man.save = Mock() nb_man.cell_start(cell) - self.assertEqual(cell.metadata.papermill["start_time"], fixed_now.isoformat()) - self.assertFalse(cell.metadata.papermill["exception"]) + self.assertEqual(cell.metadata.papermill['start_time'], fixed_now.isoformat()) + self.assertFalse(cell.metadata.papermill['exception']) self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.RUNNING + cell.metadata.papermill['status'], NotebookExecutionManager.RUNNING ) nb_man.save.assert_called_once() @@ -140,7 +139,7 @@ def test_cell_start(self): def test_cell_start_new_nb(self): nb_man = NotebookExecutionManager(self.nb) nb_man.cell_start(self.foo_nb.cells[0], nb=self.foo_nb) - self.assertEqual(nb_man.nb.metadata["foo"], "bar") + self.assertEqual(nb_man.nb.metadata['foo'], 'bar') def test_cell_exception(self): nb_man = NotebookExecutionManager(self.nb) @@ -148,16 +147,16 @@ def test_cell_exception(self): cell = nb_man.nb.cells[0] nb_man.cell_exception(cell) - self.assertEqual(nb_man.nb.metadata.papermill["exception"], True) - self.assertEqual(cell.metadata.papermill["exception"], True) + self.assertEqual(nb_man.nb.metadata.papermill['exception'], True) + self.assertEqual(cell.metadata.papermill['exception'], True) self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.FAILED + cell.metadata.papermill['status'], NotebookExecutionManager.FAILED ) def test_cell_exception_new_nb(self): nb_man = NotebookExecutionManager(self.nb) nb_man.cell_exception(self.foo_nb.cells[0], nb=self.foo_nb) - self.assertEqual(nb_man.nb.metadata["foo"], "bar") + self.assertEqual(nb_man.nb.metadata['foo'], 'bar') def test_cell_complete_after_cell_start(self): nb_man = NotebookExecutionManager(self.nb) @@ -173,17 +172,17 @@ def test_cell_complete_after_cell_start(self): nb_man.pbar = Mock() nb_man.cell_complete(cell) - self.assertIsNotNone(cell.metadata.papermill["start_time"]) - start_time = dateutil.parser.parse(cell.metadata.papermill["start_time"]) + self.assertIsNotNone(cell.metadata.papermill['start_time']) + start_time = dateutil.parser.parse(cell.metadata.papermill['start_time']) - self.assertEqual(cell.metadata.papermill["end_time"], fixed_now.isoformat()) + self.assertEqual(cell.metadata.papermill['end_time'], fixed_now.isoformat()) self.assertEqual( - cell.metadata.papermill["duration"], + cell.metadata.papermill['duration'], (fixed_now - start_time).total_seconds(), ) - self.assertFalse(cell.metadata.papermill["exception"]) + self.assertFalse(cell.metadata.papermill['exception']) self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.COMPLETED + cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED ) nb_man.save.assert_called_once() @@ -202,11 +201,11 @@ def test_cell_complete_without_cell_start(self): nb_man.pbar = Mock() nb_man.cell_complete(cell) - self.assertEqual(cell.metadata.papermill["end_time"], fixed_now.isoformat()) - self.assertIsNone(cell.metadata.papermill["duration"]) - self.assertFalse(cell.metadata.papermill["exception"]) + self.assertEqual(cell.metadata.papermill['end_time'], fixed_now.isoformat()) + self.assertIsNone(cell.metadata.papermill['duration']) + self.assertFalse(cell.metadata.papermill['exception']) self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.COMPLETED + cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED ) nb_man.save.assert_called_once() @@ -227,17 +226,17 @@ def test_cell_complete_after_cell_exception(self): nb_man.pbar = Mock() nb_man.cell_complete(cell) - self.assertIsNotNone(cell.metadata.papermill["start_time"]) - start_time = dateutil.parser.parse(cell.metadata.papermill["start_time"]) + self.assertIsNotNone(cell.metadata.papermill['start_time']) + start_time = dateutil.parser.parse(cell.metadata.papermill['start_time']) - self.assertEqual(cell.metadata.papermill["end_time"], fixed_now.isoformat()) + self.assertEqual(cell.metadata.papermill['end_time'], fixed_now.isoformat()) self.assertEqual( - cell.metadata.papermill["duration"], + cell.metadata.papermill['duration'], (fixed_now - start_time).total_seconds(), ) - self.assertTrue(cell.metadata.papermill["exception"]) + self.assertTrue(cell.metadata.papermill['exception']) self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.FAILED + cell.metadata.papermill['status'], NotebookExecutionManager.FAILED ) nb_man.save.assert_called_once() @@ -247,9 +246,9 @@ def test_cell_complete_new_nb(self): nb_man = NotebookExecutionManager(self.nb) nb_man.notebook_start() baz_nb = copy.deepcopy(nb_man.nb) - baz_nb.metadata["baz"] = "buz" + baz_nb.metadata['baz'] = 'buz' nb_man.cell_complete(baz_nb.cells[0], nb=baz_nb) - self.assertEqual(nb_man.nb.metadata["baz"], "buz") + self.assertEqual(nb_man.nb.metadata['baz'], 'buz') def test_notebook_complete(self): nb_man = NotebookExecutionManager(self.nb) @@ -264,17 +263,17 @@ def test_notebook_complete(self): nb_man.notebook_complete() - self.assertIsNotNone(nb_man.nb.metadata.papermill["start_time"]) - start_time = dateutil.parser.parse(nb_man.nb.metadata.papermill["start_time"]) + self.assertIsNotNone(nb_man.nb.metadata.papermill['start_time']) + start_time = dateutil.parser.parse(nb_man.nb.metadata.papermill['start_time']) self.assertEqual( - nb_man.nb.metadata.papermill["end_time"], fixed_now.isoformat() + nb_man.nb.metadata.papermill['end_time'], fixed_now.isoformat() ) self.assertEqual( - nb_man.nb.metadata.papermill["duration"], + nb_man.nb.metadata.papermill['duration'], (fixed_now - start_time).total_seconds(), ) - self.assertFalse(nb_man.nb.metadata.papermill["exception"]) + self.assertFalse(nb_man.nb.metadata.papermill['exception']) nb_man.save.assert_called_once() nb_man.cleanup_pbar.assert_called_once() @@ -283,9 +282,9 @@ def test_notebook_complete_new_nb(self): nb_man = NotebookExecutionManager(self.nb) nb_man.notebook_start() baz_nb = copy.deepcopy(nb_man.nb) - baz_nb.metadata["baz"] = "buz" + baz_nb.metadata['baz'] = 'buz' nb_man.notebook_complete(nb=baz_nb) - self.assertEqual(nb_man.nb.metadata["baz"], "buz") + self.assertEqual(nb_man.nb.metadata['baz'], 'buz') def test_notebook_complete_cell_status_completed(self): nb_man = NotebookExecutionManager(self.nb) @@ -293,7 +292,7 @@ def test_notebook_complete_cell_status_completed(self): nb_man.notebook_complete() for cell in nb_man.nb.cells: self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.COMPLETED + cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED ) def test_notebook_complete_cell_status_with_failed(self): @@ -302,22 +301,22 @@ def test_notebook_complete_cell_status_with_failed(self): nb_man.cell_exception(nb_man.nb.cells[1]) nb_man.notebook_complete() self.assertEqual( - nb_man.nb.cells[0].metadata.papermill["status"], + nb_man.nb.cells[0].metadata.papermill['status'], NotebookExecutionManager.COMPLETED, ) self.assertEqual( - nb_man.nb.cells[1].metadata.papermill["status"], + nb_man.nb.cells[1].metadata.papermill['status'], NotebookExecutionManager.FAILED, ) for cell in nb_man.nb.cells[2:]: self.assertEqual( - cell.metadata.papermill["status"], NotebookExecutionManager.PENDING + cell.metadata.papermill['status'], NotebookExecutionManager.PENDING ) class TestEngineBase(unittest.TestCase): def setUp(self): - self.notebook_name = "simple_execute.ipynb" + self.notebook_name = 'simple_execute.ipynb' self.notebook_path = get_notebook_path(self.notebook_name) self.nb = load_notebook_node(self.notebook_path) @@ -326,27 +325,27 @@ def test_wrap_and_execute_notebook(self): Mocks each wrapped call and proves the correct inputs get applied to the correct underlying calls for execute_notebook. """ - with patch.object(Engine, "execute_managed_notebook") as exec_mock: - with patch.object(engines, "NotebookExecutionManager") as wrap_mock: + with patch.object(Engine, 'execute_managed_notebook') as exec_mock: + with patch.object(engines, 'NotebookExecutionManager') as wrap_mock: Engine.execute_notebook( self.nb, - "python", - output_path="foo.ipynb", + 'python', + output_path='foo.ipynb', progress_bar=False, log_output=True, - bar="baz", + bar='baz', ) wrap_mock.assert_called_once_with( self.nb, - output_path="foo.ipynb", + output_path='foo.ipynb', progress_bar=False, log_output=True, autosave_cell_every=30, ) wrap_mock.return_value.notebook_start.assert_called_once() exec_mock.assert_called_once_with( - wrap_mock.return_value, "python", log_output=True, bar="baz" + wrap_mock.return_value, 'python', log_output=True, bar='baz' ) wrap_mock.return_value.notebook_complete.assert_called_once() wrap_mock.return_value.cleanup_pbar.assert_called_once() @@ -359,9 +358,9 @@ def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs): nb_man.cell_start(cell) nb_man.cell_complete(cell) - with patch.object(NotebookExecutionManager, "save") as save_mock: + with patch.object(NotebookExecutionManager, 'save') as save_mock: nb = CellCallbackEngine.execute_notebook( - copy.deepcopy(self.nb), "python", output_path="foo.ipynb" + copy.deepcopy(self.nb), 'python', output_path='foo.ipynb' ) self.assertEqual(nb, AnyMock(NotebookNode)) @@ -369,18 +368,18 @@ def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs): self.assertEqual(save_mock.call_count, 8) - self.assertIsNotNone(nb.metadata.papermill["start_time"]) - self.assertIsNotNone(nb.metadata.papermill["end_time"]) - self.assertEqual(nb.metadata.papermill["duration"], AnyMock(float)) - self.assertFalse(nb.metadata.papermill["exception"]) + self.assertIsNotNone(nb.metadata.papermill['start_time']) + self.assertIsNotNone(nb.metadata.papermill['end_time']) + self.assertEqual(nb.metadata.papermill['duration'], AnyMock(float)) + self.assertFalse(nb.metadata.papermill['exception']) for cell in nb.cells: - self.assertIsNotNone(cell.metadata.papermill["start_time"]) - self.assertIsNotNone(cell.metadata.papermill["end_time"]) - self.assertEqual(cell.metadata.papermill["duration"], AnyMock(float)) - self.assertFalse(cell.metadata.papermill["exception"]) + self.assertIsNotNone(cell.metadata.papermill['start_time']) + self.assertIsNotNone(cell.metadata.papermill['end_time']) + self.assertEqual(cell.metadata.papermill['duration'], AnyMock(float)) + self.assertFalse(cell.metadata.papermill['exception']) self.assertEqual( - cell.metadata.papermill["status"], + cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED, ) @@ -390,12 +389,12 @@ class NoCellCallbackEngine(Engine): def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs): pass - with patch.object(NotebookExecutionManager, "save") as save_mock: + with patch.object(NotebookExecutionManager, 'save') as save_mock: with patch.object( - NotebookExecutionManager, "complete_pbar" + NotebookExecutionManager, 'complete_pbar' ) as pbar_comp_mock: nb = NoCellCallbackEngine.execute_notebook( - copy.deepcopy(self.nb), "python", output_path="foo.ipynb" + copy.deepcopy(self.nb), 'python', output_path='foo.ipynb' ) self.assertEqual(nb, AnyMock(NotebookNode)) @@ -404,38 +403,38 @@ def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs): self.assertEqual(save_mock.call_count, 2) pbar_comp_mock.assert_called_once() - self.assertIsNotNone(nb.metadata.papermill["start_time"]) - self.assertIsNotNone(nb.metadata.papermill["end_time"]) - self.assertEqual(nb.metadata.papermill["duration"], AnyMock(float)) - self.assertFalse(nb.metadata.papermill["exception"]) + self.assertIsNotNone(nb.metadata.papermill['start_time']) + self.assertIsNotNone(nb.metadata.papermill['end_time']) + self.assertEqual(nb.metadata.papermill['duration'], AnyMock(float)) + self.assertFalse(nb.metadata.papermill['exception']) for cell in nb.cells: - self.assertIsNone(cell.metadata.papermill["start_time"]) - self.assertIsNone(cell.metadata.papermill["end_time"]) - self.assertIsNone(cell.metadata.papermill["duration"]) - self.assertIsNone(cell.metadata.papermill["exception"]) + self.assertIsNone(cell.metadata.papermill['start_time']) + self.assertIsNone(cell.metadata.papermill['end_time']) + self.assertIsNone(cell.metadata.papermill['duration']) + self.assertIsNone(cell.metadata.papermill['exception']) self.assertEqual( - cell.metadata.papermill["status"], + cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED, ) class TestNBClientEngine(unittest.TestCase): def setUp(self): - self.notebook_name = "simple_execute.ipynb" + self.notebook_name = 'simple_execute.ipynb' self.notebook_path = get_notebook_path(self.notebook_name) self.nb = load_notebook_node(self.notebook_path) def test_nb_convert_engine(self): - with patch.object(engines, "PapermillNotebookClient") as client_mock: - with patch.object(NotebookExecutionManager, "save") as save_mock: + with patch.object(engines, 'PapermillNotebookClient') as client_mock: + with patch.object(NotebookExecutionManager, 'save') as save_mock: nb = NBClientEngine.execute_notebook( copy.deepcopy(self.nb), - "python", - output_path="foo.ipynb", + 'python', + output_path='foo.ipynb', progress_bar=False, log_output=True, - bar="baz", + bar='baz', start_timeout=30, execution_timeout=1000, ) @@ -447,15 +446,15 @@ def test_nb_convert_engine(self): args, kwargs = client_mock.call_args expected = [ - ("timeout", 1000), - ("startup_timeout", 30), - ("kernel_name", "python"), - ("log", logger), - ("log_output", True), + ('timeout', 1000), + ('startup_timeout', 30), + ('kernel_name', 'python'), + ('log', logger), + ('log_output', True), ] actual = {(key, kwargs[key]) for key in kwargs} msg = ( - f"Expected arguments {expected} are not a subset of actual {actual}" + f'Expected arguments {expected} are not a subset of actual {actual}' ) self.assertTrue(set(expected).issubset(actual), msg=msg) @@ -464,71 +463,71 @@ def test_nb_convert_engine(self): self.assertEqual(save_mock.call_count, 2) def test_nb_convert_engine_execute(self): - with patch.object(NotebookExecutionManager, "save") as save_mock: + with patch.object(NotebookExecutionManager, 'save') as save_mock: nb = NBClientEngine.execute_notebook( self.nb, - "python", - output_path="foo.ipynb", + 'python', + output_path='foo.ipynb', progress_bar=False, log_output=True, ) self.assertEqual(save_mock.call_count, 8) self.assertEqual(nb, AnyMock(NotebookNode)) - self.assertIsNotNone(nb.metadata.papermill["start_time"]) - self.assertIsNotNone(nb.metadata.papermill["end_time"]) - self.assertEqual(nb.metadata.papermill["duration"], AnyMock(float)) - self.assertFalse(nb.metadata.papermill["exception"]) + self.assertIsNotNone(nb.metadata.papermill['start_time']) + self.assertIsNotNone(nb.metadata.papermill['end_time']) + self.assertEqual(nb.metadata.papermill['duration'], AnyMock(float)) + self.assertFalse(nb.metadata.papermill['exception']) for cell in nb.cells: - self.assertIsNotNone(cell.metadata.papermill["start_time"]) - self.assertIsNotNone(cell.metadata.papermill["end_time"]) - self.assertEqual(cell.metadata.papermill["duration"], AnyMock(float)) - self.assertFalse(cell.metadata.papermill["exception"]) + self.assertIsNotNone(cell.metadata.papermill['start_time']) + self.assertIsNotNone(cell.metadata.papermill['end_time']) + self.assertEqual(cell.metadata.papermill['duration'], AnyMock(float)) + self.assertFalse(cell.metadata.papermill['exception']) self.assertEqual( - cell.metadata.papermill["status"], + cell.metadata.papermill['status'], NotebookExecutionManager.COMPLETED, ) def test_nb_convert_log_outputs(self): - with patch.object(logger, "info") as info_mock: - with patch.object(logger, "warning") as warning_mock: - with patch.object(NotebookExecutionManager, "save"): + with patch.object(logger, 'info') as info_mock: + with patch.object(logger, 'warning') as warning_mock: + with patch.object(NotebookExecutionManager, 'save'): NBClientEngine.execute_notebook( self.nb, - "python", - output_path="foo.ipynb", + 'python', + output_path='foo.ipynb', progress_bar=False, log_output=True, ) info_mock.assert_has_calls( [ - call("Executing notebook with kernel: python"), + call('Executing notebook with kernel: python'), call( - "Executing Cell 1---------------------------------------" + 'Executing Cell 1---------------------------------------' ), call( - "Ending Cell 1------------------------------------------" + 'Ending Cell 1------------------------------------------' ), call( - "Executing Cell 2---------------------------------------" + 'Executing Cell 2---------------------------------------' ), - call("None\n"), + call('None\n'), call( - "Ending Cell 2------------------------------------------" + 'Ending Cell 2------------------------------------------' ), ] ) warning_mock.is_not_called() def test_nb_convert_no_log_outputs(self): - with patch.object(logger, "info") as info_mock: - with patch.object(logger, "warning") as warning_mock: - with patch.object(NotebookExecutionManager, "save"): + with patch.object(logger, 'info') as info_mock: + with patch.object(logger, 'warning') as warning_mock: + with patch.object(NotebookExecutionManager, 'save'): NBClientEngine.execute_notebook( self.nb, - "python", - output_path="foo.ipynb", + 'python', + output_path='foo.ipynb', progress_bar=False, log_output=False, ) @@ -542,33 +541,33 @@ def setUp(self): def test_registration(self): mock_engine = Mock() - self.papermill_engines.register("mock_engine", mock_engine) - self.assertIn("mock_engine", self.papermill_engines._engines) - self.assertIs(mock_engine, self.papermill_engines._engines["mock_engine"]) + self.papermill_engines.register('mock_engine', mock_engine) + self.assertIn('mock_engine', self.papermill_engines._engines) + self.assertIs(mock_engine, self.papermill_engines._engines['mock_engine']) def test_getting(self): mock_engine = Mock() - self.papermill_engines.register("mock_engine", mock_engine) + self.papermill_engines.register('mock_engine', mock_engine) # test retrieving an engine works - retrieved_engine = self.papermill_engines.get_engine("mock_engine") + retrieved_engine = self.papermill_engines.get_engine('mock_engine') self.assertIs(mock_engine, retrieved_engine) # test you can't retrieve a non-registered engine self.assertRaises( exceptions.PapermillException, self.papermill_engines.get_engine, - "non-existent", + 'non-existent', ) def test_registering_entry_points(self): fake_entrypoint = Mock(load=Mock()) - fake_entrypoint.name = "fake-engine" + fake_entrypoint.name = 'fake-engine' with patch( - "entrypoints.get_group_all", return_value=[fake_entrypoint] + 'entrypoints.get_group_all', return_value=[fake_entrypoint] ) as mock_get_group_all: self.papermill_engines.register_entry_points() - mock_get_group_all.assert_called_once_with("papermill.engine") + mock_get_group_all.assert_called_once_with('papermill.engine') self.assertEqual( - self.papermill_engines.get_engine("fake-engine"), + self.papermill_engines.get_engine('fake-engine'), fake_entrypoint.load.return_value, ) diff --git a/papermill/tests/test_exceptions.py b/papermill/tests/test_exceptions.py index 9c555942..191767fb 100644 --- a/papermill/tests/test_exceptions.py +++ b/papermill/tests/test_exceptions.py @@ -12,29 +12,29 @@ def temp_file(): """NamedTemporaryFile must be set in wb mode, closed without delete, opened with open(file, "rb"), then manually deleted. Otherwise, file fails to be read due to permission error on Windows. """ - with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f: + with tempfile.NamedTemporaryFile(mode='wb', delete=False) as f: yield f os.unlink(f.name) @pytest.mark.parametrize( - "exc,args", + 'exc,args', [ ( exceptions.PapermillExecutionError, - (1, 2, "TestSource", "Exception", Exception(), ["Traceback", "Message"]), + (1, 2, 'TestSource', 'Exception', Exception(), ['Traceback', 'Message']), ), ( exceptions.PapermillMissingParameterException, - ("PapermillMissingParameterException",), + ('PapermillMissingParameterException',), ), - (exceptions.AwsError, ("AwsError",)), - (exceptions.FileExistsError, ("FileExistsError",)), - (exceptions.PapermillException, ("PapermillException",)), - (exceptions.PapermillRateLimitException, ("PapermillRateLimitException",)), + (exceptions.AwsError, ('AwsError',)), + (exceptions.FileExistsError, ('FileExistsError',)), + (exceptions.PapermillException, ('PapermillException',)), + (exceptions.PapermillRateLimitException, ('PapermillRateLimitException',)), ( exceptions.PapermillOptionalDependencyException, - ("PapermillOptionalDependencyException",), + ('PapermillOptionalDependencyException',), ), ], ) @@ -45,7 +45,7 @@ def test_exceptions_are_unpickleable(temp_file, exc, args): temp_file.close() # close to re-open for reading # Read the Pickled File - with open(temp_file.name, "rb") as read_file: + with open(temp_file.name, 'rb') as read_file: read_file.seek(0) data = read_file.read() pickled_err = pickle.loads(data) diff --git a/papermill/tests/test_execute.py b/papermill/tests/test_execute.py index 350d9b0f..daefc33b 100644 --- a/papermill/tests/test_execute.py +++ b/papermill/tests/test_execute.py @@ -3,20 +3,19 @@ import tempfile import unittest from copy import deepcopy -from unittest.mock import patch, ANY - from functools import partial from pathlib import Path +from unittest.mock import ANY, patch import nbformat from nbformat import validate from .. import engines, translators -from ..log import logger +from ..exceptions import PapermillExecutionError +from ..execute import execute_notebook from ..iorw import load_notebook_node +from ..log import logger from ..utils import chdir -from ..execute import execute_notebook -from ..exceptions import PapermillExecutionError from . import get_notebook_path, kernel_name execute_notebook = partial(execute_notebook, kernel_name=kernel_name) @@ -25,132 +24,132 @@ class TestNotebookHelpers(unittest.TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() - self.notebook_name = "simple_execute.ipynb" + self.notebook_name = 'simple_execute.ipynb' self.notebook_path = get_notebook_path(self.notebook_name) self.nb_test_executed_fname = os.path.join( - self.test_dir, f"output_{self.notebook_name}" + self.test_dir, f'output_{self.notebook_name}' ) def tearDown(self): shutil.rmtree(self.test_dir) - @patch(engines.__name__ + ".PapermillNotebookClient") + @patch(engines.__name__ + '.PapermillNotebookClient') def test_start_timeout(self, preproc_mock): execute_notebook( self.notebook_path, self.nb_test_executed_fname, start_timeout=123 ) args, kwargs = preproc_mock.call_args expected = [ - ("timeout", None), - ("startup_timeout", 123), - ("kernel_name", kernel_name), - ("log", logger), + ('timeout', None), + ('startup_timeout', 123), + ('kernel_name', kernel_name), + ('log', logger), ] actual = {(key, kwargs[key]) for key in kwargs} self.assertTrue( set(expected).issubset(actual), - msg=f"Expected arguments {expected} are not a subset of actual {actual}", + msg=f'Expected arguments {expected} are not a subset of actual {actual}', ) - @patch(engines.__name__ + ".PapermillNotebookClient") + @patch(engines.__name__ + '.PapermillNotebookClient') def test_default_start_timeout(self, preproc_mock): execute_notebook(self.notebook_path, self.nb_test_executed_fname) args, kwargs = preproc_mock.call_args expected = [ - ("timeout", None), - ("startup_timeout", 60), - ("kernel_name", kernel_name), - ("log", logger), + ('timeout', None), + ('startup_timeout', 60), + ('kernel_name', kernel_name), + ('log', logger), ] actual = {(key, kwargs[key]) for key in kwargs} self.assertTrue( set(expected).issubset(actual), - msg=f"Expected arguments {expected} are not a subset of actual {actual}", + msg=f'Expected arguments {expected} are not a subset of actual {actual}', ) def test_cell_insertion(self): execute_notebook( - self.notebook_path, self.nb_test_executed_fname, {"msg": "Hello"} + self.notebook_path, self.nb_test_executed_fname, {'msg': 'Hello'} ) test_nb = load_notebook_node(self.nb_test_executed_fname) self.assertListEqual( - test_nb.cells[1].get("source").split("\n"), - ["# Parameters", 'msg = "Hello"', ""], + test_nb.cells[1].get('source').split('\n'), + ['# Parameters', 'msg = "Hello"', ''], ) - self.assertEqual(test_nb.metadata.papermill.parameters, {"msg": "Hello"}) + self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': 'Hello'}) def test_no_tags(self): - notebook_name = "no_parameters.ipynb" - nb_test_executed_fname = os.path.join(self.test_dir, f"output_{notebook_name}") + notebook_name = 'no_parameters.ipynb' + nb_test_executed_fname = os.path.join(self.test_dir, f'output_{notebook_name}') execute_notebook( - get_notebook_path(notebook_name), nb_test_executed_fname, {"msg": "Hello"} + get_notebook_path(notebook_name), nb_test_executed_fname, {'msg': 'Hello'} ) test_nb = load_notebook_node(nb_test_executed_fname) self.assertListEqual( - test_nb.cells[0].get("source").split("\n"), - ["# Parameters", 'msg = "Hello"', ""], + test_nb.cells[0].get('source').split('\n'), + ['# Parameters', 'msg = "Hello"', ''], ) - self.assertEqual(test_nb.metadata.papermill.parameters, {"msg": "Hello"}) + self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': 'Hello'}) def test_quoted_params(self): execute_notebook( - self.notebook_path, self.nb_test_executed_fname, {"msg": '"Hello"'} + self.notebook_path, self.nb_test_executed_fname, {'msg': '"Hello"'} ) test_nb = load_notebook_node(self.nb_test_executed_fname) self.assertListEqual( - test_nb.cells[1].get("source").split("\n"), - ["# Parameters", r'msg = "\"Hello\""', ""], + test_nb.cells[1].get('source').split('\n'), + ['# Parameters', r'msg = "\"Hello\""', ''], ) - self.assertEqual(test_nb.metadata.papermill.parameters, {"msg": '"Hello"'}) + self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': '"Hello"'}) def test_backslash_params(self): execute_notebook( - self.notebook_path, self.nb_test_executed_fname, {"foo": r"do\ not\ crash"} + self.notebook_path, self.nb_test_executed_fname, {'foo': r'do\ not\ crash'} ) test_nb = load_notebook_node(self.nb_test_executed_fname) self.assertListEqual( - test_nb.cells[1].get("source").split("\n"), - ["# Parameters", r'foo = "do\\ not\\ crash"', ""], + test_nb.cells[1].get('source').split('\n'), + ['# Parameters', r'foo = "do\\ not\\ crash"', ''], ) self.assertEqual( - test_nb.metadata.papermill.parameters, {"foo": r"do\ not\ crash"} + test_nb.metadata.papermill.parameters, {'foo': r'do\ not\ crash'} ) def test_backslash_quote_params(self): execute_notebook( - self.notebook_path, self.nb_test_executed_fname, {"foo": r"bar=\"baz\""} + self.notebook_path, self.nb_test_executed_fname, {'foo': r'bar=\"baz\"'} ) test_nb = load_notebook_node(self.nb_test_executed_fname) self.assertListEqual( - test_nb.cells[1].get("source").split("\n"), - ["# Parameters", r'foo = "bar=\\\"baz\\\""', ""], + test_nb.cells[1].get('source').split('\n'), + ['# Parameters', r'foo = "bar=\\\"baz\\\""', ''], ) - self.assertEqual(test_nb.metadata.papermill.parameters, {"foo": r"bar=\"baz\""}) + self.assertEqual(test_nb.metadata.papermill.parameters, {'foo': r'bar=\"baz\"'}) def test_double_backslash_quote_params(self): execute_notebook( - self.notebook_path, self.nb_test_executed_fname, {"foo": r'\\"bar\\"'} + self.notebook_path, self.nb_test_executed_fname, {'foo': r'\\"bar\\"'} ) test_nb = load_notebook_node(self.nb_test_executed_fname) self.assertListEqual( - test_nb.cells[1].get("source").split("\n"), - ["# Parameters", r'foo = "\\\\\"bar\\\\\""', ""], + test_nb.cells[1].get('source').split('\n'), + ['# Parameters', r'foo = "\\\\\"bar\\\\\""', ''], ) - self.assertEqual(test_nb.metadata.papermill.parameters, {"foo": r'\\"bar\\"'}) + self.assertEqual(test_nb.metadata.papermill.parameters, {'foo': r'\\"bar\\"'}) def test_prepare_only(self): - for example in ["broken1.ipynb", "keyboard_interrupt.ipynb"]: + for example in ['broken1.ipynb', 'keyboard_interrupt.ipynb']: path = get_notebook_path(example) result_path = os.path.join(self.test_dir, example) # Should not raise as we don't execute the notebook at all execute_notebook( - path, result_path, {"foo": r"do\ not\ crash"}, prepare_only=True + path, result_path, {'foo': r'do\ not\ crash'}, prepare_only=True ) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "code") + self.assertEqual(nb.cells[0].cell_type, 'code') self.assertEqual( - nb.cells[0].get("source").split("\n"), - ["# Parameters", r'foo = "do\\ not\\ crash"', ""], + nb.cells[0].get('source').split('\n'), + ['# Parameters', r'foo = "do\\ not\\ crash"', ''], ) @@ -162,50 +161,50 @@ def tearDown(self): shutil.rmtree(self.test_dir) def test(self): - path = get_notebook_path("broken1.ipynb") + path = get_notebook_path('broken1.ipynb') # check that the notebook has two existing marker cells, so that this test is sure to be # validating the removal logic (the markers are simulatin an error in the first code cell # that has since been fixed) original_nb = load_notebook_node(path) self.assertEqual( - original_nb.cells[0].metadata["tags"], ["papermill-error-cell-tag"] + original_nb.cells[0].metadata['tags'], ['papermill-error-cell-tag'] ) - self.assertIn("In [1]", original_nb.cells[0].source) + self.assertIn('In [1]', original_nb.cells[0].source) self.assertEqual( - original_nb.cells[2].metadata["tags"], ["papermill-error-cell-tag"] + original_nb.cells[2].metadata['tags'], ['papermill-error-cell-tag'] ) - result_path = os.path.join(self.test_dir, "broken1.ipynb") + result_path = os.path.join(self.test_dir, 'broken1.ipynb') with self.assertRaises(PapermillExecutionError): execute_notebook(path, result_path) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "markdown") + self.assertEqual(nb.cells[0].cell_type, 'markdown') self.assertRegex( nb.cells[0].source, r'^$', ) - self.assertEqual(nb.cells[0].metadata["tags"], ["papermill-error-cell-tag"]) + self.assertEqual(nb.cells[0].metadata['tags'], ['papermill-error-cell-tag']) - self.assertEqual(nb.cells[1].cell_type, "markdown") + self.assertEqual(nb.cells[1].cell_type, 'markdown') self.assertEqual(nb.cells[2].execution_count, 1) - self.assertEqual(nb.cells[3].cell_type, "markdown") - self.assertEqual(nb.cells[4].cell_type, "markdown") + self.assertEqual(nb.cells[3].cell_type, 'markdown') + self.assertEqual(nb.cells[4].cell_type, 'markdown') - self.assertEqual(nb.cells[5].cell_type, "markdown") + self.assertEqual(nb.cells[5].cell_type, 'markdown') self.assertRegex( nb.cells[5].source, '' ) - self.assertEqual(nb.cells[5].metadata["tags"], ["papermill-error-cell-tag"]) + self.assertEqual(nb.cells[5].metadata['tags'], ['papermill-error-cell-tag']) self.assertEqual(nb.cells[6].execution_count, 2) - self.assertEqual(nb.cells[6].outputs[0].output_type, "error") + self.assertEqual(nb.cells[6].outputs[0].output_type, 'error') self.assertEqual(nb.cells[7].execution_count, None) # double check the removal (the new cells above should be the only two tagged ones) self.assertEqual( sum( - "papermill-error-cell-tag" in cell.metadata.get("tags", []) + 'papermill-error-cell-tag' in cell.metadata.get('tags', []) for cell in nb.cells ), 2, @@ -220,25 +219,25 @@ def tearDown(self): shutil.rmtree(self.test_dir) def test(self): - path = get_notebook_path("broken2.ipynb") - result_path = os.path.join(self.test_dir, "broken2.ipynb") + path = get_notebook_path('broken2.ipynb') + result_path = os.path.join(self.test_dir, 'broken2.ipynb') with self.assertRaises(PapermillExecutionError): execute_notebook(path, result_path) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "markdown") + self.assertEqual(nb.cells[0].cell_type, 'markdown') self.assertRegex( nb.cells[0].source, r'^.*In \[2\].*$', ) self.assertEqual(nb.cells[1].execution_count, 1) - self.assertEqual(nb.cells[2].cell_type, "markdown") + self.assertEqual(nb.cells[2].cell_type, 'markdown') self.assertRegex( nb.cells[2].source, '' ) self.assertEqual(nb.cells[3].execution_count, 2) - self.assertEqual(nb.cells[3].outputs[0].output_type, "display_data") - self.assertEqual(nb.cells[3].outputs[1].output_type, "error") + self.assertEqual(nb.cells[3].outputs[0].output_type, 'display_data') + self.assertEqual(nb.cells[3].outputs[1].output_type, 'error') self.assertEqual(nb.cells[4].execution_count, None) @@ -246,10 +245,10 @@ def test(self): class TestReportMode(unittest.TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() - self.notebook_name = "report_mode_test.ipynb" + self.notebook_name = 'report_mode_test.ipynb' self.notebook_path = get_notebook_path(self.notebook_name) self.nb_test_executed_fname = os.path.join( - self.test_dir, f"output_{self.notebook_name}" + self.test_dir, f'output_{self.notebook_name}' ) def tearDown(self): @@ -257,12 +256,12 @@ def tearDown(self): def test_report_mode(self): nb = execute_notebook( - self.notebook_path, self.nb_test_executed_fname, {"a": 0}, report_mode=True + self.notebook_path, self.nb_test_executed_fname, {'a': 0}, report_mode=True ) for cell in nb.cells: - if cell.cell_type == "code": + if cell.cell_type == 'code': self.assertEqual( - cell.metadata.get("jupyter", {}).get("source_hidden"), True + cell.metadata.get('jupyter', {}).get('source_hidden'), True ) @@ -270,9 +269,9 @@ class TestOutputPathNone(unittest.TestCase): def test_output_path_of_none(self): """Output path of None should return notebook node obj but not write an ipynb""" nb = execute_notebook( - get_notebook_path("simple_execute.ipynb"), None, {"msg": "Hello"} + get_notebook_path('simple_execute.ipynb'), None, {'msg': 'Hello'} ) - self.assertEqual(nb.metadata.papermill.parameters, {"msg": "Hello"}) + self.assertEqual(nb.metadata.papermill.parameters, {'msg': 'Hello'}) class TestCWD(unittest.TestCase): @@ -280,26 +279,26 @@ def setUp(self): self.test_dir = tempfile.mkdtemp() self.base_test_dir = tempfile.mkdtemp() - self.check_notebook_name = "read_check.ipynb" - self.check_notebook_path = os.path.join(self.base_test_dir, "read_check.ipynb") + self.check_notebook_name = 'read_check.ipynb' + self.check_notebook_path = os.path.join(self.base_test_dir, 'read_check.ipynb') # Setup read paths so base_test_dir has check_notebook_name shutil.copyfile( get_notebook_path(self.check_notebook_name), self.check_notebook_path ) - with open(os.path.join(self.test_dir, "check.txt"), "w", encoding="utf-8") as f: + with open(os.path.join(self.test_dir, 'check.txt'), 'w', encoding='utf-8') as f: # Needed for read_check to pass - f.write("exists") + f.write('exists') - self.simple_notebook_name = "simple_execute.ipynb" + self.simple_notebook_name = 'simple_execute.ipynb' self.simple_notebook_path = os.path.join( - self.base_test_dir, "simple_execute.ipynb" + self.base_test_dir, 'simple_execute.ipynb' ) # Setup read paths so base_test_dir has simple_notebook_name shutil.copyfile( get_notebook_path(self.simple_notebook_name), self.simple_notebook_path ) - self.nb_test_executed_fname = "test_output.ipynb" + self.nb_test_executed_fname = 'test_output.ipynb' def tearDown(self): shutil.rmtree(self.test_dir) @@ -352,64 +351,64 @@ def tearDown(self): shutil.rmtree(self.test_dir) def test_sys_exit(self): - notebook_name = "sysexit.ipynb" - result_path = os.path.join(self.test_dir, f"output_{notebook_name}") + notebook_name = 'sysexit.ipynb' + result_path = os.path.join(self.test_dir, f'output_{notebook_name}') execute_notebook(get_notebook_path(notebook_name), result_path) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "code") + self.assertEqual(nb.cells[0].cell_type, 'code') self.assertEqual(nb.cells[0].execution_count, 1) self.assertEqual(nb.cells[1].execution_count, 2) - self.assertEqual(nb.cells[1].outputs[0].output_type, "error") - self.assertEqual(nb.cells[1].outputs[0].ename, "SystemExit") - self.assertEqual(nb.cells[1].outputs[0].evalue, "") + self.assertEqual(nb.cells[1].outputs[0].output_type, 'error') + self.assertEqual(nb.cells[1].outputs[0].ename, 'SystemExit') + self.assertEqual(nb.cells[1].outputs[0].evalue, '') self.assertEqual(nb.cells[2].execution_count, None) def test_sys_exit0(self): - notebook_name = "sysexit0.ipynb" - result_path = os.path.join(self.test_dir, f"output_{notebook_name}") + notebook_name = 'sysexit0.ipynb' + result_path = os.path.join(self.test_dir, f'output_{notebook_name}') execute_notebook(get_notebook_path(notebook_name), result_path) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "code") + self.assertEqual(nb.cells[0].cell_type, 'code') self.assertEqual(nb.cells[0].execution_count, 1) self.assertEqual(nb.cells[1].execution_count, 2) - self.assertEqual(nb.cells[1].outputs[0].output_type, "error") - self.assertEqual(nb.cells[1].outputs[0].ename, "SystemExit") - self.assertEqual(nb.cells[1].outputs[0].evalue, "0") + self.assertEqual(nb.cells[1].outputs[0].output_type, 'error') + self.assertEqual(nb.cells[1].outputs[0].ename, 'SystemExit') + self.assertEqual(nb.cells[1].outputs[0].evalue, '0') self.assertEqual(nb.cells[2].execution_count, None) def test_sys_exit1(self): - notebook_name = "sysexit1.ipynb" - result_path = os.path.join(self.test_dir, f"output_{notebook_name}") + notebook_name = 'sysexit1.ipynb' + result_path = os.path.join(self.test_dir, f'output_{notebook_name}') with self.assertRaises(PapermillExecutionError): execute_notebook(get_notebook_path(notebook_name), result_path) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "markdown") + self.assertEqual(nb.cells[0].cell_type, 'markdown') self.assertRegex( nb.cells[0].source, r'^$', ) self.assertEqual(nb.cells[1].execution_count, 1) - self.assertEqual(nb.cells[2].cell_type, "markdown") + self.assertEqual(nb.cells[2].cell_type, 'markdown') self.assertRegex( nb.cells[2].source, '' ) self.assertEqual(nb.cells[3].execution_count, 2) - self.assertEqual(nb.cells[3].outputs[0].output_type, "error") + self.assertEqual(nb.cells[3].outputs[0].output_type, 'error') self.assertEqual(nb.cells[4].execution_count, None) def test_system_exit(self): - notebook_name = "systemexit.ipynb" - result_path = os.path.join(self.test_dir, f"output_{notebook_name}") + notebook_name = 'systemexit.ipynb' + result_path = os.path.join(self.test_dir, f'output_{notebook_name}') execute_notebook(get_notebook_path(notebook_name), result_path) nb = load_notebook_node(result_path) - self.assertEqual(nb.cells[0].cell_type, "code") + self.assertEqual(nb.cells[0].cell_type, 'code') self.assertEqual(nb.cells[0].execution_count, 1) self.assertEqual(nb.cells[1].execution_count, 2) - self.assertEqual(nb.cells[1].outputs[0].output_type, "error") - self.assertEqual(nb.cells[1].outputs[0].ename, "SystemExit") - self.assertEqual(nb.cells[1].outputs[0].evalue, "") + self.assertEqual(nb.cells[1].outputs[0].output_type, 'error') + self.assertEqual(nb.cells[1].outputs[0].ename, 'SystemExit') + self.assertEqual(nb.cells[1].outputs[0].evalue, '') self.assertEqual(nb.cells[2].execution_count, None) @@ -421,10 +420,10 @@ def tearDown(self): shutil.rmtree(self.test_dir) def test_from_version_4_4_upgrades(self): - notebook_name = "nb_version_4.4.ipynb" - result_path = os.path.join(self.test_dir, f"output_{notebook_name}") + notebook_name = 'nb_version_4.4.ipynb' + result_path = os.path.join(self.test_dir, f'output_{notebook_name}') execute_notebook( - get_notebook_path(notebook_name), result_path, {"var": "It works"} + get_notebook_path(notebook_name), result_path, {'var': 'It works'} ) nb = load_notebook_node(result_path) validate(nb) @@ -438,10 +437,10 @@ def tearDown(self): shutil.rmtree(self.test_dir) def test_no_v3_language_backport(self): - notebook_name = "blank-vscode.ipynb" - result_path = os.path.join(self.test_dir, f"output_{notebook_name}") + notebook_name = 'blank-vscode.ipynb' + result_path = os.path.join(self.test_dir, f'output_{notebook_name}') execute_notebook( - get_notebook_path(notebook_name), result_path, {"var": "It works"} + get_notebook_path(notebook_name), result_path, {'var': 'It works'} ) nb = load_notebook_node(result_path) validate(nb) @@ -455,24 +454,24 @@ def execute_managed_notebook(cls, nb_man, kernel_name, **kwargs): @classmethod def nb_kernel_name(cls, nb, name=None): - return "my_custom_kernel" + return 'my_custom_kernel' @classmethod def nb_language(cls, nb, language=None): - return "my_custom_language" + return 'my_custom_language' def setUp(self): self.test_dir = tempfile.mkdtemp() - self.notebook_path = get_notebook_path("simple_execute.ipynb") + self.notebook_path = get_notebook_path('simple_execute.ipynb') self.nb_test_executed_fname = os.path.join( - self.test_dir, "output_{}".format("simple_execute.ipynb") + self.test_dir, 'output_{}'.format('simple_execute.ipynb') ) self._orig_papermill_engines = deepcopy(engines.papermill_engines) self._orig_translators = deepcopy(translators.papermill_translators) - engines.papermill_engines.register("custom_engine", self.CustomEngine) + engines.papermill_engines.register('custom_engine', self.CustomEngine) translators.papermill_translators.register( - "my_custom_language", translators.PythonTranslator() + 'my_custom_language', translators.PythonTranslator() ) def tearDown(self): @@ -482,11 +481,11 @@ def tearDown(self): @patch.object( CustomEngine, - "execute_managed_notebook", + 'execute_managed_notebook', wraps=CustomEngine.execute_managed_notebook, ) @patch( - "papermill.parameterize.translate_parameters", + 'papermill.parameterize.translate_parameters', wraps=translators.translate_parameters, ) def test_custom_kernel_name_and_language( @@ -498,30 +497,30 @@ def test_custom_kernel_name_and_language( execute_notebook( self.notebook_path, self.nb_test_executed_fname, - engine_name="custom_engine", - parameters={"msg": "fake msg"}, + engine_name='custom_engine', + parameters={'msg': 'fake msg'}, ) self.assertEqual( - execute_managed_notebook.call_args[0], (ANY, "my_custom_kernel") + execute_managed_notebook.call_args[0], (ANY, 'my_custom_kernel') ) self.assertEqual( translate_parameters.call_args[0], - (ANY, "my_custom_language", {"msg": "fake msg"}, ANY), + (ANY, 'my_custom_language', {'msg': 'fake msg'}, ANY), ) class TestNotebookNodeInput(unittest.TestCase): def setUp(self): self.test_dir = tempfile.TemporaryDirectory() - self.result_path = os.path.join(self.test_dir.name, "output.ipynb") + self.result_path = os.path.join(self.test_dir.name, 'output.ipynb') def tearDown(self): self.test_dir.cleanup() def test_notebook_node_input(self): input_nb = nbformat.read( - get_notebook_path("simple_execute.ipynb"), as_version=4 + get_notebook_path('simple_execute.ipynb'), as_version=4 ) - execute_notebook(input_nb, self.result_path, {"msg": "Hello"}) + execute_notebook(input_nb, self.result_path, {'msg': 'Hello'}) test_nb = nbformat.read(self.result_path, as_version=4) - self.assertEqual(test_nb.metadata.papermill.parameters, {"msg": "Hello"}) + self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': 'Hello'}) diff --git a/papermill/tests/test_gcs.py b/papermill/tests/test_gcs.py index 280deb8f..610d8c4c 100644 --- a/papermill/tests/test_gcs.py +++ b/papermill/tests/test_gcs.py @@ -69,124 +69,124 @@ class GCSTest(unittest.TestCase): def setUp(self): self.gcs_handler = GCSHandler() - @patch("papermill.iorw.GCSFileSystem", side_effect=mock_gcs_fs_wrapper()) + @patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper()) def test_gcs_read(self, mock_gcs_filesystem): client = self.gcs_handler._get_client() - self.assertEqual(self.gcs_handler.read("gs://bucket/test.ipynb"), 1) + self.assertEqual(self.gcs_handler.read('gs://bucket/test.ipynb'), 1) # Check that client is only generated once self.assertIs(client, self.gcs_handler._get_client()) - @patch("papermill.iorw.GCSFileSystem", side_effect=mock_gcs_fs_wrapper()) + @patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper()) def test_gcs_write(self, mock_gcs_filesystem): client = self.gcs_handler._get_client() self.assertEqual( - self.gcs_handler.write("new value", "gs://bucket/test.ipynb"), 1 + self.gcs_handler.write('new value', 'gs://bucket/test.ipynb'), 1 ) # Check that client is only generated once self.assertIs(client, self.gcs_handler._get_client()) - @patch("papermill.iorw.GCSFileSystem", side_effect=mock_gcs_fs_wrapper()) + @patch('papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper()) def test_gcs_listdir(self, mock_gcs_filesystem): client = self.gcs_handler._get_client() - self.gcs_handler.listdir("testdir") + self.gcs_handler.listdir('testdir') # Check that client is only generated once self.assertIs(client, self.gcs_handler._get_client()) @patch( - "papermill.iorw.GCSFileSystem", + 'papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper( - GCSRateLimitException({"message": "test", "code": 429}), 10 + GCSRateLimitException({'message': 'test', 'code': 429}), 10 ), ) def test_gcs_handle_exception(self, mock_gcs_filesystem): - with patch.object(GCSHandler, "RETRY_DELAY", 0): - with patch.object(GCSHandler, "RETRY_MULTIPLIER", 0): - with patch.object(GCSHandler, "RETRY_MAX_DELAY", 0): + with patch.object(GCSHandler, 'RETRY_DELAY', 0): + with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0): + with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0): with self.assertRaises(PapermillRateLimitException): self.gcs_handler.write( - "raise_limit_exception", "gs://bucket/test.ipynb" + 'raise_limit_exception', 'gs://bucket/test.ipynb' ) @patch( - "papermill.iorw.GCSFileSystem", + 'papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper( - GCSRateLimitException({"message": "test", "code": 429}), 1 + GCSRateLimitException({'message': 'test', 'code': 429}), 1 ), ) def test_gcs_retry(self, mock_gcs_filesystem): - with patch.object(GCSHandler, "RETRY_DELAY", 0): - with patch.object(GCSHandler, "RETRY_MULTIPLIER", 0): - with patch.object(GCSHandler, "RETRY_MAX_DELAY", 0): + with patch.object(GCSHandler, 'RETRY_DELAY', 0): + with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0): + with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0): self.assertEqual( self.gcs_handler.write( - "raise_limit_exception", "gs://bucket/test.ipynb" + 'raise_limit_exception', 'gs://bucket/test.ipynb' ), 2, ) @patch( - "papermill.iorw.GCSFileSystem", + 'papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper( - GCSHttpError({"message": "test", "code": 429}), 1 + GCSHttpError({'message': 'test', 'code': 429}), 1 ), ) def test_gcs_retry_older_exception(self, mock_gcs_filesystem): - with patch.object(GCSHandler, "RETRY_DELAY", 0): - with patch.object(GCSHandler, "RETRY_MULTIPLIER", 0): - with patch.object(GCSHandler, "RETRY_MAX_DELAY", 0): + with patch.object(GCSHandler, 'RETRY_DELAY', 0): + with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0): + with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0): self.assertEqual( self.gcs_handler.write( - "raise_limit_exception", "gs://bucket/test.ipynb" + 'raise_limit_exception', 'gs://bucket/test.ipynb' ), 2, ) - @patch("papermill.iorw.gs_is_retriable", side_effect=fallback_gs_is_retriable) + @patch('papermill.iorw.gs_is_retriable', side_effect=fallback_gs_is_retriable) @patch( - "papermill.iorw.GCSFileSystem", + 'papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper( - GCSRateLimitException({"message": "test", "code": None}), 1 + GCSRateLimitException({'message': 'test', 'code': None}), 1 ), ) def test_gcs_fallback_retry_unknown_failure_code( self, mock_gcs_filesystem, mock_gcs_retriable ): - with patch.object(GCSHandler, "RETRY_DELAY", 0): - with patch.object(GCSHandler, "RETRY_MULTIPLIER", 0): - with patch.object(GCSHandler, "RETRY_MAX_DELAY", 0): + with patch.object(GCSHandler, 'RETRY_DELAY', 0): + with patch.object(GCSHandler, 'RETRY_MULTIPLIER', 0): + with patch.object(GCSHandler, 'RETRY_MAX_DELAY', 0): self.assertEqual( self.gcs_handler.write( - "raise_limit_exception", "gs://bucket/test.ipynb" + 'raise_limit_exception', 'gs://bucket/test.ipynb' ), 2, ) - @patch("papermill.iorw.gs_is_retriable", return_value=False) + @patch('papermill.iorw.gs_is_retriable', return_value=False) @patch( - "papermill.iorw.GCSFileSystem", + 'papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper( - GCSRateLimitException({"message": "test", "code": 500}), 1 + GCSRateLimitException({'message': 'test', 'code': 500}), 1 ), ) def test_gcs_invalid_code(self, mock_gcs_filesystem, mock_gcs_retriable): with self.assertRaises(GCSRateLimitException): - self.gcs_handler.write("fatal_exception", "gs://bucket/test.ipynb") + self.gcs_handler.write('fatal_exception', 'gs://bucket/test.ipynb') - @patch("papermill.iorw.gs_is_retriable", side_effect=fallback_gs_is_retriable) + @patch('papermill.iorw.gs_is_retriable', side_effect=fallback_gs_is_retriable) @patch( - "papermill.iorw.GCSFileSystem", + 'papermill.iorw.GCSFileSystem', side_effect=mock_gcs_fs_wrapper( - GCSRateLimitException({"message": "test", "code": 500}), 1 + GCSRateLimitException({'message': 'test', 'code': 500}), 1 ), ) def test_fallback_gcs_invalid_code(self, mock_gcs_filesystem, mock_gcs_retriable): with self.assertRaises(GCSRateLimitException): - self.gcs_handler.write("fatal_exception", "gs://bucket/test.ipynb") + self.gcs_handler.write('fatal_exception', 'gs://bucket/test.ipynb') @patch( - "papermill.iorw.GCSFileSystem", - side_effect=mock_gcs_fs_wrapper(ValueError("not-a-retry"), 1), + 'papermill.iorw.GCSFileSystem', + side_effect=mock_gcs_fs_wrapper(ValueError('not-a-retry'), 1), ) def test_gcs_unretryable(self, mock_gcs_filesystem): with self.assertRaises(ValueError): - self.gcs_handler.write("no_a_rate_limit", "gs://bucket/test.ipynb") + self.gcs_handler.write('no_a_rate_limit', 'gs://bucket/test.ipynb') diff --git a/papermill/tests/test_hdfs.py b/papermill/tests/test_hdfs.py index 0577e1f5..7b5a9326 100644 --- a/papermill/tests/test_hdfs.py +++ b/papermill/tests/test_hdfs.py @@ -8,7 +8,7 @@ class MockHadoopFileSystem(MagicMock): def get_file_info(self, path): - return [MockFileInfo("test1.ipynb"), MockFileInfo("test2.ipynb")] + return [MockFileInfo('test1.ipynb'), MockFileInfo('test2.ipynb')] def open_input_stream(self, path): return MockHadoopFile() @@ -19,7 +19,7 @@ def open_output_stream(self, path): class MockHadoopFile: def __init__(self): - self._content = b"Content of notebook" + self._content = b'Content of notebook' def __enter__(self, *args): return self @@ -40,8 +40,8 @@ def __init__(self, path): self.path = path -@pytest.mark.skip(reason="No valid dep package for python 3.12 yet") -@patch("papermill.iorw.HadoopFileSystem", side_effect=MockHadoopFileSystem()) +@pytest.mark.skip(reason='No valid dep package for python 3.12 yet') +@patch('papermill.iorw.HadoopFileSystem', side_effect=MockHadoopFileSystem()) class HDFSTest(unittest.TestCase): def setUp(self): self.hdfs_handler = HDFSHandler() @@ -49,8 +49,8 @@ def setUp(self): def test_hdfs_listdir(self, mock_hdfs_filesystem): client = self.hdfs_handler._get_client() self.assertEqual( - self.hdfs_handler.listdir("hdfs:///Projects/"), - ["test1.ipynb", "test2.ipynb"], + self.hdfs_handler.listdir('hdfs:///Projects/'), + ['test1.ipynb', 'test2.ipynb'], ) # Check if client is the same after calling self.assertIs(client, self.hdfs_handler._get_client()) @@ -58,14 +58,14 @@ def test_hdfs_listdir(self, mock_hdfs_filesystem): def test_hdfs_read(self, mock_hdfs_filesystem): client = self.hdfs_handler._get_client() self.assertEqual( - self.hdfs_handler.read("hdfs:///Projects/test1.ipynb"), - b"Content of notebook", + self.hdfs_handler.read('hdfs:///Projects/test1.ipynb'), + b'Content of notebook', ) self.assertIs(client, self.hdfs_handler._get_client()) def test_hdfs_write(self, mock_hdfs_filesystem): client = self.hdfs_handler._get_client() self.assertEqual( - self.hdfs_handler.write("hdfs:///Projects/test1.ipynb", b"New content"), 1 + self.hdfs_handler.write('hdfs:///Projects/test1.ipynb', b'New content'), 1 ) self.assertIs(client, self.hdfs_handler._get_client()) diff --git a/papermill/tests/test_inspect.py b/papermill/tests/test_inspect.py index bab1df65..6d787e2d 100644 --- a/papermill/tests/test_inspect.py +++ b/papermill/tests/test_inspect.py @@ -3,11 +3,9 @@ import pytest from click import Context - from papermill.inspection import display_notebook_help, inspect_notebook - -NOTEBOOKS_PATH = Path(__file__).parent / "notebooks" +NOTEBOOKS_PATH = Path(__file__).parent / 'notebooks' def _get_fullpath(name): @@ -17,55 +15,55 @@ def _get_fullpath(name): @pytest.fixture def click_context(): mock = MagicMock(spec=Context, command=MagicMock()) - mock.command.get_usage.return_value = "Dummy usage" + mock.command.get_usage.return_value = 'Dummy usage' return mock @pytest.mark.parametrize( - "name, expected", + 'name, expected', [ - (_get_fullpath("no_parameters.ipynb"), {}), + (_get_fullpath('no_parameters.ipynb'), {}), ( - _get_fullpath("simple_execute.ipynb"), + _get_fullpath('simple_execute.ipynb'), { - "msg": { - "name": "msg", - "inferred_type_name": "None", - "default": "None", - "help": "", + 'msg': { + 'name': 'msg', + 'inferred_type_name': 'None', + 'default': 'None', + 'help': '', } }, ), ( - _get_fullpath("complex_parameters.ipynb"), + _get_fullpath('complex_parameters.ipynb'), { - "msg": { - "name": "msg", - "inferred_type_name": "None", - "default": "None", - "help": "", + 'msg': { + 'name': 'msg', + 'inferred_type_name': 'None', + 'default': 'None', + 'help': '', }, - "a": { - "name": "a", - "inferred_type_name": "float", - "default": "2.25", - "help": "Variable a", + 'a': { + 'name': 'a', + 'inferred_type_name': 'float', + 'default': '2.25', + 'help': 'Variable a', }, - "b": { - "name": "b", - "inferred_type_name": "List[str]", - "default": "['Hello','World']", - "help": "Nice list", + 'b': { + 'name': 'b', + 'inferred_type_name': 'List[str]', + 'default': "['Hello','World']", + 'help': 'Nice list', }, - "c": { - "name": "c", - "inferred_type_name": "NoneType", - "default": "None", - "help": "", + 'c': { + 'name': 'c', + 'inferred_type_name': 'NoneType', + 'default': 'None', + 'help': '', }, }, ), - (_get_fullpath("notimplemented_translator.ipynb"), {}), + (_get_fullpath('notimplemented_translator.ipynb'), {}), ], ) def test_inspect_notebook(name, expected): @@ -74,50 +72,50 @@ def test_inspect_notebook(name, expected): def test_str_path(): expected = { - "msg": { - "name": "msg", - "inferred_type_name": "None", - "default": "None", - "help": "", + 'msg': { + 'name': 'msg', + 'inferred_type_name': 'None', + 'default': 'None', + 'help': '', } } - assert inspect_notebook(str(_get_fullpath("simple_execute.ipynb"))) == expected + assert inspect_notebook(str(_get_fullpath('simple_execute.ipynb'))) == expected @pytest.mark.parametrize( - "name, expected", + 'name, expected', [ ( - _get_fullpath("no_parameters.ipynb"), + _get_fullpath('no_parameters.ipynb'), [ - "Dummy usage", + 'Dummy usage', "\nParameters inferred for notebook '{name}':", "\n No cell tagged 'parameters'", ], ), ( - _get_fullpath("simple_execute.ipynb"), + _get_fullpath('simple_execute.ipynb'), [ - "Dummy usage", + 'Dummy usage', "\nParameters inferred for notebook '{name}':", - " msg: Unknown type (default None)", + ' msg: Unknown type (default None)', ], ), ( - _get_fullpath("complex_parameters.ipynb"), + _get_fullpath('complex_parameters.ipynb'), [ - "Dummy usage", + 'Dummy usage', "\nParameters inferred for notebook '{name}':", - " msg: Unknown type (default None)", - " a: float (default 2.25) Variable a", + ' msg: Unknown type (default None)', + ' a: float (default 2.25) Variable a', " b: List[str] (default ['Hello','World'])\n Nice list", - " c: NoneType (default None) ", + ' c: NoneType (default None) ', ], ), ( - _get_fullpath("notimplemented_translator.ipynb"), + _get_fullpath('notimplemented_translator.ipynb'), [ - "Dummy usage", + 'Dummy usage', "\nParameters inferred for notebook '{name}':", "\n Can't infer anything about this notebook's parameters. It may not have any parameter defined.", # noqa ], @@ -125,7 +123,7 @@ def test_str_path(): ], ) def test_display_notebook_help(click_context, name, expected): - with patch("papermill.inspection.click.echo") as echo: + with patch('papermill.inspection.click.echo') as echo: display_notebook_help(click_context, str(name), None) assert echo.call_count == len(expected) diff --git a/papermill/tests/test_iorw.py b/papermill/tests/test_iorw.py index 39ad12b0..d1f765b9 100644 --- a/papermill/tests/test_iorw.py +++ b/papermill/tests/test_iorw.py @@ -1,31 +1,31 @@ +import io import json -import unittest import os -import io +import unittest +from tempfile import TemporaryDirectory +from unittest.mock import Mock, patch + import nbformat import pytest - from requests.exceptions import ConnectionError -from tempfile import TemporaryDirectory -from unittest.mock import Mock, patch from .. import iorw +from ..exceptions import PapermillException from ..iorw import ( + ADLHandler, HttpHandler, LocalHandler, NoIOHandler, - ADLHandler, NotebookNodeHandler, - StreamHandler, PapermillIO, - read_yaml_file, - papermill_io, + StreamHandler, local_file_io_cwd, + papermill_io, + read_yaml_file, ) -from ..exceptions import PapermillException from . import get_notebook_path -FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "fixtures") +FIXTURE_PATH = os.path.join(os.path.dirname(__file__), 'fixtures') class TestPapermillIO(unittest.TestCase): @@ -38,16 +38,16 @@ def __init__(self, ver): self.ver = ver def read(self, path): - return f"contents from {path} for version {self.ver}" + return f'contents from {path} for version {self.ver}' def listdir(self, path): - return ["fake", "contents"] + return ['fake', 'contents'] def write(self, buf, path): - return f"wrote {buf}" + return f'wrote {buf}' def pretty_path(self, path): - return f"{path}/pretty/{self.ver}" + return f'{path}/pretty/{self.ver}' class FakeByteHandler: def __init__(self, ver): @@ -59,13 +59,13 @@ def read(self, path): return f.read() def listdir(self, path): - return ["fake", "contents"] + return ['fake', 'contents'] def write(self, buf, path): - return f"wrote {buf}" + return f'wrote {buf}' def pretty_path(self, path): - return f"{path}/pretty/{self.ver}" + return f'{path}/pretty/{self.ver}' def setUp(self): self.papermill_io = PapermillIO() @@ -73,8 +73,8 @@ def setUp(self): self.fake1 = self.FakeHandler(1) self.fake2 = self.FakeHandler(2) self.fake_byte1 = self.FakeByteHandler(1) - self.papermill_io.register("fake", self.fake1) - self.papermill_io_bytes.register("notebooks", self.fake_byte1) + self.papermill_io.register('fake', self.fake1) + self.papermill_io_bytes.register('notebooks', self.fake_byte1) self.old_papermill_io = iorw.papermill_io iorw.papermill_io = self.papermill_io @@ -83,21 +83,21 @@ def tearDown(self): iorw.papermill_io = self.old_papermill_io def test_get_handler(self): - self.assertEqual(self.papermill_io.get_handler("fake"), self.fake1) + self.assertEqual(self.papermill_io.get_handler('fake'), self.fake1) def test_get_local_handler(self): with self.assertRaises(PapermillException): - self.papermill_io.get_handler("dne") + self.papermill_io.get_handler('dne') - self.papermill_io.register("local", self.fake2) - self.assertEqual(self.papermill_io.get_handler("dne"), self.fake2) + self.papermill_io.register('local', self.fake2) + self.assertEqual(self.papermill_io.get_handler('dne'), self.fake2) def test_get_no_io_handler(self): self.assertIsInstance(self.papermill_io.get_handler(None), NoIOHandler) def test_get_notebook_node_handler(self): test_nb = nbformat.read( - get_notebook_path("test_notebooknode_io.ipynb"), as_version=4 + get_notebook_path('test_notebooknode_io.ipynb'), as_version=4 ) self.assertIsInstance( self.papermill_io.get_handler(test_nb), NotebookNodeHandler @@ -105,94 +105,94 @@ def test_get_notebook_node_handler(self): def test_entrypoint_register(self): fake_entrypoint = Mock(load=Mock()) - fake_entrypoint.name = "fake-from-entry-point://" + fake_entrypoint.name = 'fake-from-entry-point://' with patch( - "entrypoints.get_group_all", return_value=[fake_entrypoint] + 'entrypoints.get_group_all', return_value=[fake_entrypoint] ) as mock_get_group_all: self.papermill_io.register_entry_points() - mock_get_group_all.assert_called_once_with("papermill.io") - fake_ = self.papermill_io.get_handler("fake-from-entry-point://") + mock_get_group_all.assert_called_once_with('papermill.io') + fake_ = self.papermill_io.get_handler('fake-from-entry-point://') assert fake_ == fake_entrypoint.load.return_value def test_register_ordering(self): # Should match fake1 with fake2 path - self.assertEqual(self.papermill_io.get_handler("fake2/path"), self.fake1) + self.assertEqual(self.papermill_io.get_handler('fake2/path'), self.fake1) self.papermill_io.reset() - self.papermill_io.register("fake", self.fake1) - self.papermill_io.register("fake2", self.fake2) + self.papermill_io.register('fake', self.fake1) + self.papermill_io.register('fake2', self.fake2) # Should match fake1 with fake1 path, and NOT fake2 path/match - self.assertEqual(self.papermill_io.get_handler("fake/path"), self.fake1) + self.assertEqual(self.papermill_io.get_handler('fake/path'), self.fake1) # Should match fake2 with fake2 path - self.assertEqual(self.papermill_io.get_handler("fake2/path"), self.fake2) + self.assertEqual(self.papermill_io.get_handler('fake2/path'), self.fake2) def test_read(self): self.assertEqual( - self.papermill_io.read("fake/path"), "contents from fake/path for version 1" + self.papermill_io.read('fake/path'), 'contents from fake/path for version 1' ) def test_read_bytes(self): self.assertIsNotNone( self.papermill_io_bytes.read( - "notebooks/gcs/gcs_in/gcs-simple_notebook.ipynb" + 'notebooks/gcs/gcs_in/gcs-simple_notebook.ipynb' ) ) def test_read_with_no_file_extension(self): with pytest.warns(UserWarning): - self.papermill_io.read("fake/path") + self.papermill_io.read('fake/path') def test_read_with_invalid_file_extension(self): with pytest.warns(UserWarning): - self.papermill_io.read("fake/path/fakeinputpath.ipynb1") + self.papermill_io.read('fake/path/fakeinputpath.ipynb1') def test_read_with_valid_file_extension(self): with pytest.warns(None) as warns: - self.papermill_io.read("fake/path/fakeinputpath.ipynb") + self.papermill_io.read('fake/path/fakeinputpath.ipynb') self.assertEqual(len(warns), 0) def test_read_yaml_with_no_file_extension(self): with pytest.warns(UserWarning): - read_yaml_file("fake/path") + read_yaml_file('fake/path') def test_read_yaml_with_invalid_file_extension(self): with pytest.warns(UserWarning): - read_yaml_file("fake/path/fakeinputpath.ipynb") + read_yaml_file('fake/path/fakeinputpath.ipynb') def test_read_stdin(self): - file_content = "Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ" - with patch("sys.stdin", io.StringIO(file_content)): - self.assertEqual(self.old_papermill_io.read("-"), file_content) + file_content = 'Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ' + with patch('sys.stdin', io.StringIO(file_content)): + self.assertEqual(self.old_papermill_io.read('-'), file_content) def test_listdir(self): - self.assertEqual(self.papermill_io.listdir("fake/path"), ["fake", "contents"]) + self.assertEqual(self.papermill_io.listdir('fake/path'), ['fake', 'contents']) def test_write(self): - self.assertEqual(self.papermill_io.write("buffer", "fake/path"), "wrote buffer") + self.assertEqual(self.papermill_io.write('buffer', 'fake/path'), 'wrote buffer') def test_write_with_no_file_extension(self): with pytest.warns(UserWarning): - self.papermill_io.write("buffer", "fake/path") + self.papermill_io.write('buffer', 'fake/path') def test_write_with_path_of_none(self): - self.assertIsNone(self.papermill_io.write("buffer", None)) + self.assertIsNone(self.papermill_io.write('buffer', None)) def test_write_with_invalid_file_extension(self): with pytest.warns(UserWarning): - self.papermill_io.write("buffer", "fake/path/fakeoutputpath.ipynb1") + self.papermill_io.write('buffer', 'fake/path/fakeoutputpath.ipynb1') def test_write_stdout(self): - file_content = "Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ" + file_content = 'Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ' out = io.BytesIO() - with patch("sys.stdout", out): - self.old_papermill_io.write(file_content, "-") - self.assertEqual(out.getvalue(), file_content.encode("utf-8")) + with patch('sys.stdout', out): + self.old_papermill_io.write(file_content, '-') + self.assertEqual(out.getvalue(), file_content.encode('utf-8')) def test_pretty_path(self): self.assertEqual( - self.papermill_io.pretty_path("fake/path"), "fake/path/pretty/1" + self.papermill_io.pretty_path('fake/path'), 'fake/path/pretty/1' ) @@ -203,35 +203,35 @@ class TestLocalHandler(unittest.TestCase): def test_read_utf8(self): self.assertEqual( - LocalHandler().read(os.path.join(FIXTURE_PATH, "rock.txt")).strip(), "✄" + LocalHandler().read(os.path.join(FIXTURE_PATH, 'rock.txt')).strip(), '✄' ) def test_write_utf8(self): with TemporaryDirectory() as temp_dir: - path = os.path.join(temp_dir, "paper.txt") - LocalHandler().write("✄", path) - with open(path, encoding="utf-8") as f: - self.assertEqual(f.read().strip(), "✄") + path = os.path.join(temp_dir, 'paper.txt') + LocalHandler().write('✄', path) + with open(path, encoding='utf-8') as f: + self.assertEqual(f.read().strip(), '✄') def test_write_no_directory_exists(self): with self.assertRaises(FileNotFoundError): - LocalHandler().write("buffer", "fake/path/fakenb.ipynb") + LocalHandler().write('buffer', 'fake/path/fakenb.ipynb') def test_write_local_directory(self): - with patch.object(io, "open"): + with patch.object(io, 'open'): # Shouldn't raise with missing directory - LocalHandler().write("buffer", "local.ipynb") + LocalHandler().write('buffer', 'local.ipynb') def test_write_passed_cwd(self): with TemporaryDirectory() as temp_dir: handler = LocalHandler() handler.cwd(temp_dir) - handler.write("✄", "paper.txt") + handler.write('✄', 'paper.txt') - path = os.path.join(temp_dir, "paper.txt") - with open(path, encoding="utf-8") as f: - self.assertEqual(f.read().strip(), "✄") + path = os.path.join(temp_dir, 'paper.txt') + with open(path, encoding='utf-8') as f: + self.assertEqual(f.read().strip(), '✄') def test_local_file_io_cwd(self): with TemporaryDirectory() as temp_dir: @@ -241,16 +241,16 @@ def test_local_file_io_cwd(self): try: local_handler = LocalHandler() papermill_io.reset() - papermill_io.register("local", local_handler) + papermill_io.register('local', local_handler) with local_file_io_cwd(temp_dir): - local_handler.write("✄", "paper.txt") - self.assertEqual(local_handler.read("paper.txt"), "✄") + local_handler.write('✄', 'paper.txt') + self.assertEqual(local_handler.read('paper.txt'), '✄') # Double check it used the tmpdir - path = os.path.join(temp_dir, "paper.txt") - with open(path, encoding="utf-8") as f: - self.assertEqual(f.read().strip(), "✄") + path = os.path.join(temp_dir, 'paper.txt') + with open(path, encoding='utf-8') as f: + self.assertEqual(f.read().strip(), '✄') finally: papermill_io.handlers = handlers @@ -263,7 +263,7 @@ def test_invalid_string(self): # a string from which we can't extract a notebook is assumed to # be a file and an IOError will be raised with self.assertRaises(IOError): - LocalHandler().read("a random string") + LocalHandler().read('a random string') class TestNoIOHandler(unittest.TestCase): @@ -276,10 +276,10 @@ def test_raises_on_listdir(self): NoIOHandler().listdir(None) def test_write_returns_none(self): - self.assertIsNone(NoIOHandler().write("buf", None)) + self.assertIsNone(NoIOHandler().write('buf', None)) def test_pretty_path(self): - expect = "Notebook will not be saved" + expect = 'Notebook will not be saved' self.assertEqual(NoIOHandler().pretty_path(None), expect) @@ -291,20 +291,20 @@ class TestADLHandler(unittest.TestCase): def setUp(self): self.handler = ADLHandler() self.handler._client = Mock( - read=Mock(return_value=["foo", "bar", "baz"]), - listdir=Mock(return_value=["foo", "bar", "baz"]), + read=Mock(return_value=['foo', 'bar', 'baz']), + listdir=Mock(return_value=['foo', 'bar', 'baz']), write=Mock(), ) def test_read(self): - self.assertEqual(self.handler.read("some_path"), "foo\nbar\nbaz") + self.assertEqual(self.handler.read('some_path'), 'foo\nbar\nbaz') def test_listdir(self): - self.assertEqual(self.handler.listdir("some_path"), ["foo", "bar", "baz"]) + self.assertEqual(self.handler.listdir('some_path'), ['foo', 'bar', 'baz']) def test_write(self): - self.handler.write("foo", "bar") - self.handler._client.write.assert_called_once_with("foo", "bar") + self.handler.write('foo', 'bar') + self.handler._client.write.assert_called_once_with('foo', 'bar') class TestHttpHandler(unittest.TestCase): @@ -318,23 +318,23 @@ def test_listdir(self): `listdir` function is not supported. """ with self.assertRaises(PapermillException) as e: - HttpHandler.listdir("http://example.com") + HttpHandler.listdir('http://example.com') - self.assertEqual(f"{e.exception}", "listdir is not supported by HttpHandler") + self.assertEqual(f'{e.exception}', 'listdir is not supported by HttpHandler') def test_read(self): """ Tests that the `read` function performs a request to the giving path and returns the response. """ - path = "http://example.com" - text = "request test response" + path = 'http://example.com' + text = 'request test response' - with patch("papermill.iorw.requests.get") as mock_get: + with patch('papermill.iorw.requests.get') as mock_get: mock_get.return_value = Mock(text=text) self.assertEqual(HttpHandler.read(path), text) mock_get.assert_called_once_with( - path, headers={"Accept": "application/json"} + path, headers={'Accept': 'application/json'} ) def test_write(self): @@ -342,10 +342,10 @@ def test_write(self): Tests that the `write` function performs a put request to the given path. """ - path = "http://example.com" + path = 'http://example.com' buf = '{"papermill": true}' - with patch("papermill.iorw.requests.put") as mock_put: + with patch('papermill.iorw.requests.put') as mock_put: HttpHandler.write(buf, path) mock_put.assert_called_once_with(path, json=json.loads(buf)) @@ -353,7 +353,7 @@ def test_write_failure(self): """ Tests that the `write` function raises on failure to put the buffer. """ - path = "http://localhost:9999" + path = 'http://localhost:9999' buf = '{"papermill": true}' with self.assertRaises(ConnectionError): @@ -361,35 +361,35 @@ def test_write_failure(self): class TestStreamHandler(unittest.TestCase): - @patch("sys.stdin", io.StringIO("mock stream")) + @patch('sys.stdin', io.StringIO('mock stream')) def test_read_from_stdin(self): - result = StreamHandler().read("foo") - self.assertEqual(result, "mock stream") + result = StreamHandler().read('foo') + self.assertEqual(result, 'mock stream') def test_raises_on_listdir(self): with self.assertRaises(PapermillException): StreamHandler().listdir(None) - @patch("sys.stdout") + @patch('sys.stdout') def test_write_to_stdout_buffer(self, mock_stdout): mock_stdout.buffer = io.BytesIO() - StreamHandler().write("mock stream", "foo") - self.assertEqual(mock_stdout.buffer.getbuffer(), b"mock stream") + StreamHandler().write('mock stream', 'foo') + self.assertEqual(mock_stdout.buffer.getbuffer(), b'mock stream') - @patch("sys.stdout", new_callable=io.BytesIO) + @patch('sys.stdout', new_callable=io.BytesIO) def test_write_to_stdout(self, mock_stdout): - StreamHandler().write("mock stream", "foo") - self.assertEqual(mock_stdout.getbuffer(), b"mock stream") + StreamHandler().write('mock stream', 'foo') + self.assertEqual(mock_stdout.getbuffer(), b'mock stream') def test_pretty_path_returns_input_path(self): '''Should return the input str, which often is the default registered schema "-"''' - self.assertEqual(StreamHandler().pretty_path("foo"), "foo") + self.assertEqual(StreamHandler().pretty_path('foo'), 'foo') class TestNotebookNodeHandler(unittest.TestCase): def test_read_notebook_node(self): input_nb = nbformat.read( - get_notebook_path("test_notebooknode_io.ipynb"), as_version=4 + get_notebook_path('test_notebooknode_io.ipynb'), as_version=4 ) result = NotebookNodeHandler().read(input_nb) expect = ( @@ -403,12 +403,12 @@ def test_read_notebook_node(self): def test_raises_on_listdir(self): with self.assertRaises(PapermillException): - NotebookNodeHandler().listdir("foo") + NotebookNodeHandler().listdir('foo') def test_raises_on_write(self): with self.assertRaises(PapermillException): - NotebookNodeHandler().write("foo", "bar") + NotebookNodeHandler().write('foo', 'bar') def test_pretty_path(self): - expect = "NotebookNode object" - self.assertEqual(NotebookNodeHandler().pretty_path("foo"), expect) + expect = 'NotebookNode object' + self.assertEqual(NotebookNodeHandler().pretty_path('foo'), expect) diff --git a/papermill/tests/test_parameterize.py b/papermill/tests/test_parameterize.py index 4e2df4f4..a3e9dcff 100644 --- a/papermill/tests/test_parameterize.py +++ b/papermill/tests/test_parameterize.py @@ -1,14 +1,14 @@ import unittest +from datetime import datetime -from ..iorw import load_notebook_node from ..exceptions import PapermillMissingParameterException +from ..iorw import load_notebook_node from ..parameterize import ( + add_builtin_parameters, parameterize_notebook, parameterize_path, - add_builtin_parameters, ) from . import get_notebook_path -from datetime import datetime class TestNotebookParametrizing(unittest.TestCase): @@ -17,189 +17,189 @@ def count_nb_injected_parameter_cells(self, nb): [ c for c in nb.cells - if "injected-parameters" in c.get("metadata", {}).get("tags", []) + if 'injected-parameters' in c.get('metadata', {}).get('tags', []) ] ) def test_no_tag_copying(self): # Test that injected cell does not copy other tags - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) - test_nb.cells[0]["metadata"]["tags"].append("some tag") + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) + test_nb.cells[0]['metadata']['tags'].append('some tag') - test_nb = parameterize_notebook(test_nb, {"msg": "Hello"}) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}) cell_zero = test_nb.cells[0] - self.assertTrue("some tag" in cell_zero.get("metadata").get("tags")) - self.assertTrue("parameters" in cell_zero.get("metadata").get("tags")) + self.assertTrue('some tag' in cell_zero.get('metadata').get('tags')) + self.assertTrue('parameters' in cell_zero.get('metadata').get('tags')) cell_one = test_nb.cells[1] - self.assertTrue("some tag" not in cell_one.get("metadata").get("tags")) - self.assertTrue("injected-parameters" in cell_one.get("metadata").get("tags")) + self.assertTrue('some tag' not in cell_one.get('metadata').get('tags')) + self.assertTrue('injected-parameters' in cell_one.get('metadata').get('tags')) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) def test_injected_parameters_tag(self): - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) - test_nb = parameterize_notebook(test_nb, {"msg": "Hello"}) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}) cell_zero = test_nb.cells[0] - self.assertTrue("parameters" in cell_zero.get("metadata").get("tags")) + self.assertTrue('parameters' in cell_zero.get('metadata').get('tags')) self.assertTrue( - "injected-parameters" not in cell_zero.get("metadata").get("tags") + 'injected-parameters' not in cell_zero.get('metadata').get('tags') ) cell_one = test_nb.cells[1] - self.assertTrue("injected-parameters" in cell_one.get("metadata").get("tags")) + self.assertTrue('injected-parameters' in cell_one.get('metadata').get('tags')) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) def test_repeated_run_injected_parameters_tag(self): - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 0) - test_nb = parameterize_notebook(test_nb, {"msg": "Hello"}) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) - parameterize_notebook(test_nb, {"msg": "Hello"}) + parameterize_notebook(test_nb, {'msg': 'Hello'}) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) def test_no_parameter_tag(self): - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) - test_nb.cells[0]["metadata"]["tags"] = [] + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) + test_nb.cells[0]['metadata']['tags'] = [] - test_nb = parameterize_notebook(test_nb, {"msg": "Hello"}) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}) cell_zero = test_nb.cells[0] - self.assertTrue("injected-parameters" in cell_zero.get("metadata").get("tags")) - self.assertTrue("parameters" not in cell_zero.get("metadata").get("tags")) + self.assertTrue('injected-parameters' in cell_zero.get('metadata').get('tags')) + self.assertTrue('parameters' not in cell_zero.get('metadata').get('tags')) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) def test_repeated_run_no_parameters_tag(self): - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) - test_nb.cells[0]["metadata"]["tags"] = [] + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) + test_nb.cells[0]['metadata']['tags'] = [] self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 0) - test_nb = parameterize_notebook(test_nb, {"msg": "Hello"}) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) - test_nb = parameterize_notebook(test_nb, {"msg": "Hello"}) + test_nb = parameterize_notebook(test_nb, {'msg': 'Hello'}) self.assertEqual(self.count_nb_injected_parameter_cells(test_nb), 1) def test_custom_comment(self): - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) test_nb = parameterize_notebook( - test_nb, {"msg": "Hello"}, comment="This is a custom comment" + test_nb, {'msg': 'Hello'}, comment='This is a custom comment' ) cell_one = test_nb.cells[1] - first_line = cell_one["source"].split("\n")[0] - self.assertEqual(first_line, "# This is a custom comment") + first_line = cell_one['source'].split('\n')[0] + self.assertEqual(first_line, '# This is a custom comment') class TestBuiltinParameters(unittest.TestCase): def test_add_builtin_parameters_keeps_provided_parameters(self): - with_builtin_parameters = add_builtin_parameters({"foo": "bar"}) - self.assertEqual(with_builtin_parameters["foo"], "bar") + with_builtin_parameters = add_builtin_parameters({'foo': 'bar'}) + self.assertEqual(with_builtin_parameters['foo'], 'bar') def test_add_builtin_parameters_adds_dict_of_builtins(self): - with_builtin_parameters = add_builtin_parameters({"foo": "bar"}) - self.assertIn("pm", with_builtin_parameters) - self.assertIsInstance(with_builtin_parameters["pm"], type({})) + with_builtin_parameters = add_builtin_parameters({'foo': 'bar'}) + self.assertIn('pm', with_builtin_parameters) + self.assertIsInstance(with_builtin_parameters['pm'], type({})) def test_add_builtin_parameters_allows_to_override_builtin(self): - with_builtin_parameters = add_builtin_parameters({"pm": "foo"}) - self.assertEqual(with_builtin_parameters["pm"], "foo") + with_builtin_parameters = add_builtin_parameters({'pm': 'foo'}) + self.assertEqual(with_builtin_parameters['pm'], 'foo') def test_builtin_parameters_include_run_uuid(self): - with_builtin_parameters = add_builtin_parameters({"foo": "bar"}) - self.assertIn("run_uuid", with_builtin_parameters["pm"]) + with_builtin_parameters = add_builtin_parameters({'foo': 'bar'}) + self.assertIn('run_uuid', with_builtin_parameters['pm']) def test_builtin_parameters_include_current_datetime_local(self): - with_builtin_parameters = add_builtin_parameters({"foo": "bar"}) - self.assertIn("current_datetime_local", with_builtin_parameters["pm"]) + with_builtin_parameters = add_builtin_parameters({'foo': 'bar'}) + self.assertIn('current_datetime_local', with_builtin_parameters['pm']) self.assertIsInstance( - with_builtin_parameters["pm"]["current_datetime_local"], datetime + with_builtin_parameters['pm']['current_datetime_local'], datetime ) def test_builtin_parameters_include_current_datetime_utc(self): - with_builtin_parameters = add_builtin_parameters({"foo": "bar"}) - self.assertIn("current_datetime_utc", with_builtin_parameters["pm"]) + with_builtin_parameters = add_builtin_parameters({'foo': 'bar'}) + self.assertIn('current_datetime_utc', with_builtin_parameters['pm']) self.assertIsInstance( - with_builtin_parameters["pm"]["current_datetime_utc"], datetime + with_builtin_parameters['pm']['current_datetime_utc'], datetime ) class TestPathParameterizing(unittest.TestCase): def test_plain_text_path_with_empty_parameters_object(self): - self.assertEqual(parameterize_path("foo/bar", {}), "foo/bar") + self.assertEqual(parameterize_path('foo/bar', {}), 'foo/bar') def test_plain_text_path_with_none_parameters(self): - self.assertEqual(parameterize_path("foo/bar", None), "foo/bar") + self.assertEqual(parameterize_path('foo/bar', None), 'foo/bar') def test_plain_text_path_with_unused_parameters(self): - self.assertEqual(parameterize_path("foo/bar", {"baz": "quux"}), "foo/bar") + self.assertEqual(parameterize_path('foo/bar', {'baz': 'quux'}), 'foo/bar') def test_path_with_single_parameter(self): self.assertEqual( - parameterize_path("foo/bar/{baz}", {"baz": "quux"}), "foo/bar/quux" + parameterize_path('foo/bar/{baz}', {'baz': 'quux'}), 'foo/bar/quux' ) def test_path_with_boolean_parameter(self): self.assertEqual( - parameterize_path("foo/bar/{baz}", {"baz": False}), "foo/bar/False" + parameterize_path('foo/bar/{baz}', {'baz': False}), 'foo/bar/False' ) def test_path_with_dict_parameter(self): self.assertEqual( - parameterize_path("foo/{bar[baz]}/", {"bar": {"baz": "quux"}}), "foo/quux/" + parameterize_path('foo/{bar[baz]}/', {'bar': {'baz': 'quux'}}), 'foo/quux/' ) def test_path_with_list_parameter(self): self.assertEqual( - parameterize_path("foo/{bar[0]}/", {"bar": [1, 2, 3]}), "foo/1/" + parameterize_path('foo/{bar[0]}/', {'bar': [1, 2, 3]}), 'foo/1/' ) self.assertEqual( - parameterize_path("foo/{bar[2]}/", {"bar": [1, 2, 3]}), "foo/3/" + parameterize_path('foo/{bar[2]}/', {'bar': [1, 2, 3]}), 'foo/3/' ) def test_path_with_none_parameter(self): self.assertEqual( - parameterize_path("foo/bar/{baz}", {"baz": None}), "foo/bar/None" + parameterize_path('foo/bar/{baz}', {'baz': None}), 'foo/bar/None' ) def test_path_with_numeric_parameter(self): - self.assertEqual(parameterize_path("foo/bar/{baz}", {"baz": 42}), "foo/bar/42") + self.assertEqual(parameterize_path('foo/bar/{baz}', {'baz': 42}), 'foo/bar/42') def test_path_with_numeric_format_string(self): self.assertEqual( - parameterize_path("foo/bar/{baz:03d}", {"baz": 42}), "foo/bar/042" + parameterize_path('foo/bar/{baz:03d}', {'baz': 42}), 'foo/bar/042' ) def test_path_with_float_format_string(self): self.assertEqual( - parameterize_path("foo/bar/{baz:.03f}", {"baz": 0.3}), "foo/bar/0.300" + parameterize_path('foo/bar/{baz:.03f}', {'baz': 0.3}), 'foo/bar/0.300' ) def test_path_with_multiple_parameter(self): self.assertEqual( - parameterize_path("{foo}/{baz}", {"foo": "bar", "baz": "quux"}), "bar/quux" + parameterize_path('{foo}/{baz}', {'foo': 'bar', 'baz': 'quux'}), 'bar/quux' ) def test_parameterized_path_with_undefined_parameter(self): with self.assertRaises(PapermillMissingParameterException) as context: - parameterize_path("{foo}", {}) + parameterize_path('{foo}', {}) self.assertEqual(str(context.exception), "Missing parameter 'foo'") def test_parameterized_path_with_none_parameters(self): with self.assertRaises(PapermillMissingParameterException) as context: - parameterize_path("{foo}", None) + parameterize_path('{foo}', None) self.assertEqual(str(context.exception), "Missing parameter 'foo'") def test_path_of_none_returns_none(self): - self.assertIsNone(parameterize_path(path=None, parameters={"foo": "bar"})) + self.assertIsNone(parameterize_path(path=None, parameters={'foo': 'bar'})) self.assertIsNone(parameterize_path(path=None, parameters=None)) def test_path_of_notebook_node_returns_input(self): - test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb")) + test_nb = load_notebook_node(get_notebook_path('simple_execute.ipynb')) result_nb = parameterize_path(test_nb, parameters=None) self.assertIs(result_nb, test_nb) diff --git a/papermill/tests/test_s3.py b/papermill/tests/test_s3.py index 156b4a7a..de86f5b6 100644 --- a/papermill/tests/test_s3.py +++ b/papermill/tests/test_s3.py @@ -1,52 +1,52 @@ # The following tests are purposely limited to the exposed interface by iorw.py import os.path -import pytest + import boto3 import moto - +import pytest from moto import mock_s3 -from ..s3 import Bucket, Prefix, Key, S3 +from ..s3 import S3, Bucket, Key, Prefix @pytest.fixture def bucket_no_service(): """Returns a bucket instance with no services""" - return Bucket("my_test_bucket") + return Bucket('my_test_bucket') @pytest.fixture def bucket_with_service(): """Returns a bucket instance with a service""" - return Bucket("my_sqs_bucket", ["sqs"]) + return Bucket('my_sqs_bucket', ['sqs']) @pytest.fixture def bucket_sqs(): """Returns a bucket instance with a sqs service""" - return Bucket("my_sqs_bucket", ["sqs"]) + return Bucket('my_sqs_bucket', ['sqs']) @pytest.fixture def bucket_ec2(): """Returns a bucket instance with a ec2 service""" - return Bucket("my_sqs_bucket", ["ec2"]) + return Bucket('my_sqs_bucket', ['ec2']) @pytest.fixture def bucket_multiservice(): """Returns a bucket instance with a ec2 service""" - return Bucket("my_sqs_bucket", ["ec2", "sqs"]) + return Bucket('my_sqs_bucket', ['ec2', 'sqs']) def test_bucket_init(): - assert Bucket("my_test_bucket") - assert Bucket("my_sqs_bucket", "sqs") + assert Bucket('my_test_bucket') + assert Bucket('my_sqs_bucket', 'sqs') def test_bucket_defaults(): - name = "a bucket" + name = 'a bucket' b1 = Bucket(name) b2 = Bucket(name, None) @@ -86,19 +86,19 @@ def test_prefix_init(): Prefix(service=None) with pytest.raises(TypeError): - Prefix("my_test_prefix") + Prefix('my_test_prefix') - b1 = Bucket("my_test_bucket") - p1 = Prefix(b1, "sqs_test", service="sqs") - assert Prefix(b1, "test_bucket") - assert Prefix(b1, "test_bucket", service=None) - assert Prefix(b1, "test_bucket", None) + b1 = Bucket('my_test_bucket') + p1 = Prefix(b1, 'sqs_test', service='sqs') + assert Prefix(b1, 'test_bucket') + assert Prefix(b1, 'test_bucket', service=None) + assert Prefix(b1, 'test_bucket', None) assert p1.bucket.service == p1.service def test_prefix_defaults(): - bucket = Bucket("my data pool") - name = "bigdata bucket" + bucket = Bucket('my data pool') + name = 'bigdata bucket' p1 = Prefix(bucket, name) p2 = Prefix(bucket, name, None) @@ -107,13 +107,13 @@ def test_prefix_defaults(): def test_prefix_str(bucket_sqs): - p1 = Prefix(bucket_sqs, "sqs_prefix_test", "sqs") - assert str(p1) == "s3://" + str(bucket_sqs) + "/sqs_prefix_test" + p1 = Prefix(bucket_sqs, 'sqs_prefix_test', 'sqs') + assert str(p1) == 's3://' + str(bucket_sqs) + '/sqs_prefix_test' def test_prefix_repr(bucket_sqs): - p1 = Prefix(bucket_sqs, "sqs_prefix_test", "sqs") - assert repr(p1) == "s3://" + str(bucket_sqs) + "/sqs_prefix_test" + p1 = Prefix(bucket_sqs, 'sqs_prefix_test', 'sqs') + assert repr(p1) == 's3://' + str(bucket_sqs) + '/sqs_prefix_test' def test_key_init(): @@ -121,13 +121,13 @@ def test_key_init(): def test_key_repr(): - k = Key("foo", "bar") - assert repr(k) == "s3://foo/bar" + k = Key('foo', 'bar') + assert repr(k) == 's3://foo/bar' def test_key_defaults(): - bucket = Bucket("my data pool") - name = "bigdata bucket" + bucket = Bucket('my data pool') + name = 'bigdata bucket' k1 = Key(bucket, name) k2 = Key(bucket, name, None, None, None, None, None) @@ -148,36 +148,36 @@ def test_s3_defaults(): local_dir = os.path.dirname(os.path.abspath(__file__)) -test_bucket_name = "test-pm-bucket" -test_string = "Hello" -test_file_path = "notebooks/s3/s3_in/s3-simple_notebook.ipynb" -test_empty_file_path = "notebooks/s3/s3_in/s3-empty.ipynb" +test_bucket_name = 'test-pm-bucket' +test_string = 'Hello' +test_file_path = 'notebooks/s3/s3_in/s3-simple_notebook.ipynb' +test_empty_file_path = 'notebooks/s3/s3_in/s3-empty.ipynb' with open(os.path.join(local_dir, test_file_path)) as f: test_nb_content = f.read() -no_empty_lines = lambda s: "\n".join([l for l in s.split("\n") if len(l) > 0]) +no_empty_lines = lambda s: '\n'.join([l for l in s.split('\n') if len(l) > 0]) test_clean_nb_content = no_empty_lines(test_nb_content) -read_from_gen = lambda g: "\n".join(g) +read_from_gen = lambda g: '\n'.join(g) -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def s3_client(): mock_s3 = moto.mock_s3() mock_s3.start() - client = boto3.client("s3") + client = boto3.client('s3') client.create_bucket( Bucket=test_bucket_name, - CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + CreateBucketConfiguration={'LocationConstraint': 'us-west-2'}, ) client.put_object(Bucket=test_bucket_name, Key=test_file_path, Body=test_nb_content) - client.put_object(Bucket=test_bucket_name, Key=test_empty_file_path, Body="") + client.put_object(Bucket=test_bucket_name, Key=test_empty_file_path, Body='') yield S3() try: client.delete_object(Bucket=test_bucket_name, Key=test_file_path) - client.delete_object(Bucket=test_bucket_name, Key=test_file_path + ".txt") + client.delete_object(Bucket=test_bucket_name, Key=test_file_path + '.txt') client.delete_object(Bucket=test_bucket_name, Key=test_empty_file_path) except Exception: pass @@ -185,19 +185,19 @@ def s3_client(): def test_s3_read(s3_client): - s3_path = f"s3://{test_bucket_name}/{test_file_path}" + s3_path = f's3://{test_bucket_name}/{test_file_path}' data = read_from_gen(s3_client.read(s3_path)) assert data == test_clean_nb_content def test_s3_read_empty(s3_client): - s3_path = f"s3://{test_bucket_name}/{test_empty_file_path}" + s3_path = f's3://{test_bucket_name}/{test_empty_file_path}' data = read_from_gen(s3_client.read(s3_path)) - assert data == "" + assert data == '' def test_s3_write(s3_client): - s3_path = f"s3://{test_bucket_name}/{test_file_path}.txt" + s3_path = f's3://{test_bucket_name}/{test_file_path}.txt' s3_client.cp_string(test_string, s3_path) data = read_from_gen(s3_client.read(s3_path)) @@ -205,7 +205,7 @@ def test_s3_write(s3_client): def test_s3_overwrite(s3_client): - s3_path = f"s3://{test_bucket_name}/{test_file_path}" + s3_path = f's3://{test_bucket_name}/{test_file_path}' s3_client.cp_string(test_string, s3_path) data = read_from_gen(s3_client.read(s3_path)) @@ -214,8 +214,8 @@ def test_s3_overwrite(s3_client): def test_s3_listdir(s3_client): dir_name = os.path.dirname(test_file_path) - s3_dir = f"s3://{test_bucket_name}/{dir_name}" - s3_path = f"s3://{test_bucket_name}/{test_file_path}" + s3_dir = f's3://{test_bucket_name}/{dir_name}' + s3_path = f's3://{test_bucket_name}/{test_file_path}' dir_listings = s3_client.listdir(s3_dir) assert len(dir_listings) == 2 assert s3_path in dir_listings diff --git a/papermill/tests/test_translators.py b/papermill/tests/test_translators.py index 906784f6..c3fb4b62 100644 --- a/papermill/tests/test_translators.py +++ b/papermill/tests/test_translators.py @@ -1,8 +1,7 @@ -import pytest - -from unittest.mock import Mock from collections import OrderedDict +from unittest.mock import Mock +import pytest from nbformat.v4 import new_code_cell from .. import translators @@ -11,29 +10,29 @@ @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'), - ({"foo": "bar"}, '{"foo": "bar"}'), - ({"foo": '"bar"'}, '{"foo": "\\"bar\\""}'), - ({"foo": ["bar"]}, '{"foo": ["bar"]}'), - ({"foo": {"bar": "baz"}}, '{"foo": {"bar": "baz"}}'), - ({"foo": {"bar": '"baz"'}}, '{"foo": {"bar": "\\"baz\\""}}'), - (["foo"], '["foo"]'), - (["foo", '"bar"'], '["foo", "\\"bar\\""]'), - ([{"foo": "bar"}], '[{"foo": "bar"}]'), - ([{"foo": '"bar"'}], '[{"foo": "\\"bar\\""}]'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (float("nan"), "float('nan')"), - (float("-inf"), "float('-inf')"), - (float("inf"), "float('inf')"), - (True, "True"), - (False, "False"), - (None, "None"), + ({'foo': 'bar'}, '{"foo": "bar"}'), + ({'foo': '"bar"'}, '{"foo": "\\"bar\\""}'), + ({'foo': ['bar']}, '{"foo": ["bar"]}'), + ({'foo': {'bar': 'baz'}}, '{"foo": {"bar": "baz"}}'), + ({'foo': {'bar': '"baz"'}}, '{"foo": {"bar": "\\"baz\\""}}'), + (['foo'], '["foo"]'), + (['foo', '"bar"'], '["foo", "\\"bar\\""]'), + ([{'foo': 'bar'}], '[{"foo": "bar"}]'), + ([{'foo': '"bar"'}], '[{"foo": "\\"bar\\""}]'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (float('nan'), "float('nan')"), + (float('-inf'), "float('-inf')"), + (float('inf'), "float('inf')"), + (True, 'True'), + (False, 'False'), + (None, 'None'), ], ) def test_translate_type_python(test_input, expected): @@ -41,16 +40,16 @@ def test_translate_type_python(test_input, expected): @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '# Parameters\nfoo = "bar"\n'), - ({"foo": True}, "# Parameters\nfoo = True\n"), - ({"foo": 5}, "# Parameters\nfoo = 5\n"), - ({"foo": 1.1}, "# Parameters\nfoo = 1.1\n"), - ({"foo": ["bar", "baz"]}, '# Parameters\nfoo = ["bar", "baz"]\n'), - ({"foo": {"bar": "baz"}}, '# Parameters\nfoo = {"bar": "baz"}\n'), + ({'foo': 'bar'}, '# Parameters\nfoo = "bar"\n'), + ({'foo': True}, '# Parameters\nfoo = True\n'), + ({'foo': 5}, '# Parameters\nfoo = 5\n'), + ({'foo': 1.1}, '# Parameters\nfoo = 1.1\n'), + ({'foo': ['bar', 'baz']}, '# Parameters\nfoo = ["bar", "baz"]\n'), + ({'foo': {'bar': 'baz'}}, '# Parameters\nfoo = {"bar": "baz"}\n'), ( - OrderedDict([["foo", "bar"], ["baz", ["buz"]]]), + OrderedDict([['foo', 'bar'], ['baz', ['buz']]]), '# Parameters\nfoo = "bar"\nbaz = ["buz"]\n', ), ], @@ -60,39 +59,39 @@ def test_translate_codify_python(parameters, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "#"), ("foo", "# foo"), ("['best effort']", "# ['best effort']")], + 'test_input,expected', + [('', '#'), ('foo', '# foo'), ("['best effort']", "# ['best effort']")], ) def test_translate_comment_python(test_input, expected): assert translators.PythonTranslator.comment(test_input) == expected @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("a = 2", [Parameter("a", "None", "2", "")]), - ("a: int = 2", [Parameter("a", "int", "2", "")]), - ("a = 2 # type:int", [Parameter("a", "int", "2", "")]), + ('a = 2', [Parameter('a', 'None', '2', '')]), + ('a: int = 2', [Parameter('a', 'int', '2', '')]), + ('a = 2 # type:int', [Parameter('a', 'int', '2', '')]), ( - "a = False # Nice variable a", - [Parameter("a", "None", "False", "Nice variable a")], + 'a = False # Nice variable a', + [Parameter('a', 'None', 'False', 'Nice variable a')], ), ( - "a: float = 2.258 # type: int Nice variable a", - [Parameter("a", "float", "2.258", "Nice variable a")], + 'a: float = 2.258 # type: int Nice variable a', + [Parameter('a', 'float', '2.258', 'Nice variable a')], ), ( "a = 'this is a string' # type: int Nice variable a", - [Parameter("a", "int", "'this is a string'", "Nice variable a")], + [Parameter('a', 'int', "'this is a string'", 'Nice variable a')], ), ( "a: List[str] = ['this', 'is', 'a', 'string', 'list'] # Nice variable a", [ Parameter( - "a", - "List[str]", + 'a', + 'List[str]', "['this', 'is', 'a', 'string', 'list']", - "Nice variable a", + 'Nice variable a', ) ], ), @@ -100,10 +99,10 @@ def test_translate_comment_python(test_input, expected): "a: List[str] = [\n 'this', # First\n 'is',\n 'a',\n 'string',\n 'list' # Last\n] # Nice variable a", # noqa [ Parameter( - "a", - "List[str]", + 'a', + 'List[str]', "['this','is','a','string','list']", - "Nice variable a", + 'Nice variable a', ) ], ), @@ -111,10 +110,10 @@ def test_translate_comment_python(test_input, expected): "a: List[str] = [\n 'this',\n 'is',\n 'a',\n 'string',\n 'list'\n] # Nice variable a", # noqa [ Parameter( - "a", - "List[str]", + 'a', + 'List[str]', "['this','is','a','string','list']", - "Nice variable a", + 'Nice variable a', ) ], ), @@ -132,12 +131,12 @@ def test_translate_comment_python(test_input, expected): """, [ Parameter( - "a", - "List[str]", + 'a', + 'List[str]', "['this','is','a','string','list']", - "Nice variable a", + 'Nice variable a', ), - Parameter("b", "float", "-2.3432", "My b variable"), + Parameter('b', 'float', '-2.3432', 'My b variable'), ], ), ], @@ -148,26 +147,26 @@ def test_inspect_python(test_input, expected): @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'), - ({"foo": "bar"}, 'list("foo" = "bar")'), - ({"foo": '"bar"'}, 'list("foo" = "\\"bar\\"")'), - ({"foo": ["bar"]}, 'list("foo" = list("bar"))'), - ({"foo": {"bar": "baz"}}, 'list("foo" = list("bar" = "baz"))'), - ({"foo": {"bar": '"baz"'}}, 'list("foo" = list("bar" = "\\"baz\\""))'), - (["foo"], 'list("foo")'), - (["foo", '"bar"'], 'list("foo", "\\"bar\\"")'), - ([{"foo": "bar"}], 'list(list("foo" = "bar"))'), - ([{"foo": '"bar"'}], 'list(list("foo" = "\\"bar\\""))'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (True, "TRUE"), - (False, "FALSE"), - (None, "NULL"), + ({'foo': 'bar'}, 'list("foo" = "bar")'), + ({'foo': '"bar"'}, 'list("foo" = "\\"bar\\"")'), + ({'foo': ['bar']}, 'list("foo" = list("bar"))'), + ({'foo': {'bar': 'baz'}}, 'list("foo" = list("bar" = "baz"))'), + ({'foo': {'bar': '"baz"'}}, 'list("foo" = list("bar" = "\\"baz\\""))'), + (['foo'], 'list("foo")'), + (['foo', '"bar"'], 'list("foo", "\\"bar\\"")'), + ([{'foo': 'bar'}], 'list(list("foo" = "bar"))'), + ([{'foo': '"bar"'}], 'list(list("foo" = "\\"bar\\""))'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (True, 'TRUE'), + (False, 'FALSE'), + (None, 'NULL'), ], ) def test_translate_type_r(test_input, expected): @@ -175,28 +174,28 @@ def test_translate_type_r(test_input, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "#"), ("foo", "# foo"), ("['best effort']", "# ['best effort']")], + 'test_input,expected', + [('', '#'), ('foo', '# foo'), ("['best effort']", "# ['best effort']")], ) def test_translate_comment_r(test_input, expected): assert translators.RTranslator.comment(test_input) == expected @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '# Parameters\nfoo = "bar"\n'), - ({"foo": True}, "# Parameters\nfoo = TRUE\n"), - ({"foo": 5}, "# Parameters\nfoo = 5\n"), - ({"foo": 1.1}, "# Parameters\nfoo = 1.1\n"), - ({"foo": ["bar", "baz"]}, '# Parameters\nfoo = list("bar", "baz")\n'), - ({"foo": {"bar": "baz"}}, '# Parameters\nfoo = list("bar" = "baz")\n'), + ({'foo': 'bar'}, '# Parameters\nfoo = "bar"\n'), + ({'foo': True}, '# Parameters\nfoo = TRUE\n'), + ({'foo': 5}, '# Parameters\nfoo = 5\n'), + ({'foo': 1.1}, '# Parameters\nfoo = 1.1\n'), + ({'foo': ['bar', 'baz']}, '# Parameters\nfoo = list("bar", "baz")\n'), + ({'foo': {'bar': 'baz'}}, '# Parameters\nfoo = list("bar" = "baz")\n'), ( - OrderedDict([["foo", "bar"], ["baz", ["buz"]]]), + OrderedDict([['foo', 'bar'], ['baz', ['buz']]]), '# Parameters\nfoo = "bar"\nbaz = list("buz")\n', ), # Underscores remove - ({"___foo": 5}, "# Parameters\nfoo = 5\n"), + ({'___foo': 5}, '# Parameters\nfoo = 5\n'), ], ) def test_translate_codify_r(parameters, expected): @@ -204,28 +203,28 @@ def test_translate_codify_r(parameters, expected): @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'), - ({"foo": "bar"}, 'Map("foo" -> "bar")'), - ({"foo": '"bar"'}, 'Map("foo" -> "\\"bar\\"")'), - ({"foo": ["bar"]}, 'Map("foo" -> Seq("bar"))'), - ({"foo": {"bar": "baz"}}, 'Map("foo" -> Map("bar" -> "baz"))'), - ({"foo": {"bar": '"baz"'}}, 'Map("foo" -> Map("bar" -> "\\"baz\\""))'), - (["foo"], 'Seq("foo")'), - (["foo", '"bar"'], 'Seq("foo", "\\"bar\\"")'), - ([{"foo": "bar"}], 'Seq(Map("foo" -> "bar"))'), - ([{"foo": '"bar"'}], 'Seq(Map("foo" -> "\\"bar\\""))'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (2147483648, "2147483648L"), - (-2147483649, "-2147483649L"), - (True, "true"), - (False, "false"), - (None, "None"), + ({'foo': 'bar'}, 'Map("foo" -> "bar")'), + ({'foo': '"bar"'}, 'Map("foo" -> "\\"bar\\"")'), + ({'foo': ['bar']}, 'Map("foo" -> Seq("bar"))'), + ({'foo': {'bar': 'baz'}}, 'Map("foo" -> Map("bar" -> "baz"))'), + ({'foo': {'bar': '"baz"'}}, 'Map("foo" -> Map("bar" -> "\\"baz\\""))'), + (['foo'], 'Seq("foo")'), + (['foo', '"bar"'], 'Seq("foo", "\\"bar\\"")'), + ([{'foo': 'bar'}], 'Seq(Map("foo" -> "bar"))'), + ([{'foo': '"bar"'}], 'Seq(Map("foo" -> "\\"bar\\""))'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (2147483648, '2147483648L'), + (-2147483649, '-2147483649L'), + (True, 'true'), + (False, 'false'), + (None, 'None'), ], ) def test_translate_type_scala(test_input, expected): @@ -233,19 +232,19 @@ def test_translate_type_scala(test_input, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "//"), ("foo", "// foo"), ("['best effort']", "// ['best effort']")], + 'test_input,expected', + [('', '//'), ('foo', '// foo'), ("['best effort']", "// ['best effort']")], ) def test_translate_comment_scala(test_input, expected): assert translators.ScalaTranslator.comment(test_input) == expected @pytest.mark.parametrize( - "input_name,input_value,expected", + 'input_name,input_value,expected', [ - ("foo", '""', 'val foo = ""'), - ("foo", '"bar"', 'val foo = "bar"'), - ("foo", 'Map("foo" -> "bar")', 'val foo = Map("foo" -> "bar")'), + ('foo', '""', 'val foo = ""'), + ('foo', '"bar"', 'val foo = "bar"'), + ('foo', 'Map("foo" -> "bar")', 'val foo = Map("foo" -> "bar")'), ], ) def test_translate_assign_scala(input_name, input_value, expected): @@ -253,16 +252,16 @@ def test_translate_assign_scala(input_name, input_value, expected): @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '// Parameters\nval foo = "bar"\n'), - ({"foo": True}, "// Parameters\nval foo = true\n"), - ({"foo": 5}, "// Parameters\nval foo = 5\n"), - ({"foo": 1.1}, "// Parameters\nval foo = 1.1\n"), - ({"foo": ["bar", "baz"]}, '// Parameters\nval foo = Seq("bar", "baz")\n'), - ({"foo": {"bar": "baz"}}, '// Parameters\nval foo = Map("bar" -> "baz")\n'), + ({'foo': 'bar'}, '// Parameters\nval foo = "bar"\n'), + ({'foo': True}, '// Parameters\nval foo = true\n'), + ({'foo': 5}, '// Parameters\nval foo = 5\n'), + ({'foo': 1.1}, '// Parameters\nval foo = 1.1\n'), + ({'foo': ['bar', 'baz']}, '// Parameters\nval foo = Seq("bar", "baz")\n'), + ({'foo': {'bar': 'baz'}}, '// Parameters\nval foo = Map("bar" -> "baz")\n'), ( - OrderedDict([["foo", "bar"], ["baz", ["buz"]]]), + OrderedDict([['foo', 'bar'], ['baz', ['buz']]]), '// Parameters\nval foo = "bar"\nval baz = Seq("buz")\n', ), ], @@ -273,26 +272,26 @@ def test_translate_codify_scala(parameters, expected): # C# section @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'), - ({"foo": "bar"}, 'new Dictionary{ { "foo" , "bar" } }'), - ({"foo": '"bar"'}, 'new Dictionary{ { "foo" , "\\"bar\\"" } }'), - (["foo"], 'new [] { "foo" }'), - (["foo", '"bar"'], 'new [] { "foo", "\\"bar\\"" }'), + ({'foo': 'bar'}, 'new Dictionary{ { "foo" , "bar" } }'), + ({'foo': '"bar"'}, 'new Dictionary{ { "foo" , "\\"bar\\"" } }'), + (['foo'], 'new [] { "foo" }'), + (['foo', '"bar"'], 'new [] { "foo", "\\"bar\\"" }'), ( - [{"foo": "bar"}], + [{'foo': 'bar'}], 'new [] { new Dictionary{ { "foo" , "bar" } } }', ), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (2147483648, "2147483648L"), - (-2147483649, "-2147483649L"), - (True, "true"), - (False, "false"), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (2147483648, '2147483648L'), + (-2147483649, '-2147483649L'), + (True, 'true'), + (False, 'false'), ], ) def test_translate_type_csharp(test_input, expected): @@ -300,34 +299,34 @@ def test_translate_type_csharp(test_input, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "//"), ("foo", "// foo"), ("['best effort']", "// ['best effort']")], + 'test_input,expected', + [('', '//'), ('foo', '// foo'), ("['best effort']", "// ['best effort']")], ) def test_translate_comment_csharp(test_input, expected): assert translators.CSharpTranslator.comment(test_input) == expected @pytest.mark.parametrize( - "input_name,input_value,expected", - [("foo", '""', 'var foo = "";'), ("foo", '"bar"', 'var foo = "bar";')], + 'input_name,input_value,expected', + [('foo', '""', 'var foo = "";'), ('foo', '"bar"', 'var foo = "bar";')], ) def test_translate_assign_csharp(input_name, input_value, expected): assert translators.CSharpTranslator.assign(input_name, input_value) == expected @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '// Parameters\nvar foo = "bar";\n'), - ({"foo": True}, "// Parameters\nvar foo = true;\n"), - ({"foo": 5}, "// Parameters\nvar foo = 5;\n"), - ({"foo": 1.1}, "// Parameters\nvar foo = 1.1;\n"), + ({'foo': 'bar'}, '// Parameters\nvar foo = "bar";\n'), + ({'foo': True}, '// Parameters\nvar foo = true;\n'), + ({'foo': 5}, '// Parameters\nvar foo = 5;\n'), + ({'foo': 1.1}, '// Parameters\nvar foo = 1.1;\n'), ( - {"foo": ["bar", "baz"]}, + {'foo': ['bar', 'baz']}, '// Parameters\nvar foo = new [] { "bar", "baz" };\n', ), ( - {"foo": {"bar": "baz"}}, + {'foo': {'bar': 'baz'}}, '// Parameters\nvar foo = new Dictionary{ { "bar" , "baz" } };\n', ), ], @@ -338,29 +337,29 @@ def test_translate_codify_csharp(parameters, expected): # Powershell section @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{`"foo`": `"bar`"}"'), - ({"foo": "bar"}, '@{"foo" = "bar"}'), - ({"foo": '"bar"'}, '@{"foo" = "`"bar`""}'), - ({"foo": ["bar"]}, '@{"foo" = @("bar")}'), - ({"foo": {"bar": "baz"}}, '@{"foo" = @{"bar" = "baz"}}'), - ({"foo": {"bar": '"baz"'}}, '@{"foo" = @{"bar" = "`"baz`""}}'), - (["foo"], '@("foo")'), - (["foo", '"bar"'], '@("foo", "`"bar`"")'), - ([{"foo": "bar"}], '@(@{"foo" = "bar"})'), - ([{"foo": '"bar"'}], '@(@{"foo" = "`"bar`""})'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (float("nan"), "[double]::NaN"), - (float("-inf"), "[double]::NegativeInfinity"), - (float("inf"), "[double]::PositiveInfinity"), - (True, "$True"), - (False, "$False"), - (None, "$Null"), + ({'foo': 'bar'}, '@{"foo" = "bar"}'), + ({'foo': '"bar"'}, '@{"foo" = "`"bar`""}'), + ({'foo': ['bar']}, '@{"foo" = @("bar")}'), + ({'foo': {'bar': 'baz'}}, '@{"foo" = @{"bar" = "baz"}}'), + ({'foo': {'bar': '"baz"'}}, '@{"foo" = @{"bar" = "`"baz`""}}'), + (['foo'], '@("foo")'), + (['foo', '"bar"'], '@("foo", "`"bar`"")'), + ([{'foo': 'bar'}], '@(@{"foo" = "bar"})'), + ([{'foo': '"bar"'}], '@(@{"foo" = "`"bar`""})'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (float('nan'), '[double]::NaN'), + (float('-inf'), '[double]::NegativeInfinity'), + (float('inf'), '[double]::PositiveInfinity'), + (True, '$True'), + (False, '$False'), + (None, '$Null'), ], ) def test_translate_type_powershell(test_input, expected): @@ -368,16 +367,16 @@ def test_translate_type_powershell(test_input, expected): @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '# Parameters\n$foo = "bar"\n'), - ({"foo": True}, "# Parameters\n$foo = $True\n"), - ({"foo": 5}, "# Parameters\n$foo = 5\n"), - ({"foo": 1.1}, "# Parameters\n$foo = 1.1\n"), - ({"foo": ["bar", "baz"]}, '# Parameters\n$foo = @("bar", "baz")\n'), - ({"foo": {"bar": "baz"}}, '# Parameters\n$foo = @{"bar" = "baz"}\n'), + ({'foo': 'bar'}, '# Parameters\n$foo = "bar"\n'), + ({'foo': True}, '# Parameters\n$foo = $True\n'), + ({'foo': 5}, '# Parameters\n$foo = 5\n'), + ({'foo': 1.1}, '# Parameters\n$foo = 1.1\n'), + ({'foo': ['bar', 'baz']}, '# Parameters\n$foo = @("bar", "baz")\n'), + ({'foo': {'bar': 'baz'}}, '# Parameters\n$foo = @{"bar" = "baz"}\n'), ( - OrderedDict([["foo", "bar"], ["baz", ["buz"]]]), + OrderedDict([['foo', 'bar'], ['baz', ['buz']]]), '# Parameters\n$foo = "bar"\n$baz = @("buz")\n', ), ], @@ -387,16 +386,16 @@ def test_translate_codify_powershell(parameters, expected): @pytest.mark.parametrize( - "input_name,input_value,expected", - [("foo", '""', '$foo = ""'), ("foo", '"bar"', '$foo = "bar"')], + 'input_name,input_value,expected', + [('foo', '""', '$foo = ""'), ('foo', '"bar"', '$foo = "bar"')], ) def test_translate_assign_powershell(input_name, input_value, expected): assert translators.PowershellTranslator.assign(input_name, input_value) == expected @pytest.mark.parametrize( - "test_input,expected", - [("", "#"), ("foo", "# foo"), ("['best effort']", "# ['best effort']")], + 'test_input,expected', + [('', '#'), ('foo', '# foo'), ("['best effort']", "# ['best effort']")], ) def test_translate_comment_powershell(test_input, expected): assert translators.PowershellTranslator.comment(test_input) == expected @@ -404,23 +403,23 @@ def test_translate_comment_powershell(test_input, expected): # F# section @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'), - ({"foo": "bar"}, '[ ("foo", "bar" :> IComparable) ] |> Map.ofList'), - ({"foo": '"bar"'}, '[ ("foo", "\\"bar\\"" :> IComparable) ] |> Map.ofList'), - (["foo"], '[ "foo" ]'), - (["foo", '"bar"'], '[ "foo"; "\\"bar\\"" ]'), - ([{"foo": "bar"}], '[ [ ("foo", "bar" :> IComparable) ] |> Map.ofList ]'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (2147483648, "2147483648L"), - (-2147483649, "-2147483649L"), - (True, "true"), - (False, "false"), + ({'foo': 'bar'}, '[ ("foo", "bar" :> IComparable) ] |> Map.ofList'), + ({'foo': '"bar"'}, '[ ("foo", "\\"bar\\"" :> IComparable) ] |> Map.ofList'), + (['foo'], '[ "foo" ]'), + (['foo', '"bar"'], '[ "foo"; "\\"bar\\"" ]'), + ([{'foo': 'bar'}], '[ [ ("foo", "bar" :> IComparable) ] |> Map.ofList ]'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (2147483648, '2147483648L'), + (-2147483649, '-2147483649L'), + (True, 'true'), + (False, 'false'), ], ) def test_translate_type_fsharp(test_input, expected): @@ -428,10 +427,10 @@ def test_translate_type_fsharp(test_input, expected): @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("", "(* *)"), - ("foo", "(* foo *)"), + ('', '(* *)'), + ('foo', '(* foo *)'), ("['best effort']", "(* ['best effort'] *)"), ], ) @@ -440,23 +439,23 @@ def test_translate_comment_fsharp(test_input, expected): @pytest.mark.parametrize( - "input_name,input_value,expected", - [("foo", '""', 'let foo = ""'), ("foo", '"bar"', 'let foo = "bar"')], + 'input_name,input_value,expected', + [('foo', '""', 'let foo = ""'), ('foo', '"bar"', 'let foo = "bar"')], ) def test_translate_assign_fsharp(input_name, input_value, expected): assert translators.FSharpTranslator.assign(input_name, input_value) == expected @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '(* Parameters *)\nlet foo = "bar"\n'), - ({"foo": True}, "(* Parameters *)\nlet foo = true\n"), - ({"foo": 5}, "(* Parameters *)\nlet foo = 5\n"), - ({"foo": 1.1}, "(* Parameters *)\nlet foo = 1.1\n"), - ({"foo": ["bar", "baz"]}, '(* Parameters *)\nlet foo = [ "bar"; "baz" ]\n'), + ({'foo': 'bar'}, '(* Parameters *)\nlet foo = "bar"\n'), + ({'foo': True}, '(* Parameters *)\nlet foo = true\n'), + ({'foo': 5}, '(* Parameters *)\nlet foo = 5\n'), + ({'foo': 1.1}, '(* Parameters *)\nlet foo = 1.1\n'), + ({'foo': ['bar', 'baz']}, '(* Parameters *)\nlet foo = [ "bar"; "baz" ]\n'), ( - {"foo": {"bar": "baz"}}, + {'foo': {'bar': 'baz'}}, '(* Parameters *)\nlet foo = [ ("bar", "baz" :> IComparable) ] |> Map.ofList\n', ), ], @@ -466,26 +465,26 @@ def test_translate_codify_fsharp(parameters, expected): @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{\\"foo\\": \\"bar\\"}"'), - ({"foo": "bar"}, 'Dict("foo" => "bar")'), - ({"foo": '"bar"'}, 'Dict("foo" => "\\"bar\\"")'), - ({"foo": ["bar"]}, 'Dict("foo" => ["bar"])'), - ({"foo": {"bar": "baz"}}, 'Dict("foo" => Dict("bar" => "baz"))'), - ({"foo": {"bar": '"baz"'}}, 'Dict("foo" => Dict("bar" => "\\"baz\\""))'), - (["foo"], '["foo"]'), - (["foo", '"bar"'], '["foo", "\\"bar\\""]'), - ([{"foo": "bar"}], '[Dict("foo" => "bar")]'), - ([{"foo": '"bar"'}], '[Dict("foo" => "\\"bar\\"")]'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (True, "true"), - (False, "false"), - (None, "nothing"), + ({'foo': 'bar'}, 'Dict("foo" => "bar")'), + ({'foo': '"bar"'}, 'Dict("foo" => "\\"bar\\"")'), + ({'foo': ['bar']}, 'Dict("foo" => ["bar"])'), + ({'foo': {'bar': 'baz'}}, 'Dict("foo" => Dict("bar" => "baz"))'), + ({'foo': {'bar': '"baz"'}}, 'Dict("foo" => Dict("bar" => "\\"baz\\""))'), + (['foo'], '["foo"]'), + (['foo', '"bar"'], '["foo", "\\"bar\\""]'), + ([{'foo': 'bar'}], '[Dict("foo" => "bar")]'), + ([{'foo': '"bar"'}], '[Dict("foo" => "\\"bar\\"")]'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (True, 'true'), + (False, 'false'), + (None, 'nothing'), ], ) def test_translate_type_julia(test_input, expected): @@ -493,16 +492,16 @@ def test_translate_type_julia(test_input, expected): @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '# Parameters\nfoo = "bar"\n'), - ({"foo": True}, "# Parameters\nfoo = true\n"), - ({"foo": 5}, "# Parameters\nfoo = 5\n"), - ({"foo": 1.1}, "# Parameters\nfoo = 1.1\n"), - ({"foo": ["bar", "baz"]}, '# Parameters\nfoo = ["bar", "baz"]\n'), - ({"foo": {"bar": "baz"}}, '# Parameters\nfoo = Dict("bar" => "baz")\n'), + ({'foo': 'bar'}, '# Parameters\nfoo = "bar"\n'), + ({'foo': True}, '# Parameters\nfoo = true\n'), + ({'foo': 5}, '# Parameters\nfoo = 5\n'), + ({'foo': 1.1}, '# Parameters\nfoo = 1.1\n'), + ({'foo': ['bar', 'baz']}, '# Parameters\nfoo = ["bar", "baz"]\n'), + ({'foo': {'bar': 'baz'}}, '# Parameters\nfoo = Dict("bar" => "baz")\n'), ( - OrderedDict([["foo", "bar"], ["baz", ["buz"]]]), + OrderedDict([['foo', 'bar'], ['baz', ['buz']]]), '# Parameters\nfoo = "bar"\nbaz = ["buz"]\n', ), ], @@ -512,44 +511,44 @@ def test_translate_codify_julia(parameters, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "#"), ("foo", "# foo"), ('["best effort"]', '# ["best effort"]')], + 'test_input,expected', + [('', '#'), ('foo', '# foo'), ('["best effort"]', '# ["best effort"]')], ) def test_translate_comment_julia(test_input, expected): assert translators.JuliaTranslator.comment(test_input) == expected @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", '"foo"'), + ('foo', '"foo"'), ('{"foo": "bar"}', '"{""foo"": ""bar""}"'), - ({1: "foo"}, "containers.Map({'1'}, {\"foo\"})"), - ({1.0: "foo"}, "containers.Map({'1.0'}, {\"foo\"})"), - ({None: "foo"}, "containers.Map({'None'}, {\"foo\"})"), - ({True: "foo"}, "containers.Map({'True'}, {\"foo\"})"), - ({"foo": "bar"}, "containers.Map({'foo'}, {\"bar\"})"), - ({"foo": '"bar"'}, 'containers.Map({\'foo\'}, {"""bar"""})'), - ({"foo": ["bar"]}, "containers.Map({'foo'}, {{\"bar\"}})"), + ({1: 'foo'}, 'containers.Map({\'1\'}, {"foo"})'), + ({1.0: 'foo'}, 'containers.Map({\'1.0\'}, {"foo"})'), + ({None: 'foo'}, 'containers.Map({\'None\'}, {"foo"})'), + ({True: 'foo'}, 'containers.Map({\'True\'}, {"foo"})'), + ({'foo': 'bar'}, 'containers.Map({\'foo\'}, {"bar"})'), + ({'foo': '"bar"'}, 'containers.Map({\'foo\'}, {"""bar"""})'), + ({'foo': ['bar']}, 'containers.Map({\'foo\'}, {{"bar"}})'), ( - {"foo": {"bar": "baz"}}, + {'foo': {'bar': 'baz'}}, "containers.Map({'foo'}, {containers.Map({'bar'}, {\"baz\"})})", ), ( - {"foo": {"bar": '"baz"'}}, + {'foo': {'bar': '"baz"'}}, 'containers.Map({\'foo\'}, {containers.Map({\'bar\'}, {"""baz"""})})', ), - (["foo"], '{"foo"}'), - (["foo", '"bar"'], '{"foo", """bar"""}'), - ([{"foo": "bar"}], "{containers.Map({'foo'}, {\"bar\"})}"), - ([{"foo": '"bar"'}], '{containers.Map({\'foo\'}, {"""bar"""})}'), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (True, "true"), - (False, "false"), - (None, "NaN"), + (['foo'], '{"foo"}'), + (['foo', '"bar"'], '{"foo", """bar"""}'), + ([{'foo': 'bar'}], '{containers.Map({\'foo\'}, {"bar"})}'), + ([{'foo': '"bar"'}], '{containers.Map({\'foo\'}, {"""bar"""})}'), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (True, 'true'), + (False, 'false'), + (None, 'NaN'), ], ) def test_translate_type_matlab(test_input, expected): @@ -557,19 +556,19 @@ def test_translate_type_matlab(test_input, expected): @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, '% Parameters\nfoo = "bar";\n'), - ({"foo": True}, "% Parameters\nfoo = true;\n"), - ({"foo": 5}, "% Parameters\nfoo = 5;\n"), - ({"foo": 1.1}, "% Parameters\nfoo = 1.1;\n"), - ({"foo": ["bar", "baz"]}, '% Parameters\nfoo = {"bar", "baz"};\n'), + ({'foo': 'bar'}, '% Parameters\nfoo = "bar";\n'), + ({'foo': True}, '% Parameters\nfoo = true;\n'), + ({'foo': 5}, '% Parameters\nfoo = 5;\n'), + ({'foo': 1.1}, '% Parameters\nfoo = 1.1;\n'), + ({'foo': ['bar', 'baz']}, '% Parameters\nfoo = {"bar", "baz"};\n'), ( - {"foo": {"bar": "baz"}}, - "% Parameters\nfoo = containers.Map({'bar'}, {\"baz\"});\n", + {'foo': {'bar': 'baz'}}, + '% Parameters\nfoo = containers.Map({\'bar\'}, {"baz"});\n', ), ( - OrderedDict([["foo", "bar"], ["baz", ["buz"]]]), + OrderedDict([['foo', 'bar'], ['baz', ['buz']]]), '% Parameters\nfoo = "bar";\nbaz = {"buz"};\n', ), ], @@ -579,8 +578,8 @@ def test_translate_codify_matlab(parameters, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "%"), ("foo", "% foo"), ("['best effort']", "% ['best effort']")], + 'test_input,expected', + [('', '%'), ('foo', '% foo'), ("['best effort']", "% ['best effort']")], ) def test_translate_comment_matlab(test_input, expected): assert translators.MatlabTranslator.comment(test_input) == expected @@ -590,14 +589,14 @@ def test_find_translator_with_exact_kernel_name(): my_new_kernel_translator = Mock() my_new_language_translator = Mock() translators.papermill_translators.register( - "my_new_kernel", my_new_kernel_translator + 'my_new_kernel', my_new_kernel_translator ) translators.papermill_translators.register( - "my_new_language", my_new_language_translator + 'my_new_language', my_new_language_translator ) assert ( translators.papermill_translators.find_translator( - "my_new_kernel", "my_new_language" + 'my_new_kernel', 'my_new_language' ) is my_new_kernel_translator ) @@ -606,11 +605,11 @@ def test_find_translator_with_exact_kernel_name(): def test_find_translator_with_exact_language(): my_new_language_translator = Mock() translators.papermill_translators.register( - "my_new_language", my_new_language_translator + 'my_new_language', my_new_language_translator ) assert ( translators.papermill_translators.find_translator( - "unregistered_kernel", "my_new_language" + 'unregistered_kernel', 'my_new_language' ) is my_new_language_translator ) @@ -619,14 +618,14 @@ def test_find_translator_with_exact_language(): def test_find_translator_with_no_such_kernel_or_language(): with pytest.raises(PapermillException): translators.papermill_translators.find_translator( - "unregistered_kernel", "unregistered_language" + 'unregistered_kernel', 'unregistered_language' ) def test_translate_uses_str_representation_of_unknown_types(): class FooClass: def __str__(self): - return "foo" + return 'foo' obj = FooClass() assert translators.Translator.translate(obj) == '"foo"' @@ -637,7 +636,7 @@ class MyNewTranslator(translators.Translator): pass with pytest.raises(NotImplementedError): - MyNewTranslator.translate_dict({"foo": "bar"}) + MyNewTranslator.translate_dict({'foo': 'bar'}) def test_translator_must_implement_translate_list(): @@ -645,7 +644,7 @@ class MyNewTranslator(translators.Translator): pass with pytest.raises(NotImplementedError): - MyNewTranslator.translate_list(["foo", "bar"]) + MyNewTranslator.translate_list(['foo', 'bar']) def test_translator_must_implement_comment(): @@ -653,24 +652,24 @@ class MyNewTranslator(translators.Translator): pass with pytest.raises(NotImplementedError): - MyNewTranslator.comment("foo") + MyNewTranslator.comment('foo') # Bash/sh section @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ - ("foo", "foo"), - ("foo space", "'foo space'"), + ('foo', 'foo'), + ('foo space', "'foo space'"), ("foo's apostrophe", "'foo'\"'\"'s apostrophe'"), - ("shell ( is ) ", "'shell ( is ) '"), - (12345, "12345"), - (-54321, "-54321"), - (1.2345, "1.2345"), - (-5432.1, "-5432.1"), - (True, "true"), - (False, "false"), - (None, ""), + ('shell ( is ) ', "'shell ( is ) '"), + (12345, '12345'), + (-54321, '-54321'), + (1.2345, '1.2345'), + (-5432.1, '-5432.1'), + (True, 'true'), + (False, 'false'), + (None, ''), ], ) def test_translate_type_sh(test_input, expected): @@ -678,23 +677,23 @@ def test_translate_type_sh(test_input, expected): @pytest.mark.parametrize( - "test_input,expected", - [("", "#"), ("foo", "# foo"), ("['best effort']", "# ['best effort']")], + 'test_input,expected', + [('', '#'), ('foo', '# foo'), ("['best effort']", "# ['best effort']")], ) def test_translate_comment_sh(test_input, expected): assert translators.BashTranslator.comment(test_input) == expected @pytest.mark.parametrize( - "parameters,expected", + 'parameters,expected', [ - ({"foo": "bar"}, "# Parameters\nfoo=bar\n"), - ({"foo": "shell ( is ) "}, "# Parameters\nfoo='shell ( is ) '\n"), - ({"foo": True}, "# Parameters\nfoo=true\n"), - ({"foo": 5}, "# Parameters\nfoo=5\n"), - ({"foo": 1.1}, "# Parameters\nfoo=1.1\n"), + ({'foo': 'bar'}, '# Parameters\nfoo=bar\n'), + ({'foo': 'shell ( is ) '}, "# Parameters\nfoo='shell ( is ) '\n"), + ({'foo': True}, '# Parameters\nfoo=true\n'), + ({'foo': 5}, '# Parameters\nfoo=5\n'), + ({'foo': 1.1}, '# Parameters\nfoo=1.1\n'), ( - OrderedDict([["foo", "bar"], ["baz", "$dumb(shell)"]]), + OrderedDict([['foo', 'bar'], ['baz', '$dumb(shell)']]), "# Parameters\nfoo=bar\nbaz='$dumb(shell)'\n", ), ], diff --git a/papermill/tests/test_utils.py b/papermill/tests/test_utils.py index 519fa383..eed256a5 100644 --- a/papermill/tests/test_utils.py +++ b/papermill/tests/test_utils.py @@ -1,39 +1,38 @@ -import pytest import warnings - -from unittest.mock import Mock, call -from tempfile import TemporaryDirectory from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import Mock, call -from nbformat.v4 import new_notebook, new_code_cell +import pytest +from nbformat.v4 import new_code_cell, new_notebook +from ..exceptions import PapermillParameterOverwriteWarning from ..utils import ( any_tagged_cell, - retry, chdir, merge_kwargs, remove_args, + retry, ) -from ..exceptions import PapermillParameterOverwriteWarning def test_no_tagged_cell(): nb = new_notebook( - cells=[new_code_cell("a = 2", metadata={"tags": []})], + cells=[new_code_cell('a = 2', metadata={'tags': []})], ) - assert not any_tagged_cell(nb, "parameters") + assert not any_tagged_cell(nb, 'parameters') def test_tagged_cell(): nb = new_notebook( - cells=[new_code_cell("a = 2", metadata={"tags": ["parameters"]})], + cells=[new_code_cell('a = 2', metadata={'tags': ['parameters']})], ) - assert any_tagged_cell(nb, "parameters") + assert any_tagged_cell(nb, 'parameters') def test_merge_kwargs(): with warnings.catch_warnings(record=True) as wrn: - assert merge_kwargs({"a": 1, "b": 2}, a=3) == {"a": 3, "b": 2} + assert merge_kwargs({'a': 1, 'b': 2}, a=3) == {'a': 3, 'b': 2} assert len(wrn) == 1 assert issubclass(wrn[0].category, PapermillParameterOverwriteWarning) assert ( @@ -43,17 +42,17 @@ def test_merge_kwargs(): def test_remove_args(): - assert remove_args(["a"], a=1, b=2, c=3) == {"c": 3, "b": 2} + assert remove_args(['a'], a=1, b=2, c=3) == {'c': 3, 'b': 2} def test_retry(): m = Mock( - side_effect=RuntimeError(), __name__="m", __module__="test_s3", __doc__="m" + side_effect=RuntimeError(), __name__='m', __module__='test_s3', __doc__='m' ) wrapped_m = retry(3)(m) with pytest.raises(RuntimeError): - wrapped_m("foo") - m.assert_has_calls([call("foo"), call("foo"), call("foo")]) + wrapped_m('foo') + m.assert_has_calls([call('foo'), call('foo'), call('foo')]) def test_chdir(): diff --git a/papermill/translators.py b/papermill/translators.py index ace316bf..58c2357d 100644 --- a/papermill/translators.py +++ b/papermill/translators.py @@ -6,7 +6,6 @@ from .exceptions import PapermillException from .models import Parameter - logger = logging.getLogger(__name__) @@ -29,9 +28,7 @@ def find_translator(self, kernel_name, language): elif language in self._translators: return self._translators[language] raise PapermillException( - "No parameter translator functions specified for kernel '{}' or language '{}'".format( - kernel_name, language - ) + f"No parameter translator functions specified for kernel '{kernel_name}' or language '{language}'" ) @@ -39,15 +36,15 @@ class Translator: @classmethod def translate_raw_str(cls, val): """Reusable by most interpreters""" - return f"{val}" + return f'{val}' @classmethod def translate_escaped_str(cls, str_val): """Reusable by most interpreters""" if isinstance(str_val, str): - str_val = str_val.encode("unicode_escape") - str_val = str_val.decode("utf-8") - str_val = str_val.replace('"', r"\"") + str_val = str_val.encode('unicode_escape') + str_val = str_val.decode('utf-8') + str_val = str_val.replace('"', r'\"') return f'"{str_val}"' @classmethod @@ -73,15 +70,15 @@ def translate_float(cls, val): @classmethod def translate_bool(cls, val): """Default behavior for translation""" - return "true" if val else "false" + return 'true' if val else 'false' @classmethod def translate_dict(cls, val): - raise NotImplementedError(f"dict type translation not implemented for {cls}") + raise NotImplementedError(f'dict type translation not implemented for {cls}') @classmethod def translate_list(cls, val): - raise NotImplementedError(f"list type translation not implemented for {cls}") + raise NotImplementedError(f'list type translation not implemented for {cls}') @classmethod def translate(cls, val): @@ -106,17 +103,17 @@ def translate(cls, val): @classmethod def comment(cls, cmt_str): - raise NotImplementedError(f"comment translation not implemented for {cls}") + raise NotImplementedError(f'comment translation not implemented for {cls}') @classmethod def assign(cls, name, str_val): - return f"{name} = {str_val}" + return f'{name} = {str_val}' @classmethod - def codify(cls, parameters, comment="Parameters"): - content = f"{cls.comment(comment)}\n" + def codify(cls, parameters, comment='Parameters'): + content = f'{cls.comment(comment)}\n' for name, val in parameters.items(): - content += f"{cls.assign(name, cls.translate(val))}\n" + content += f'{cls.assign(name, cls.translate(val))}\n' return content @classmethod @@ -140,7 +137,7 @@ def inspect(cls, parameters_cell): List[Parameter] A list of all parameters """ - raise NotImplementedError(f"parameters introspection not implemented for {cls}") + raise NotImplementedError(f'parameters introspection not implemented for {cls}') class PythonTranslator(Translator): @@ -166,22 +163,22 @@ def translate_bool(cls, val): @classmethod def translate_dict(cls, val): - escaped = ", ".join( - [f"{cls.translate_str(k)}: {cls.translate(v)}" for k, v in val.items()] + escaped = ', '.join( + [f'{cls.translate_str(k)}: {cls.translate(v)}' for k, v in val.items()] ) - return f"{{{escaped}}}" + return f'{{{escaped}}}' @classmethod def translate_list(cls, val): - escaped = ", ".join([cls.translate(v) for v in val]) - return f"[{escaped}]" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'[{escaped}]' @classmethod def comment(cls, cmt_str): - return f"# {cmt_str}".strip() + return f'# {cmt_str}'.strip() @classmethod - def codify(cls, parameters, comment="Parameters"): + def codify(cls, parameters, comment='Parameters'): content = super().codify(parameters, comment) try: # Put content through the Black Python code formatter @@ -192,7 +189,7 @@ def codify(cls, parameters, comment="Parameters"): except ImportError: logger.debug("Black is not installed, parameters won't be formatted") except AttributeError as aerr: - logger.warning(f"Black encountered an error, skipping formatting ({aerr})") + logger.warning(f'Black encountered an error, skipping formatting ({aerr})') return content @classmethod @@ -213,7 +210,7 @@ def inspect(cls, parameters_cell): A list of all parameters """ params = [] - src = parameters_cell["source"] + src = parameters_cell['source'] def flatten_accumulator(accumulator): """Flatten a multilines variable definition. @@ -225,10 +222,10 @@ def flatten_accumulator(accumulator): Returns: Flatten definition """ - flat_string = "" + flat_string = '' for line in accumulator[:-1]: - if "#" in line: - comment_pos = line.index("#") + if '#' in line: + comment_pos = line.index('#') flat_string += line[:comment_pos].strip() else: flat_string += line.strip() @@ -244,10 +241,10 @@ def flatten_accumulator(accumulator): grouped_variable = [] accumulator = [] for iline, line in enumerate(src.splitlines()): - if len(line.strip()) == 0 or line.strip().startswith("#"): + if len(line.strip()) == 0 or line.strip().startswith('#'): continue # Skip blank and comment - nequal = line.count("=") + nequal = line.count('=') if nequal > 0: grouped_variable.append(flatten_accumulator(accumulator)) accumulator = [] @@ -265,16 +262,16 @@ def flatten_accumulator(accumulator): match = re.match(cls.PARAMETER_PATTERN, definition) if match is not None: attr = match.groupdict() - if attr["target"] is None: # Fail to get variable name + if attr['target'] is None: # Fail to get variable name continue - type_name = str(attr["annotation"] or attr["type_comment"] or None) + type_name = str(attr['annotation'] or attr['type_comment'] or None) params.append( Parameter( - name=attr["target"].strip(), + name=attr['target'].strip(), inferred_type_name=type_name.strip(), - default=str(attr["value"]).strip(), - help=str(attr["help"] or "").strip(), + default=str(attr['value']).strip(), + help=str(attr['help'] or '').strip(), ) ) @@ -284,85 +281,85 @@ def flatten_accumulator(accumulator): class RTranslator(Translator): @classmethod def translate_none(cls, val): - return "NULL" + return 'NULL' @classmethod def translate_bool(cls, val): - return "TRUE" if val else "FALSE" + return 'TRUE' if val else 'FALSE' @classmethod def translate_dict(cls, val): - escaped = ", ".join( - [f"{cls.translate_str(k)} = {cls.translate(v)}" for k, v in val.items()] + escaped = ', '.join( + [f'{cls.translate_str(k)} = {cls.translate(v)}' for k, v in val.items()] ) - return f"list({escaped})" + return f'list({escaped})' @classmethod def translate_list(cls, val): - escaped = ", ".join([cls.translate(v) for v in val]) - return f"list({escaped})" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'list({escaped})' @classmethod def comment(cls, cmt_str): - return f"# {cmt_str}".strip() + return f'# {cmt_str}'.strip() @classmethod def assign(cls, name, str_val): # Leading '_' aren't legal R variable names -- so we drop them when injecting - while name.startswith("_"): + while name.startswith('_'): name = name[1:] - return f"{name} = {str_val}" + return f'{name} = {str_val}' class ScalaTranslator(Translator): @classmethod def translate_int(cls, val): strval = cls.translate_raw_str(val) - return strval + "L" if (val > 2147483647 or val < -2147483648) else strval + return strval + 'L' if (val > 2147483647 or val < -2147483648) else strval @classmethod def translate_dict(cls, val): """Translate dicts to scala Maps""" - escaped = ", ".join( - [f"{cls.translate_str(k)} -> {cls.translate(v)}" for k, v in val.items()] + escaped = ', '.join( + [f'{cls.translate_str(k)} -> {cls.translate(v)}' for k, v in val.items()] ) - return f"Map({escaped})" + return f'Map({escaped})' @classmethod def translate_list(cls, val): """Translate list to scala Seq""" - escaped = ", ".join([cls.translate(v) for v in val]) - return f"Seq({escaped})" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'Seq({escaped})' @classmethod def comment(cls, cmt_str): - return f"// {cmt_str}".strip() + return f'// {cmt_str}'.strip() @classmethod def assign(cls, name, str_val): - return f"val {name} = {str_val}" + return f'val {name} = {str_val}' class JuliaTranslator(Translator): @classmethod def translate_none(cls, val): - return "nothing" + return 'nothing' @classmethod def translate_dict(cls, val): - escaped = ", ".join( - [f"{cls.translate_str(k)} => {cls.translate(v)}" for k, v in val.items()] + escaped = ', '.join( + [f'{cls.translate_str(k)} => {cls.translate(v)}' for k, v in val.items()] ) - return f"Dict({escaped})" + return f'Dict({escaped})' @classmethod def translate_list(cls, val): - escaped = ", ".join([cls.translate(v) for v in val]) - return f"[{escaped}]" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'[{escaped}]' @classmethod def comment(cls, cmt_str): - return f"# {cmt_str}".strip() + return f'# {cmt_str}'.strip() class MatlabTranslator(Translator): @@ -370,8 +367,8 @@ class MatlabTranslator(Translator): def translate_escaped_str(cls, str_val): """Translate a string to an escaped Matlab string""" if isinstance(str_val, str): - str_val = str_val.encode("unicode_escape") - str_val = str_val.decode("utf-8") + str_val = str_val.encode('unicode_escape') + str_val = str_val.decode('utf-8') str_val = str_val.replace('"', '""') return f'"{str_val}"' @@ -379,35 +376,35 @@ def translate_escaped_str(cls, str_val): def __translate_char_array(str_val): """Translates a string to a Matlab char array""" if isinstance(str_val, str): - str_val = str_val.encode("unicode_escape") - str_val = str_val.decode("utf-8") + str_val = str_val.encode('unicode_escape') + str_val = str_val.decode('utf-8') str_val = str_val.replace("'", "''") return f"'{str_val}'" @classmethod def translate_none(cls, val): - return "NaN" + return 'NaN' @classmethod def translate_dict(cls, val): - keys = ", ".join([f"{cls.__translate_char_array(k)}" for k, v in val.items()]) - vals = ", ".join([f"{cls.translate(v)}" for k, v in val.items()]) - return f"containers.Map({{{keys}}}, {{{vals}}})" + keys = ', '.join([f'{cls.__translate_char_array(k)}' for k, v in val.items()]) + vals = ', '.join([f'{cls.translate(v)}' for k, v in val.items()]) + return f'containers.Map({{{keys}}}, {{{vals}}})' @classmethod def translate_list(cls, val): - escaped = ", ".join([cls.translate(v) for v in val]) - return f"{{{escaped}}}" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'{{{escaped}}}' @classmethod def comment(cls, cmt_str): - return f"% {cmt_str}".strip() + return f'% {cmt_str}'.strip() @classmethod - def codify(cls, parameters, comment="Parameters"): - content = f"{cls.comment(comment)}\n" + def codify(cls, parameters, comment='Parameters'): + content = f'{cls.comment(comment)}\n' for name, val in parameters.items(): - content += f"{cls.assign(name, cls.translate(val))};\n" + content += f'{cls.assign(name, cls.translate(val))};\n' return content @@ -415,80 +412,80 @@ class CSharpTranslator(Translator): @classmethod def translate_none(cls, val): # Can't figure out how to do this as nullable - raise NotImplementedError("Option type not implemented for C#.") + raise NotImplementedError('Option type not implemented for C#.') @classmethod def translate_bool(cls, val): - return "true" if val else "false" + return 'true' if val else 'false' @classmethod def translate_int(cls, val): strval = cls.translate_raw_str(val) - return strval + "L" if (val > 2147483647 or val < -2147483648) else strval + return strval + 'L' if (val > 2147483647 or val < -2147483648) else strval @classmethod def translate_dict(cls, val): """Translate dicts to nontyped dictionary""" - kvps = ", ".join( + kvps = ', '.join( [ - f"{{ {cls.translate_str(k)} , {cls.translate(v)} }}" + f'{{ {cls.translate_str(k)} , {cls.translate(v)} }}' for k, v in val.items() ] ) - return f"new Dictionary{{ {kvps} }}" + return f'new Dictionary{{ {kvps} }}' @classmethod def translate_list(cls, val): """Translate list to array""" - escaped = ", ".join([cls.translate(v) for v in val]) - return f"new [] {{ {escaped} }}" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'new [] {{ {escaped} }}' @classmethod def comment(cls, cmt_str): - return f"// {cmt_str}".strip() + return f'// {cmt_str}'.strip() @classmethod def assign(cls, name, str_val): - return f"var {name} = {str_val};" + return f'var {name} = {str_val};' class FSharpTranslator(Translator): @classmethod def translate_none(cls, val): - return "None" + return 'None' @classmethod def translate_bool(cls, val): - return "true" if val else "false" + return 'true' if val else 'false' @classmethod def translate_int(cls, val): strval = cls.translate_raw_str(val) - return strval + "L" if (val > 2147483647 or val < -2147483648) else strval + return strval + 'L' if (val > 2147483647 or val < -2147483648) else strval @classmethod def translate_dict(cls, val): - tuples = "; ".join( + tuples = '; '.join( [ - f"({cls.translate_str(k)}, {cls.translate(v)} :> IComparable)" + f'({cls.translate_str(k)}, {cls.translate(v)} :> IComparable)' for k, v in val.items() ] ) - return f"[ {tuples} ] |> Map.ofList" + return f'[ {tuples} ] |> Map.ofList' @classmethod def translate_list(cls, val): - escaped = "; ".join([cls.translate(v) for v in val]) - return f"[ {escaped} ]" + escaped = '; '.join([cls.translate(v) for v in val]) + return f'[ {escaped} ]' @classmethod def comment(cls, cmt_str): - return f"(* {cmt_str} *)".strip() + return f'(* {cmt_str} *)'.strip() @classmethod def assign(cls, name, str_val): - return f"let {name} = {str_val}" + return f'let {name} = {str_val}' class PowershellTranslator(Translator): @@ -496,8 +493,8 @@ class PowershellTranslator(Translator): def translate_escaped_str(cls, str_val): """Translate a string to an escaped Matlab string""" if isinstance(str_val, str): - str_val = str_val.encode("unicode_escape") - str_val = str_val.decode("utf-8") + str_val = str_val.encode('unicode_escape') + str_val = str_val.decode('utf-8') str_val = str_val.replace('"', '`"') return f'"{str_val}"' @@ -506,49 +503,49 @@ def translate_float(cls, val): if math.isfinite(val): return cls.translate_raw_str(val) elif math.isnan(val): - return "[double]::NaN" + return '[double]::NaN' elif val < 0: - return "[double]::NegativeInfinity" + return '[double]::NegativeInfinity' else: - return "[double]::PositiveInfinity" + return '[double]::PositiveInfinity' @classmethod def translate_none(cls, val): - return "$Null" + return '$Null' @classmethod def translate_bool(cls, val): - return "$True" if val else "$False" + return '$True' if val else '$False' @classmethod def translate_dict(cls, val): - kvps = "\n ".join( - [f"{cls.translate_str(k)} = {cls.translate(v)}" for k, v in val.items()] + kvps = '\n '.join( + [f'{cls.translate_str(k)} = {cls.translate(v)}' for k, v in val.items()] ) - return f"@{{{kvps}}}" + return f'@{{{kvps}}}' @classmethod def translate_list(cls, val): - escaped = ", ".join([cls.translate(v) for v in val]) - return f"@({escaped})" + escaped = ', '.join([cls.translate(v) for v in val]) + return f'@({escaped})' @classmethod def comment(cls, cmt_str): - return f"# {cmt_str}".strip() + return f'# {cmt_str}'.strip() @classmethod def assign(cls, name, str_val): - return f"${name} = {str_val}" + return f'${name} = {str_val}' class BashTranslator(Translator): @classmethod def translate_none(cls, val): - return "" + return '' @classmethod def translate_bool(cls, val): - return "true" if val else "false" + return 'true' if val else 'false' @classmethod def translate_escaped_str(cls, str_val): @@ -556,35 +553,35 @@ def translate_escaped_str(cls, str_val): @classmethod def translate_list(cls, val): - escaped = " ".join([cls.translate(v) for v in val]) - return f"({escaped})" + escaped = ' '.join([cls.translate(v) for v in val]) + return f'({escaped})' @classmethod def comment(cls, cmt_str): - return f"# {cmt_str}".strip() + return f'# {cmt_str}'.strip() @classmethod def assign(cls, name, str_val): - return f"{name}={str_val}" + return f'{name}={str_val}' # Instantiate a PapermillIO instance and register Handlers. papermill_translators = PapermillTranslators() -papermill_translators.register("python", PythonTranslator) -papermill_translators.register("R", RTranslator) -papermill_translators.register("scala", ScalaTranslator) -papermill_translators.register("julia", JuliaTranslator) -papermill_translators.register("matlab", MatlabTranslator) -papermill_translators.register(".net-csharp", CSharpTranslator) -papermill_translators.register(".net-fsharp", FSharpTranslator) -papermill_translators.register(".net-powershell", PowershellTranslator) -papermill_translators.register("pysparkkernel", PythonTranslator) -papermill_translators.register("sparkkernel", ScalaTranslator) -papermill_translators.register("sparkrkernel", RTranslator) -papermill_translators.register("bash", BashTranslator) - - -def translate_parameters(kernel_name, language, parameters, comment="Parameters"): +papermill_translators.register('python', PythonTranslator) +papermill_translators.register('R', RTranslator) +papermill_translators.register('scala', ScalaTranslator) +papermill_translators.register('julia', JuliaTranslator) +papermill_translators.register('matlab', MatlabTranslator) +papermill_translators.register('.net-csharp', CSharpTranslator) +papermill_translators.register('.net-fsharp', FSharpTranslator) +papermill_translators.register('.net-powershell', PowershellTranslator) +papermill_translators.register('pysparkkernel', PythonTranslator) +papermill_translators.register('sparkkernel', ScalaTranslator) +papermill_translators.register('sparkrkernel', RTranslator) +papermill_translators.register('bash', BashTranslator) + + +def translate_parameters(kernel_name, language, parameters, comment='Parameters'): return papermill_translators.find_translator(kernel_name, language).codify( parameters, comment ) diff --git a/papermill/utils.py b/papermill/utils.py index 532a5a43..a9e6e877 100644 --- a/papermill/utils.py +++ b/papermill/utils.py @@ -1,13 +1,12 @@ -import os import logging +import os import warnings - from contextlib import contextmanager from functools import wraps from .exceptions import PapermillParameterOverwriteWarning -logger = logging.getLogger("papermill.utils") +logger = logging.getLogger('papermill.utils') def any_tagged_cell(nb, tag): @@ -48,9 +47,9 @@ def nb_kernel_name(nb, name=None): ValueError If no kernel name is found or provided """ - name = name or nb.metadata.get("kernelspec", {}).get("name") + name = name or nb.metadata.get('kernelspec', {}).get('name') if not name: - raise ValueError("No kernel name found in notebook and no override provided.") + raise ValueError('No kernel name found in notebook and no override provided.') return name @@ -74,12 +73,12 @@ def nb_language(nb, language=None): ValueError If no notebook language is found or provided """ - language = language or nb.metadata.get("language_info", {}).get("name") + language = language or nb.metadata.get('language_info', {}).get('name') if not language: # v3 language path for old notebooks that didn't convert cleanly - language = language or nb.metadata.get("kernelspec", {}).get("language") + language = language or nb.metadata.get('kernelspec', {}).get('language') if not language: - raise ValueError("No language found in notebook and no override provided.") + raise ValueError('No language found in notebook and no override provided.') return language @@ -129,7 +128,7 @@ def merge_kwargs(caller_args, **callee_args): conflicts = set(caller_args) & set(callee_args) if conflicts: args = format( - "; ".join([f"{key}={value}" for key, value in callee_args.items()]) + '; '.join([f'{key}={value}' for key, value in callee_args.items()]) ) msg = f"Callee will overwrite caller's argument(s): {args}" warnings.warn(msg, PapermillParameterOverwriteWarning) @@ -167,7 +166,7 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: - logger.debug(f"Retrying after: {e}") + logger.debug(f'Retrying after: {e}') exception = e else: raise exception diff --git a/papermill/version.py b/papermill/version.py index 824cbf24..3d98bc1d 100644 --- a/papermill/version.py +++ b/papermill/version.py @@ -1 +1 @@ -version = "2.5.0" +version = '2.5.0'