Skip to content

Commit

Permalink
fix: protect ak.to_parquet against memory explosion when args are swa…
Browse files Browse the repository at this point in the history
…pped. (#2523)

* Protect ak.to_parquet against memory explosion when args are swapped.

* Also let 'destination' in ak.to_parquet be a pathlib.Path.

* feat: support any path-like

* refactor: remove unused path helper

* refactor: use `fsdecode` in more places

* fix: `PathLike` does not include str or bytes

---------

Co-authored-by: Angus Hollands <[email protected]>
  • Loading branch information
jpivarski and agoose77 authored Jun 15, 2023
1 parent e627d4a commit 28f243e
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 36 deletions.
24 changes: 0 additions & 24 deletions src/awkward/_regularize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,6 @@ def is_non_string_like_sequence(obj) -> bool:
return not isinstance(obj, (str, bytes)) and isinstance(obj, Sequence)


def regularize_path(path):
"""
Converts pathlib Paths into plain string paths (for all versions of Python).
"""
is_path = False

if isinstance(path, getattr(os, "PathLike", ())):
is_path = True
path = os.fspath(path)

elif hasattr(path, "__fspath__"):
is_path = True
path = path.__fspath__()

elif path.__class__.__module__ == "pathlib":
import pathlib

if isinstance(path, pathlib.Path):
is_path = True
path = str(path)

return is_path, path


def regularize_axis(axis: SupportsInt | None) -> AxisMaybeNone:
if axis is None:
return None
Expand Down
9 changes: 4 additions & 5 deletions src/awkward/operations/ak_from_avro_file.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
__all__ = ("from_avro_file",)

# from awkward._typing import Type
import pathlib
from os import PathLike, fsdecode

import awkward as ak
from awkward._nplikes.numpylike import NumpyMetadata
Expand All @@ -15,7 +14,7 @@ def from_avro_file(
):
"""
Args:
file (string or file-like object): Avro file to be read as Awkward Array.
file (path-like or file-like object): Avro file to be read as Awkward Array.
limit_entries (int): The number of rows of the Avro file to be read into the Awkward Array.
debug_forth (bool): If True, prints the generated Forth code for debugging.
highlevel (bool): If True, return an #ak.Array; otherwise, return
Expand All @@ -40,8 +39,8 @@ def from_avro_file(
"debug_forth": debug_forth,
},
):
if isinstance(file, pathlib.Path):
file = str(file)
if isinstance(file, (str, bytes, PathLike)):
file = fsdecode(file)

if isinstance(file, str):
with open(file, "rb") as opened_file:
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_from_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def __exit__(self, exception_type, exception_value, exception_traceback):


def _get_reader(source):
if not isinstance(source, pathlib.Path) and isinstance(source, str):
if isinstance(source, str):
source = source.encode("utf8", errors="surrogateescape")

if isinstance(source, bytes):
Expand Down
8 changes: 4 additions & 4 deletions src/awkward/operations/ak_to_json.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
__all__ = ("to_json",)
import json
import pathlib
from numbers import Number
from os import PathLike, fsdecode
from urllib.parse import urlparse

from awkward_cpp.lib import _ext
Expand Down Expand Up @@ -33,7 +33,7 @@ def to_json(
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
file (None, str/pathlib.Path, or file-like object): If None, this function returns
file (None, path-like, or file-like object): If None, this function returns
JSON-encoded bytes. Otherwise, this function has no return value.
If a string/pathlib.Path, this function opens a file with that name, writes JSON
data, and closes the file. If that path has a URI protocol (like
Expand Down Expand Up @@ -205,8 +205,8 @@ def _impl(
)

if file is not None:
if isinstance(file, (str, pathlib.Path)):
parsed_url = urlparse(file)
if isinstance(file, (str, bytes, PathLike)):
parsed_url = urlparse(fsdecode(file))
if parsed_url.scheme == "" or parsed_url.netloc == "":

def opener():
Expand Down
12 changes: 10 additions & 2 deletions src/awkward/operations/ak_to_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__all__ = ("to_parquet",)

from collections.abc import Mapping, Sequence
from os import fsdecode

import awkward as ak
from awkward._nplikes.numpylike import NumpyMetadata
Expand Down Expand Up @@ -39,8 +40,8 @@ def to_parquet(
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
destination (str): Name of the output file, file path, or remote URL passed to
[fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs)
destination (path-like): Name of the output file, file path, or
remote URL passed to [fsspec.core.url_to_fs](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.core.url_to_fs)
for remote writing.
list_to32 (bool): If True, convert Awkward lists into 32-bit Arrow lists
if they're small enough, even if it means an extra conversion. Otherwise,
Expand Down Expand Up @@ -291,6 +292,13 @@ def parquet_columns(specifier, only=None):
if parquet_extra_options is None:
parquet_extra_options = {}

try:
destination = fsdecode(destination)
except TypeError:
raise TypeError(
f"'destination' argument of 'ak.to_parquet' must be a path-like, not {type(destination).__name__} ('array' argument is first; 'destination' second)"
) from None

fs, destination = fsspec.core.url_to_fs(destination, **(storage_options or {}))
metalist = []
with pyarrow_parquet.ParquetWriter(
Expand Down

0 comments on commit 28f243e

Please sign in to comment.