From 7df15b395e03af8f5fce8599c064218f521595b8 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Tue, 25 Jun 2024 09:29:11 -0400 Subject: [PATCH] Use Arrow stream interface for public API (#69) * Use Arrow stream interface for public API --- stac_geoparquet/arrow/_api.py | 82 +++++++++++++++++++++------- stac_geoparquet/arrow/_delta_lake.py | 13 ++--- stac_geoparquet/arrow/_to_parquet.py | 52 +++++++++++------- stac_geoparquet/arrow/types.py | 5 ++ 4 files changed, 105 insertions(+), 47 deletions(-) create mode 100644 stac_geoparquet/arrow/types.py diff --git a/stac_geoparquet/arrow/_api.py b/stac_geoparquet/arrow/_api.py index f563108..6e093f6 100644 --- a/stac_geoparquet/arrow/_api.py +++ b/stac_geoparquet/arrow/_api.py @@ -1,8 +1,9 @@ from __future__ import annotations +import itertools import os from pathlib import Path -from typing import Any, Iterable, Iterator +from typing import Any, Iterable import pyarrow as pa @@ -10,6 +11,7 @@ from stac_geoparquet.arrow._constants import DEFAULT_JSON_CHUNK_SIZE from stac_geoparquet.arrow._schema.models import InferredSchema from stac_geoparquet.arrow._util import batched_iter +from stac_geoparquet.arrow.types import ArrowStreamExportable from stac_geoparquet.json_reader import read_json_chunked @@ -18,7 +20,7 @@ def parse_stac_items_to_arrow( *, chunk_size: int = 8192, schema: pa.Schema | InferredSchema | None = None, -) -> Iterable[pa.RecordBatch]: +) -> pa.RecordBatchReader: """ Parse a collection of STAC Items to an iterable of [`pyarrow.RecordBatch`][pyarrow.RecordBatch]. @@ -37,7 +39,7 @@ def parse_stac_items_to_arrow( inference. Defaults to None. Returns: - an iterable of pyarrow RecordBatches with the STAC-GeoParquet representation of items. + pyarrow RecordBatchReader with a stream of STAC Arrow RecordBatches. """ if schema is not None: if isinstance(schema, InferredSchema): @@ -45,15 +47,19 @@ def parse_stac_items_to_arrow( # If schema is provided, then for better memory usage we parse input STAC items # to Arrow batches in chunks. - for chunk in batched_iter(items, chunk_size): - yield stac_items_to_arrow(chunk, schema=schema) + batches = ( + stac_items_to_arrow(batch, schema=schema) + for batch in batched_iter(items, chunk_size) + ) + return pa.RecordBatchReader.from_batches(schema, batches) else: # If schema is _not_ provided, then we must convert to Arrow all at once, or # else it would be possible for a STAC item late in the collection (after the # first chunk) to have a different schema and not match the schema inferred for # the first chunk. - yield stac_items_to_arrow(items) + batch = stac_items_to_arrow(items) + return pa.RecordBatchReader.from_batches(batch.schema, [batch]) def parse_stac_ndjson_to_arrow( @@ -62,7 +68,7 @@ def parse_stac_ndjson_to_arrow( chunk_size: int = DEFAULT_JSON_CHUNK_SIZE, schema: pa.Schema | None = None, limit: int | None = None, -) -> Iterator[pa.RecordBatch]: +) -> pa.RecordBatchReader: """ Convert one or more newline-delimited JSON STAC files to a generator of Arrow RecordBatches. @@ -81,8 +87,8 @@ def parse_stac_ndjson_to_arrow( Keyword Args: limit: The maximum number of JSON Items to use for schema inference - Yields: - Arrow RecordBatch with a single chunk of Item data. + Returns: + pyarrow RecordBatchReader with a stream of STAC Arrow RecordBatches. """ # If the schema was not provided, then we need to load all data into memory at once # to perform schema resolution. @@ -90,30 +96,68 @@ def parse_stac_ndjson_to_arrow( inferred_schema = InferredSchema() inferred_schema.update_from_json(path, chunk_size=chunk_size, limit=limit) inferred_schema.manual_updates() - yield from parse_stac_ndjson_to_arrow( + return parse_stac_ndjson_to_arrow( path, chunk_size=chunk_size, schema=inferred_schema ) - return if isinstance(schema, InferredSchema): schema = schema.inner - for batch in read_json_chunked(path, chunk_size=chunk_size): - yield stac_items_to_arrow(batch, schema=schema) + batches_iter = ( + stac_items_to_arrow(batch, schema=schema) + for batch in read_json_chunked(path, chunk_size=chunk_size) + ) + first_batch = next(batches_iter) + # Need to take this schema from the iterator; the existing `schema` is the schema of + # JSON batch + resolved_schema = first_batch.schema + return pa.RecordBatchReader.from_batches( + resolved_schema, itertools.chain([first_batch], batches_iter) + ) + + +def stac_table_to_items( + table: pa.Table | pa.RecordBatchReader | ArrowStreamExportable, +) -> Iterable[dict]: + """Convert STAC Arrow to a generator of STAC Item `dict`s. + + Args: + table: STAC in Arrow form. This can be a pyarrow Table, a pyarrow + RecordBatchReader, or any other Arrow stream object exposed through the + [Arrow PyCapsule + Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html). + A RecordBatchReader or stream object will not be materialized in memory. + Yields: + A STAC `dict` for each input row. + """ + # Coerce to record batch reader to avoid materializing entire stream + reader = pa.RecordBatchReader.from_stream(table) -def stac_table_to_items(table: pa.Table) -> Iterable[dict]: - """Convert a STAC Table to a generator of STAC Item `dict`s""" - for batch in table.to_batches(): + for batch in reader: clean_batch = StacArrowBatch(batch) yield from clean_batch.to_json_batch().iter_dicts() def stac_table_to_ndjson( - table: pa.Table, dest: str | Path | os.PathLike[bytes] + table: pa.Table | pa.RecordBatchReader | ArrowStreamExportable, + dest: str | Path | os.PathLike[bytes], ) -> None: - """Write a STAC Table to a newline-delimited JSON file.""" - for batch in table.to_batches(): + """Write STAC Arrow to a newline-delimited JSON file. + + Args: + table: STAC in Arrow form. This can be a pyarrow Table, a pyarrow + RecordBatchReader, or any other Arrow stream object exposed through the + [Arrow PyCapsule + Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html). + A RecordBatchReader or stream object will not be materialized in memory. + dest: The destination where newline-delimited JSON should be written. + """ + + # Coerce to record batch reader to avoid materializing entire stream + reader = pa.RecordBatchReader.from_stream(table) + + for batch in reader: clean_batch = StacArrowBatch(batch) clean_batch.to_json_batch().to_ndjson(dest) diff --git a/stac_geoparquet/arrow/_delta_lake.py b/stac_geoparquet/arrow/_delta_lake.py index 35ee065..2a5a71a 100644 --- a/stac_geoparquet/arrow/_delta_lake.py +++ b/stac_geoparquet/arrow/_delta_lake.py @@ -1,6 +1,5 @@ from __future__ import annotations -import itertools from pathlib import Path from typing import TYPE_CHECKING, Any, Iterable @@ -47,14 +46,14 @@ def parse_stac_ndjson_to_delta_lake( schema_version: GeoParquet specification version; if not provided will default to latest supported version. """ - batches_iter = parse_stac_ndjson_to_arrow( + record_batch_reader = parse_stac_ndjson_to_arrow( input_path, chunk_size=chunk_size, schema=schema, limit=limit ) - first_batch = next(batches_iter) - schema = first_batch.schema.with_metadata( + schema = record_batch_reader.schema.with_metadata( create_geoparquet_metadata( - pa.Table.from_batches([first_batch]), schema_version=schema_version + record_batch_reader.schema, schema_version=schema_version ) ) - combined_iter = itertools.chain([first_batch], batches_iter) - write_deltalake(table_or_uri, combined_iter, schema=schema, engine="rust", **kwargs) + write_deltalake( + table_or_uri, record_batch_reader, schema=schema, engine="rust", **kwargs + ) diff --git a/stac_geoparquet/arrow/_to_parquet.py b/stac_geoparquet/arrow/_to_parquet.py index b29f329..0cc6565 100644 --- a/stac_geoparquet/arrow/_to_parquet.py +++ b/stac_geoparquet/arrow/_to_parquet.py @@ -15,6 +15,7 @@ ) from stac_geoparquet.arrow._crs import WGS84_CRS_JSON from stac_geoparquet.arrow._schema.models import InferredSchema +from stac_geoparquet.arrow.types import ArrowStreamExportable def parse_stac_ndjson_to_parquet( @@ -43,26 +44,24 @@ def parse_stac_ndjson_to_parquet( limit: The maximum number of JSON records to convert. schema_version: GeoParquet specification version; if not provided will default to latest supported version. - """ - batches_iter = parse_stac_ndjson_to_arrow( + All other keyword args are passed on to + [`pyarrow.parquet.ParquetWriter`][pyarrow.parquet.ParquetWriter]. + """ + record_batch_reader = parse_stac_ndjson_to_arrow( input_path, chunk_size=chunk_size, schema=schema, limit=limit ) - first_batch = next(batches_iter) - schema = first_batch.schema.with_metadata( - create_geoparquet_metadata( - pa.Table.from_batches([first_batch]), schema_version=schema_version - ) + to_parquet( + record_batch_reader, + output_path=output_path, + schema_version=schema_version, + **kwargs, ) - with pq.ParquetWriter(output_path, schema, **kwargs) as writer: - writer.write_batch(first_batch) - for batch in batches_iter: - writer.write_batch(batch) def to_parquet( - table: pa.Table, - where: Any, + table: pa.Table | pa.RecordBatchReader | ArrowStreamExportable, + output_path: str | Path, *, schema_version: SUPPORTED_PARQUET_SCHEMA_VERSIONS = DEFAULT_PARQUET_SCHEMA_VERSION, **kwargs: Any, @@ -72,22 +71,33 @@ def to_parquet( This writes metadata compliant with either GeoParquet 1.0 or 1.1. Args: - table: The table to write to Parquet - where: The destination for saving. + table: STAC in Arrow form. This can be a pyarrow Table, a pyarrow + RecordBatchReader, or any other Arrow stream object exposed through the + [Arrow PyCapsule + Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html). + A RecordBatchReader or stream object will not be materialized in memory. + output_path: The destination for saving. Keyword Args: schema_version: GeoParquet specification version; if not provided will default to latest supported version. + + All other keyword args are passed on to + [`pyarrow.parquet.ParquetWriter`][pyarrow.parquet.ParquetWriter]. """ - metadata = table.schema.metadata or {} - metadata.update(create_geoparquet_metadata(table, schema_version=schema_version)) - table = table.replace_schema_metadata(metadata) + # Coerce to record batch reader to avoid materializing entire stream + reader = pa.RecordBatchReader.from_stream(table) - pq.write_table(table, where, **kwargs) + schema = reader.schema.with_metadata( + create_geoparquet_metadata(reader.schema, schema_version=schema_version) + ) + with pq.ParquetWriter(output_path, schema, **kwargs) as writer: + for batch in reader: + writer.write_batch(batch) def create_geoparquet_metadata( - table: pa.Table, + schema: pa.Schema, *, schema_version: SUPPORTED_PARQUET_SCHEMA_VERSIONS, ) -> dict[bytes, bytes]: @@ -116,7 +126,7 @@ def create_geoparquet_metadata( "primary_column": "geometry", } - if "proj:geometry" in table.schema.names: + if "proj:geometry" in schema.names: # Note we don't include proj:bbox as a covering here for a couple different # reasons. For one, it's very common for the projected geometries to have a # different CRS in each row, so having statistics for proj:bbox wouldn't be diff --git a/stac_geoparquet/arrow/types.py b/stac_geoparquet/arrow/types.py new file mode 100644 index 0000000..d7e56b4 --- /dev/null +++ b/stac_geoparquet/arrow/types.py @@ -0,0 +1,5 @@ +from typing import Protocol + + +class ArrowStreamExportable(Protocol): + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ... # noqa