From 28f243e27feb519dd25055871b39cb91f2c74ee5 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Thu, 15 Jun 2023 11:38:15 -0500 Subject: [PATCH] fix: protect ak.to_parquet against memory explosion when args are swapped. (#2523) * Protect ak.to_parquet against memory explosion when args are swapped. * Also let 'destination' in ak.to_parquet be a pathlib.Path. * feat: support any path-like * refactor: remove unused path helper * refactor: use `fsdecode` in more places * fix: `PathLike` does not include str or bytes --------- Co-authored-by: Angus Hollands --- src/awkward/_regularize.py | 24 --------------------- src/awkward/operations/ak_from_avro_file.py | 9 ++++---- src/awkward/operations/ak_from_json.py | 2 +- src/awkward/operations/ak_to_json.py | 8 +++---- src/awkward/operations/ak_to_parquet.py | 12 +++++++++-- 5 files changed, 19 insertions(+), 36 deletions(-) diff --git a/src/awkward/_regularize.py b/src/awkward/_regularize.py index 0b7f09f2dc..a1a472859a 100644 --- a/src/awkward/_regularize.py +++ b/src/awkward/_regularize.py @@ -50,30 +50,6 @@ def is_non_string_like_sequence(obj) -> bool: return not isinstance(obj, (str, bytes)) and isinstance(obj, Sequence) -def regularize_path(path): - """ - Converts pathlib Paths into plain string paths (for all versions of Python). - """ - is_path = False - - if isinstance(path, getattr(os, "PathLike", ())): - is_path = True - path = os.fspath(path) - - elif hasattr(path, "__fspath__"): - is_path = True - path = path.__fspath__() - - elif path.__class__.__module__ == "pathlib": - import pathlib - - if isinstance(path, pathlib.Path): - is_path = True - path = str(path) - - return is_path, path - - def regularize_axis(axis: SupportsInt | None) -> AxisMaybeNone: if axis is None: return None diff --git a/src/awkward/operations/ak_from_avro_file.py b/src/awkward/operations/ak_from_avro_file.py index 7c07cde565..a15b1f54d2 100644 --- a/src/awkward/operations/ak_from_avro_file.py +++ b/src/awkward/operations/ak_from_avro_file.py @@ -1,8 +1,7 @@ # BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE __all__ = ("from_avro_file",) -# from awkward._typing import Type -import pathlib +from os import PathLike, fsdecode import awkward as ak from awkward._nplikes.numpylike import NumpyMetadata @@ -15,7 +14,7 @@ def from_avro_file( ): """ Args: - file (string or file-like object): Avro file to be read as Awkward Array. + file (path-like or file-like object): Avro file to be read as Awkward Array. limit_entries (int): The number of rows of the Avro file to be read into the Awkward Array. debug_forth (bool): If True, prints the generated Forth code for debugging. highlevel (bool): If True, return an #ak.Array; otherwise, return @@ -40,8 +39,8 @@ def from_avro_file( "debug_forth": debug_forth, }, ): - if isinstance(file, pathlib.Path): - file = str(file) + if isinstance(file, (str, bytes, PathLike)): + file = fsdecode(file) if isinstance(file, str): with open(file, "rb") as opened_file: diff --git a/src/awkward/operations/ak_from_json.py b/src/awkward/operations/ak_from_json.py index 1b54a83c94..755146a453 100644 --- a/src/awkward/operations/ak_from_json.py +++ b/src/awkward/operations/ak_from_json.py @@ -399,7 +399,7 @@ def __exit__(self, exception_type, exception_value, exception_traceback): def _get_reader(source): - if not isinstance(source, pathlib.Path) and isinstance(source, str): + if isinstance(source, str): source = source.encode("utf8", errors="surrogateescape") if isinstance(source, bytes): diff --git a/src/awkward/operations/ak_to_json.py b/src/awkward/operations/ak_to_json.py index 300df2ce61..e069056922 100644 --- a/src/awkward/operations/ak_to_json.py +++ b/src/awkward/operations/ak_to_json.py @@ -1,8 +1,8 @@ # BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE __all__ = ("to_json",) import json -import pathlib from numbers import Number +from os import PathLike, fsdecode from urllib.parse import urlparse from awkward_cpp.lib import _ext @@ -33,7 +33,7 @@ def to_json( """ Args: array: Array-like data (anything #ak.to_layout recognizes). - file (None, str/pathlib.Path, or file-like object): If None, this function returns + file (None, path-like, or file-like object): If None, this function returns JSON-encoded bytes. Otherwise, this function has no return value. If a string/pathlib.Path, this function opens a file with that name, writes JSON data, and closes the file. If that path has a URI protocol (like @@ -205,8 +205,8 @@ def _impl( ) if file is not None: - if isinstance(file, (str, pathlib.Path)): - parsed_url = urlparse(file) + if isinstance(file, (str, bytes, PathLike)): + parsed_url = urlparse(fsdecode(file)) if parsed_url.scheme == "" or parsed_url.netloc == "": def opener(): diff --git a/src/awkward/operations/ak_to_parquet.py b/src/awkward/operations/ak_to_parquet.py index b4fe934821..07371ea82b 100644 --- a/src/awkward/operations/ak_to_parquet.py +++ b/src/awkward/operations/ak_to_parquet.py @@ -2,6 +2,7 @@ __all__ = ("to_parquet",) from collections.abc import Mapping, Sequence +from os import fsdecode import awkward as ak from awkward._nplikes.numpylike import NumpyMetadata @@ -39,8 +40,8 @@ def to_parquet( """ Args: array: Array-like data (anything #ak.to_layout recognizes). - destination (str): Name of the output file, file path, or remote URL passed to - [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs) + destination (path-like): Name of the output file, file path, or + remote URL passed to [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs) for remote writing. list_to32 (bool): If True, convert Awkward lists into 32-bit Arrow lists if they're small enough, even if it means an extra conversion. Otherwise, @@ -291,6 +292,13 @@ def parquet_columns(specifier, only=None): if parquet_extra_options is None: parquet_extra_options = {} + try: + destination = fsdecode(destination) + except TypeError: + raise TypeError( + f"'destination' argument of 'ak.to_parquet' must be a path-like, not {type(destination).__name__} ('array' argument is first; 'destination' second)" + ) from None + fs, destination = fsspec.core.url_to_fs(destination, **(storage_options or {})) metalist = [] with pyarrow_parquet.ParquetWriter(