From 5c0a682dbdc67f9f95d9beead1e8ddd368ee9049 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Wed, 24 Apr 2024 17:15:08 -0400 Subject: [PATCH] Move arrow-based code into arrow module (#47) * Move arrow-based code into arrow module * fix tests import * deprecation --------- Co-authored-by: Tom Augspurger --- stac_geoparquet/__init__.py | 4 +- stac_geoparquet/arrow/__init__.py | 3 + stac_geoparquet/arrow/_from_arrow.py | 166 ++++++++++++ stac_geoparquet/arrow/_to_arrow.py | 364 ++++++++++++++++++++++++++ stac_geoparquet/arrow/_to_parquet.py | 46 ++++ stac_geoparquet/from_arrow.py | 170 +------------ stac_geoparquet/to_arrow.py | 368 +-------------------------- stac_geoparquet/to_parquet.py | 50 +--- tests/test_arrow.py | 23 +- 9 files changed, 620 insertions(+), 574 deletions(-) create mode 100644 stac_geoparquet/arrow/__init__.py create mode 100644 stac_geoparquet/arrow/_from_arrow.py create mode 100644 stac_geoparquet/arrow/_to_arrow.py create mode 100644 stac_geoparquet/arrow/_to_parquet.py diff --git a/stac_geoparquet/__init__.py b/stac_geoparquet/__init__.py index cfa46a7..4da6147 100644 --- a/stac_geoparquet/__init__.py +++ b/stac_geoparquet/__init__.py @@ -1,8 +1,8 @@ """stac-geoparquet""" -from .stac_geoparquet import to_geodataframe, to_dict, to_item_collection +from . import arrow from ._version import __version__ - +from .stac_geoparquet import to_dict, to_geodataframe, to_item_collection __all__ = [ "__version__", diff --git a/stac_geoparquet/arrow/__init__.py b/stac_geoparquet/arrow/__init__.py new file mode 100644 index 0000000..ee781a3 --- /dev/null +++ b/stac_geoparquet/arrow/__init__.py @@ -0,0 +1,3 @@ +from ._from_arrow import stac_table_to_items, stac_table_to_ndjson +from ._to_arrow import parse_stac_items_to_arrow, parse_stac_ndjson_to_arrow +from ._to_parquet import to_parquet diff --git a/stac_geoparquet/arrow/_from_arrow.py b/stac_geoparquet/arrow/_from_arrow.py new file mode 100644 index 0000000..f940864 --- /dev/null +++ b/stac_geoparquet/arrow/_from_arrow.py @@ -0,0 +1,166 @@ +"""Convert STAC Items in Arrow Table format to JSON Lines or Python dicts.""" + +import os +import json +from typing import Iterable, List, Union + +import numpy as np +import pyarrow as pa +import pyarrow.compute as pc +import shapely + + +def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None: + """Write a STAC Table to a newline-delimited JSON file.""" + with open(dest, "w") as f: + for item_dict in stac_table_to_items(table): + json.dump(item_dict, f, separators=(",", ":")) + f.write("\n") + + +def stac_table_to_items(table: pa.Table) -> Iterable[dict]: + """Convert a STAC Table to a generator of STAC Item `dict`s""" + table = _undo_stac_table_transformations(table) + + # Convert WKB geometry column to GeoJSON, and then assign the geojson geometry when + # converting each row to a dictionary. + for batch in table.to_batches(): + geoms = shapely.from_wkb(batch["geometry"]) + geojson_strings = shapely.to_geojson(geoms) + + # RecordBatch is missing a `drop()` method, so we keep all columns other than + # geometry instead + keep_column_names = [name for name in batch.column_names if name != "geometry"] + struct_batch = batch.select(keep_column_names).to_struct_array() + + for row_idx in range(len(struct_batch)): + row_dict = struct_batch[row_idx].as_py() + row_dict["geometry"] = json.loads(geojson_strings[row_idx]) + yield row_dict + + +def _undo_stac_table_transformations(table: pa.Table) -> pa.Table: + """Undo the transformations done to convert STAC Json into an Arrow Table + + Note that this function does _not_ undo the GeoJSON -> WKB geometry transformation, + as that is easier to do when converting each item in the table to a dict. + """ + table = _convert_timestamp_columns_to_string(table) + table = _lower_properties_from_top_level(table) + table = _convert_bbox_to_array(table) + return table + + +def _convert_timestamp_columns_to_string(table: pa.Table) -> pa.Table: + """Convert any datetime columns in the table to a string representation""" + allowed_column_names = { + "datetime", # common metadata + "start_datetime", + "end_datetime", + "created", + "updated", + "expires", # timestamps extension + "published", + "unpublished", + } + for column_name in allowed_column_names: + try: + column = table[column_name] + except KeyError: + continue + + table = table.drop(column_name).append_column( + column_name, pc.strftime(column, format="%Y-%m-%dT%H:%M:%SZ") + ) + + return table + + +def _lower_properties_from_top_level(table: pa.Table) -> pa.Table: + """Take properties columns from the top level and wrap them in a struct column""" + stac_top_level_keys = { + "stac_version", + "stac_extensions", + "type", + "id", + "bbox", + "geometry", + "collection", + "links", + "assets", + } + + properties_column_names: List[str] = [] + properties_column_fields: List[pa.Field] = [] + for column_idx in range(table.num_columns): + column_name = table.column_names[column_idx] + if column_name in stac_top_level_keys: + continue + + properties_column_names.append(column_name) + properties_column_fields.append(table.schema.field(column_idx)) + + properties_array_chunks = [] + for batch in table.select(properties_column_names).to_batches(): + struct_arr = pa.StructArray.from_arrays( + batch.columns, fields=properties_column_fields + ) + properties_array_chunks.append(struct_arr) + + return table.drop_columns(properties_column_names).append_column( + "properties", pa.chunked_array(properties_array_chunks) + ) + + +def _convert_bbox_to_array(table: pa.Table) -> pa.Table: + """Convert the struct bbox column back to a list column for writing to JSON""" + + bbox_col_idx = table.schema.get_field_index("bbox") + bbox_col = table.column(bbox_col_idx) + + new_chunks = [] + for chunk in bbox_col.chunks: + assert pa.types.is_struct(chunk.type) + + if bbox_col.type.num_fields == 4: + xmin = chunk.field("xmin").to_numpy() + ymin = chunk.field("ymin").to_numpy() + xmax = chunk.field("xmax").to_numpy() + ymax = chunk.field("ymax").to_numpy() + coords = np.column_stack( + [ + xmin, + ymin, + xmax, + ymax, + ] + ) + + list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4) + + elif bbox_col.type.num_fields == 6: + xmin = chunk.field("xmin").to_numpy() + ymin = chunk.field("ymin").to_numpy() + zmin = chunk.field("zmin").to_numpy() + xmax = chunk.field("xmax").to_numpy() + ymax = chunk.field("ymax").to_numpy() + zmax = chunk.field("zmax").to_numpy() + coords = np.column_stack( + [ + xmin, + ymin, + zmin, + xmax, + ymax, + zmax, + ] + ) + + list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 6) + + else: + raise ValueError("Expected 4 or 6 fields in bbox struct.") + + new_chunks.append(list_arr) + + return table.set_column(bbox_col_idx, "bbox", new_chunks) diff --git a/stac_geoparquet/arrow/_to_arrow.py b/stac_geoparquet/arrow/_to_arrow.py new file mode 100644 index 0000000..b5b1f06 --- /dev/null +++ b/stac_geoparquet/arrow/_to_arrow.py @@ -0,0 +1,364 @@ +"""Convert STAC data into Arrow tables""" + +import json +from copy import deepcopy +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Union, Generator + +import ciso8601 +import numpy as np +import pyarrow as pa +import pyarrow.compute as pc +import shapely +import shapely.geometry + + +def _chunks( + lst: Sequence[Dict[str, Any]], n: int +) -> Generator[Sequence[Dict[str, Any]], None, None]: + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +def parse_stac_items_to_arrow( + items: Sequence[Dict[str, Any]], + *, + chunk_size: int = 8192, + schema: Optional[pa.Schema] = None, + downcast: bool = True, +) -> pa.Table: + """Parse a collection of STAC Items to a :class:`pyarrow.Table`. + + The objects under `properties` are moved up to the top-level of the + Table, similar to :meth:`geopandas.GeoDataFrame.from_features`. + + Args: + items: the STAC Items to convert + chunk_size: The chunk size to use for Arrow record batches. This only takes + effect if `schema` is not None. When `schema` is None, the input will be + parsed into a single contiguous record batch. Defaults to 8192. + schema: The schema of the input data. If provided, can improve memory use; + otherwise all items need to be parsed into a single array for schema + inference. Defaults to None. + downcast: if True, store bbox as float32 for memory and disk saving. + + Returns: + a pyarrow Table with the STAC-GeoParquet representation of items. + """ + + if schema is not None: + # If schema is provided, then for better memory usage we parse input STAC items + # to Arrow batches in chunks. + batches = [] + for chunk in _chunks(items, chunk_size): + batches.append(_stac_items_to_arrow(chunk, schema=schema)) + + table = pa.Table.from_batches(batches, schema=schema) + else: + # If schema is _not_ provided, then we must convert to Arrow all at once, or + # else it would be possible for a STAC item late in the collection (after the + # first chunk) to have a different schema and not match the schema inferred for + # the first chunk. + table = pa.Table.from_batches([_stac_items_to_arrow(items)]) + + return _process_arrow_table(table, downcast=downcast) + + +def parse_stac_ndjson_to_arrow( + path: Union[str, Path], + *, + chunk_size: int = 8192, + schema: Optional[pa.Schema] = None, + downcast: bool = True, +) -> pa.Table: + # Define outside of if/else to make mypy happy + items: List[dict] = [] + + # If the schema was not provided, then we need to load all data into memory at once + # to perform schema resolution. + if schema is None: + with open(path) as f: + for line in f: + items.append(json.loads(line)) + + return parse_stac_items_to_arrow(items, chunk_size=chunk_size, schema=schema) + + # Otherwise, we can stream over the input, converting each batch of `chunk_size` + # into an Arrow RecordBatch at a time. This is much more memory efficient. + with open(path) as f: + batches: List[pa.RecordBatch] = [] + for line in f: + items.append(json.loads(line)) + + if len(items) >= chunk_size: + batches.append(_stac_items_to_arrow(items, schema=schema)) + items = [] + + # Don't forget the last chunk in case the total number of items is not a multiple of + # chunk_size. + if len(items) > 0: + batches.append(_stac_items_to_arrow(items, schema=schema)) + + table = pa.Table.from_batches(batches, schema=schema) + return _process_arrow_table(table, downcast=downcast) + + +def _process_arrow_table(table: pa.Table, *, downcast: bool = True) -> pa.Table: + table = _bring_properties_to_top_level(table) + table = _convert_timestamp_columns(table) + table = _convert_bbox_to_struct(table, downcast=downcast) + return table + + +def _stac_items_to_arrow( + items: Sequence[Dict[str, Any]], *, schema: Optional[pa.Schema] = None +) -> pa.RecordBatch: + """Convert dicts representing STAC Items to Arrow + + This converts GeoJSON geometries to WKB before Arrow conversion to allow multiple + geometry types. + + All items will be parsed into a single RecordBatch, meaning that each internal array + is fully contiguous in memory for the length of `items`. + + Args: + items: STAC Items to convert to Arrow + + Kwargs: + schema: An optional schema that describes the format of the data. Note that this + must represent the geometry column as binary type. + + Returns: + Arrow RecordBatch with items in Arrow + """ + # Preprocess GeoJSON to WKB in each STAC item + # Otherwise, pyarrow will try to parse coordinates into a native geometry type and + # if you have multiple geometry types pyarrow will error with + # `ArrowInvalid: cannot mix list and non-list, non-null values` + wkb_items = [] + for item in items: + wkb_item = deepcopy(item) + # Note: this mutates the existing items. Should we + wkb_item["geometry"] = shapely.to_wkb( + shapely.geometry.shape(wkb_item["geometry"]), flavor="iso" + ) + wkb_items.append(wkb_item) + + if schema is not None: + array = pa.array(wkb_items, type=pa.struct(schema)) + else: + array = pa.array(wkb_items) + return pa.RecordBatch.from_struct_array(array) + + +def _bring_properties_to_top_level(table: pa.Table) -> pa.Table: + """Bring all the fields inside of the nested "properties" struct to the top level""" + properties_field = table.schema.field("properties") + properties_column = table["properties"] + + for field_idx in range(properties_field.type.num_fields): + inner_prop_field = properties_field.type.field(field_idx) + table = table.append_column( + inner_prop_field, pc.struct_field(properties_column, field_idx) + ) + + table = table.drop("properties") + return table + + +def _convert_geometry_to_wkb(table: pa.Table) -> pa.Table: + """Convert the geometry column in the table to WKB""" + geoms = shapely.from_geojson( + [json.dumps(item) for item in table["geometry"].to_pylist()] + ) + wkb_geoms = shapely.to_wkb(geoms) + return table.drop("geometry").append_column("geometry", pa.array(wkb_geoms)) + + +def _convert_timestamp_columns(table: pa.Table) -> pa.Table: + """Convert all timestamp columns from a string to an Arrow Timestamp data type""" + allowed_column_names = { + "datetime", # common metadata + "start_datetime", + "end_datetime", + "created", + "updated", + "expires", # timestamps extension + "published", + "unpublished", + } + for column_name in allowed_column_names: + try: + column = table[column_name] + except KeyError: + continue + + field_index = table.schema.get_field_index(column_name) + + if pa.types.is_timestamp(column.type): + continue + + # STAC allows datetimes to be null. If all rows are null, the column type may be + # inferred as null. We cast this to a timestamp column. + elif pa.types.is_null(column.type): + table = table.set_column( + field_index, column_name, column.cast(pa.timestamp("us")) + ) + + elif pa.types.is_string(column.type): + table = table.set_column( + field_index, column_name, _convert_timestamp_column(column) + ) + else: + raise ValueError( + f"Inferred time column '{column_name}' was expected to be a string or" + f" timestamp data type but got {column.type}" + ) + + return table + + +def _convert_timestamp_column(column: pa.ChunkedArray) -> pa.ChunkedArray: + """Convert an individual timestamp column from string to a Timestamp type""" + chunks = [] + for chunk in column.chunks: + parsed_chunk: List[Optional[datetime]] = [] + for item in chunk: + if not item.is_valid: + parsed_chunk.append(None) + else: + parsed_chunk.append(ciso8601.parse_rfc3339(item.as_py())) + + pyarrow_chunk = pa.array(parsed_chunk) + chunks.append(pyarrow_chunk) + + return pa.chunked_array(chunks) + + +def _is_bbox_3d(bbox_col: pa.ChunkedArray) -> bool: + """Infer whether the bounding box column represents 2d or 3d bounding boxes.""" + offsets_set = set() + for chunk in bbox_col.chunks: + offsets = chunk.offsets.to_numpy() + offsets_set.update(np.unique(offsets[1:] - offsets[:-1])) + + if len(offsets_set) > 1: + raise ValueError("Mixed 2d-3d bounding boxes not yet supported") + + offset = list(offsets_set)[0] + if offset == 6: + return True + elif offset == 4: + return False + else: + raise ValueError(f"Unexpected bbox offset: {offset=}") + + +def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool) -> pa.Table: + """Convert bbox column to a struct representation + + Since the bbox in JSON is stored as an array, pyarrow automatically converts the + bbox column to a ListArray. But according to GeoParquet 1.1, we should save the bbox + column as a StructArray, which allows for Parquet statistics to infer any spatial + partitioning in the dataset. + + Args: + table: _description_ + downcast: if True, will use float32 coordinates for the bounding boxes instead + of float64. Float rounding is applied to ensure the float32 bounding box + strictly contains the original float64 box. This is recommended when + possible to minimize file size. + + Returns: + New table + """ + bbox_col_idx = table.schema.get_field_index("bbox") + bbox_col = table.column(bbox_col_idx) + bbox_3d = _is_bbox_3d(bbox_col) + + new_chunks = [] + for chunk in bbox_col.chunks: + assert ( + pa.types.is_list(chunk.type) + or pa.types.is_large_list(chunk.type) + or pa.types.is_fixed_size_list(chunk.type) + ) + if bbox_3d: + coords = chunk.flatten().to_numpy().reshape(-1, 6) + else: + coords = chunk.flatten().to_numpy().reshape(-1, 4) + + if downcast: + coords = coords.astype(np.float32) + + if bbox_3d: + xmin = coords[:, 0] + ymin = coords[:, 1] + zmin = coords[:, 2] + xmax = coords[:, 3] + ymax = coords[:, 4] + zmax = coords[:, 5] + + if downcast: + # Round min values down to the next float32 value + # Round max values up to the next float32 value + xmin = np.nextafter(xmin, -np.Infinity) + ymin = np.nextafter(ymin, -np.Infinity) + zmin = np.nextafter(zmin, -np.Infinity) + xmax = np.nextafter(xmax, np.Infinity) + ymax = np.nextafter(ymax, np.Infinity) + zmax = np.nextafter(zmax, np.Infinity) + + struct_arr = pa.StructArray.from_arrays( + [ + xmin, + ymin, + zmin, + xmax, + ymax, + zmax, + ], + names=[ + "xmin", + "ymin", + "zmin", + "xmax", + "ymax", + "zmax", + ], + ) + + else: + xmin = coords[:, 0] + ymin = coords[:, 1] + xmax = coords[:, 2] + ymax = coords[:, 3] + + if downcast: + # Round min values down to the next float32 value + # Round max values up to the next float32 value + xmin = np.nextafter(xmin, -np.Infinity) + ymin = np.nextafter(ymin, -np.Infinity) + xmax = np.nextafter(xmax, np.Infinity) + ymax = np.nextafter(ymax, np.Infinity) + + struct_arr = pa.StructArray.from_arrays( + [ + xmin, + ymin, + xmax, + ymax, + ], + names=[ + "xmin", + "ymin", + "xmax", + "ymax", + ], + ) + + new_chunks.append(struct_arr) + + return table.set_column(bbox_col_idx, "bbox", new_chunks) diff --git a/stac_geoparquet/arrow/_to_parquet.py b/stac_geoparquet/arrow/_to_parquet.py new file mode 100644 index 0000000..3641001 --- /dev/null +++ b/stac_geoparquet/arrow/_to_parquet.py @@ -0,0 +1,46 @@ +import json +from typing import Any + +import pyarrow as pa +import pyarrow.parquet as pq +from pyproj import CRS + +WGS84_CRS_JSON = CRS.from_epsg(4326).to_json_dict() + + +def to_parquet(table: pa.Table, where: Any, **kwargs: Any) -> None: + """Write an Arrow table with STAC data to GeoParquet + + This writes metadata compliant with GeoParquet 1.1. + + Args: + table: The table to write to Parquet + where: The destination for saving. + """ + # TODO: include bbox of geometries + column_meta = { + "encoding": "WKB", + # TODO: specify known geometry types + "geometry_types": [], + "crs": WGS84_CRS_JSON, + "edges": "planar", + "covering": { + "bbox": { + "xmin": ["bbox", "xmin"], + "ymin": ["bbox", "ymin"], + "xmax": ["bbox", "xmax"], + "ymax": ["bbox", "ymax"], + } + }, + } + geo_meta = { + "version": "1.1.0-dev", + "columns": {"geometry": column_meta}, + "primary_column": "geometry", + } + + metadata = table.schema.metadata or {} + metadata.update({b"geo": json.dumps(geo_meta).encode("utf-8")}) + table = table.replace_schema_metadata(metadata) + + pq.write_table(table, where, **kwargs) diff --git a/stac_geoparquet/from_arrow.py b/stac_geoparquet/from_arrow.py index f940864..dc19fca 100644 --- a/stac_geoparquet/from_arrow.py +++ b/stac_geoparquet/from_arrow.py @@ -1,166 +1,8 @@ -"""Convert STAC Items in Arrow Table format to JSON Lines or Python dicts.""" +import warnings -import os -import json -from typing import Iterable, List, Union +warnings.warn( + "stac_geoparquet.from_arrow is deprecated. Please use stac_geoparquet.arrow instead.", + FutureWarning, +) -import numpy as np -import pyarrow as pa -import pyarrow.compute as pc -import shapely - - -def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None: - """Write a STAC Table to a newline-delimited JSON file.""" - with open(dest, "w") as f: - for item_dict in stac_table_to_items(table): - json.dump(item_dict, f, separators=(",", ":")) - f.write("\n") - - -def stac_table_to_items(table: pa.Table) -> Iterable[dict]: - """Convert a STAC Table to a generator of STAC Item `dict`s""" - table = _undo_stac_table_transformations(table) - - # Convert WKB geometry column to GeoJSON, and then assign the geojson geometry when - # converting each row to a dictionary. - for batch in table.to_batches(): - geoms = shapely.from_wkb(batch["geometry"]) - geojson_strings = shapely.to_geojson(geoms) - - # RecordBatch is missing a `drop()` method, so we keep all columns other than - # geometry instead - keep_column_names = [name for name in batch.column_names if name != "geometry"] - struct_batch = batch.select(keep_column_names).to_struct_array() - - for row_idx in range(len(struct_batch)): - row_dict = struct_batch[row_idx].as_py() - row_dict["geometry"] = json.loads(geojson_strings[row_idx]) - yield row_dict - - -def _undo_stac_table_transformations(table: pa.Table) -> pa.Table: - """Undo the transformations done to convert STAC Json into an Arrow Table - - Note that this function does _not_ undo the GeoJSON -> WKB geometry transformation, - as that is easier to do when converting each item in the table to a dict. - """ - table = _convert_timestamp_columns_to_string(table) - table = _lower_properties_from_top_level(table) - table = _convert_bbox_to_array(table) - return table - - -def _convert_timestamp_columns_to_string(table: pa.Table) -> pa.Table: - """Convert any datetime columns in the table to a string representation""" - allowed_column_names = { - "datetime", # common metadata - "start_datetime", - "end_datetime", - "created", - "updated", - "expires", # timestamps extension - "published", - "unpublished", - } - for column_name in allowed_column_names: - try: - column = table[column_name] - except KeyError: - continue - - table = table.drop(column_name).append_column( - column_name, pc.strftime(column, format="%Y-%m-%dT%H:%M:%SZ") - ) - - return table - - -def _lower_properties_from_top_level(table: pa.Table) -> pa.Table: - """Take properties columns from the top level and wrap them in a struct column""" - stac_top_level_keys = { - "stac_version", - "stac_extensions", - "type", - "id", - "bbox", - "geometry", - "collection", - "links", - "assets", - } - - properties_column_names: List[str] = [] - properties_column_fields: List[pa.Field] = [] - for column_idx in range(table.num_columns): - column_name = table.column_names[column_idx] - if column_name in stac_top_level_keys: - continue - - properties_column_names.append(column_name) - properties_column_fields.append(table.schema.field(column_idx)) - - properties_array_chunks = [] - for batch in table.select(properties_column_names).to_batches(): - struct_arr = pa.StructArray.from_arrays( - batch.columns, fields=properties_column_fields - ) - properties_array_chunks.append(struct_arr) - - return table.drop_columns(properties_column_names).append_column( - "properties", pa.chunked_array(properties_array_chunks) - ) - - -def _convert_bbox_to_array(table: pa.Table) -> pa.Table: - """Convert the struct bbox column back to a list column for writing to JSON""" - - bbox_col_idx = table.schema.get_field_index("bbox") - bbox_col = table.column(bbox_col_idx) - - new_chunks = [] - for chunk in bbox_col.chunks: - assert pa.types.is_struct(chunk.type) - - if bbox_col.type.num_fields == 4: - xmin = chunk.field("xmin").to_numpy() - ymin = chunk.field("ymin").to_numpy() - xmax = chunk.field("xmax").to_numpy() - ymax = chunk.field("ymax").to_numpy() - coords = np.column_stack( - [ - xmin, - ymin, - xmax, - ymax, - ] - ) - - list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4) - - elif bbox_col.type.num_fields == 6: - xmin = chunk.field("xmin").to_numpy() - ymin = chunk.field("ymin").to_numpy() - zmin = chunk.field("zmin").to_numpy() - xmax = chunk.field("xmax").to_numpy() - ymax = chunk.field("ymax").to_numpy() - zmax = chunk.field("zmax").to_numpy() - coords = np.column_stack( - [ - xmin, - ymin, - zmin, - xmax, - ymax, - zmax, - ] - ) - - list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 6) - - else: - raise ValueError("Expected 4 or 6 fields in bbox struct.") - - new_chunks.append(list_arr) - - return table.set_column(bbox_col_idx, "bbox", new_chunks) +from stac_geoparquet.arrow._from_arrow import * # noqa diff --git a/stac_geoparquet/to_arrow.py b/stac_geoparquet/to_arrow.py index 1cc36e7..9b3f81d 100644 --- a/stac_geoparquet/to_arrow.py +++ b/stac_geoparquet/to_arrow.py @@ -1,364 +1,8 @@ -"""Convert STAC data into Arrow tables""" +import warnings -import json -from copy import deepcopy -from datetime import datetime -from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Union, Generator +warnings.warn( + "stac_geoparquet.to_arrow is deprecated. Please use stac_geoparquet.arrow instead.", + FutureWarning, +) -import ciso8601 -import numpy as np -import pyarrow as pa -import pyarrow.compute as pc -import shapely -import shapely.geometry - - -def _chunks( - lst: Sequence[Dict[str, Any]], n: int -) -> Generator[Sequence[Dict[str, Any]], None, None]: - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - -def parse_stac_items_to_arrow( - items: Sequence[Dict[str, Any]], - *, - chunk_size: int = 8192, - schema: Optional[pa.Schema] = None, - downcast: bool = True, -) -> pa.Table: - """Parse a collection of STAC Items to a :class:`pyarrow.Table`. - - The objects under `properties` are moved up to the top-level of the - Table, similar to :meth:`geopandas.GeoDataFrame.from_features`. - - Args: - items: the STAC Items to convert - chunk_size: The chunk size to use for Arrow record batches. This only takes - effect if `schema` is not None. When `schema` is None, the input will be - parsed into a single contiguous record batch. Defaults to 8192. - schema: The schema of the input data. If provided, can improve memory use; - otherwise all items need to be parsed into a single array for schema - inference. Defaults to None. - downcast: if True, store bbox as float32 for memory and disk saving. - - Returns: - a pyarrow Table with the STAC-GeoParquet representation of items. - """ - - if schema is not None: - # If schema is provided, then for better memory usage we parse input STAC items - # to Arrow batches in chunks. - batches = [] - for chunk in _chunks(items, chunk_size): - batches.append(_stac_items_to_arrow(chunk, schema=schema)) - - table = pa.Table.from_batches(batches, schema=schema) - else: - # If schema is _not_ provided, then we must convert to Arrow all at once, or - # else it would be possible for a STAC item late in the collection (after the - # first chunk) to have a different schema and not match the schema inferred for - # the first chunk. - table = pa.Table.from_batches([_stac_items_to_arrow(items)]) - - return _process_arrow_table(table, downcast=downcast) - - -def parse_stac_ndjson_to_arrow( - path: Union[str, Path], - *, - chunk_size: int = 8192, - schema: Optional[pa.Schema] = None, - downcast: bool = True, -) -> pa.Table: - # Define outside of if/else to make mypy happy - items: List[dict] = [] - - # If the schema was not provided, then we need to load all data into memory at once - # to perform schema resolution. - if schema is None: - with open(path) as f: - for line in f: - items.append(json.loads(line)) - - return parse_stac_items_to_arrow(items, chunk_size=chunk_size, schema=schema) - - # Otherwise, we can stream over the input, converting each batch of `chunk_size` - # into an Arrow RecordBatch at a time. This is much more memory efficient. - with open(path) as f: - batches: List[pa.RecordBatch] = [] - for line in f: - items.append(json.loads(line)) - - if len(items) >= chunk_size: - batches.append(_stac_items_to_arrow(items, schema=schema)) - items = [] - - # Don't forget the last chunk in case the total number of items is not a multiple of - # chunk_size. - if len(items) > 0: - batches.append(_stac_items_to_arrow(items, schema=schema)) - - table = pa.Table.from_batches(batches, schema=schema) - return _process_arrow_table(table, downcast=downcast) - - -def _process_arrow_table(table: pa.Table, *, downcast: bool = True) -> pa.Table: - table = _bring_properties_to_top_level(table) - table = _convert_timestamp_columns(table) - table = _convert_bbox_to_struct(table, downcast=downcast) - return table - - -def _stac_items_to_arrow( - items: Sequence[Dict[str, Any]], *, schema: Optional[pa.Schema] = None -) -> pa.RecordBatch: - """Convert dicts representing STAC Items to Arrow - - This converts GeoJSON geometries to WKB before Arrow conversion to allow multiple - geometry types. - - All items will be parsed into a single RecordBatch, meaning that each internal array - is fully contiguous in memory for the length of `items`. - - Args: - items: STAC Items to convert to Arrow - - Kwargs: - schema: An optional schema that describes the format of the data. Note that this - must represent the geometry column as binary type. - - Returns: - Arrow RecordBatch with items in Arrow - """ - # Preprocess GeoJSON to WKB in each STAC item - # Otherwise, pyarrow will try to parse coordinates into a native geometry type and - # if you have multiple geometry types pyarrow will error with - # `ArrowInvalid: cannot mix list and non-list, non-null values` - wkb_items = [] - for item in items: - wkb_item = deepcopy(item) - # Note: this mutates the existing items. Should we - wkb_item["geometry"] = shapely.to_wkb( - shapely.geometry.shape(wkb_item["geometry"]), flavor="iso" - ) - wkb_items.append(wkb_item) - - if schema is not None: - array = pa.array(wkb_items, type=pa.struct(schema)) - else: - array = pa.array(wkb_items) - return pa.RecordBatch.from_struct_array(array) - - -def _bring_properties_to_top_level(table: pa.Table) -> pa.Table: - """Bring all the fields inside of the nested "properties" struct to the top level""" - properties_field = table.schema.field("properties") - properties_column = table["properties"] - - for field_idx in range(properties_field.type.num_fields): - inner_prop_field = properties_field.type.field(field_idx) - table = table.append_column( - inner_prop_field, pc.struct_field(properties_column, field_idx) - ) - - table = table.drop("properties") - return table - - -def _convert_geometry_to_wkb(table: pa.Table) -> pa.Table: - """Convert the geometry column in the table to WKB""" - geoms = shapely.from_geojson( - [json.dumps(item) for item in table["geometry"].to_pylist()] - ) - wkb_geoms = shapely.to_wkb(geoms) - return table.drop("geometry").append_column("geometry", pa.array(wkb_geoms)) - - -def _convert_timestamp_columns(table: pa.Table) -> pa.Table: - """Convert all timestamp columns from a string to an Arrow Timestamp data type""" - allowed_column_names = { - "datetime", # common metadata - "start_datetime", - "end_datetime", - "created", - "updated", - "expires", # timestamps extension - "published", - "unpublished", - } - for column_name in allowed_column_names: - try: - column = table[column_name] - except KeyError: - continue - - field_index = table.schema.get_field_index(column_name) - - if pa.types.is_timestamp(column.type): - continue - - # STAC allows datetimes to be null. If all rows are null, the column type may be - # inferred as null. We cast this to a timestamp column. - elif pa.types.is_null(column.type): - table = table.set_column( - field_index, column_name, column.cast(pa.timestamp("us")) - ) - - elif pa.types.is_string(column.type): - table = table.set_column( - field_index, column_name, _convert_timestamp_column(column) - ) - else: - raise ValueError( - f"Inferred time column '{column_name}' was expected to be a string or" - f" timestamp data type but got {column.type}" - ) - - return table - - -def _convert_timestamp_column(column: pa.ChunkedArray) -> pa.ChunkedArray: - """Convert an individual timestamp column from string to a Timestamp type""" - chunks = [] - for chunk in column.chunks: - parsed_chunk: List[Optional[datetime]] = [] - for item in chunk: - if not item.is_valid: - parsed_chunk.append(None) - else: - parsed_chunk.append(ciso8601.parse_rfc3339(item.as_py())) - - pyarrow_chunk = pa.array(parsed_chunk) - chunks.append(pyarrow_chunk) - - return pa.chunked_array(chunks) - - -def is_bbox_3d(bbox_col: pa.ChunkedArray) -> bool: - """Infer whether the bounding box column represents 2d or 3d bounding boxes.""" - offsets_set = set() - for chunk in bbox_col.chunks: - offsets = chunk.offsets.to_numpy() - offsets_set.update(np.unique(offsets[1:] - offsets[:-1])) - - if len(offsets_set) > 1: - raise ValueError("Mixed 2d-3d bounding boxes not yet supported") - - offset = list(offsets_set)[0] - if offset == 6: - return True - elif offset == 4: - return False - else: - raise ValueError(f"Unexpected bbox offset: {offset=}") - - -def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool) -> pa.Table: - """Convert bbox column to a struct representation - - Since the bbox in JSON is stored as an array, pyarrow automatically converts the - bbox column to a ListArray. But according to GeoParquet 1.1, we should save the bbox - column as a StructArray, which allows for Parquet statistics to infer any spatial - partitioning in the dataset. - - Args: - table: _description_ - downcast: if True, will use float32 coordinates for the bounding boxes instead - of float64. Float rounding is applied to ensure the float32 bounding box - strictly contains the original float64 box. This is recommended when - possible to minimize file size. - - Returns: - New table - """ - bbox_col_idx = table.schema.get_field_index("bbox") - bbox_col = table.column(bbox_col_idx) - bbox_3d = is_bbox_3d(bbox_col) - - new_chunks = [] - for chunk in bbox_col.chunks: - assert ( - pa.types.is_list(chunk.type) - or pa.types.is_large_list(chunk.type) - or pa.types.is_fixed_size_list(chunk.type) - ) - if bbox_3d: - coords = chunk.flatten().to_numpy().reshape(-1, 6) - else: - coords = chunk.flatten().to_numpy().reshape(-1, 4) - - if downcast: - coords = coords.astype(np.float32) - - if bbox_3d: - xmin = coords[:, 0] - ymin = coords[:, 1] - zmin = coords[:, 2] - xmax = coords[:, 3] - ymax = coords[:, 4] - zmax = coords[:, 5] - - if downcast: - # Round min values down to the next float32 value - # Round max values up to the next float32 value - xmin = np.nextafter(xmin, -np.Infinity) - ymin = np.nextafter(ymin, -np.Infinity) - zmin = np.nextafter(zmin, -np.Infinity) - xmax = np.nextafter(xmax, np.Infinity) - ymax = np.nextafter(ymax, np.Infinity) - zmax = np.nextafter(zmax, np.Infinity) - - struct_arr = pa.StructArray.from_arrays( - [ - xmin, - ymin, - zmin, - xmax, - ymax, - zmax, - ], - names=[ - "xmin", - "ymin", - "zmin", - "xmax", - "ymax", - "zmax", - ], - ) - - else: - xmin = coords[:, 0] - ymin = coords[:, 1] - xmax = coords[:, 2] - ymax = coords[:, 3] - - if downcast: - # Round min values down to the next float32 value - # Round max values up to the next float32 value - xmin = np.nextafter(xmin, -np.Infinity) - ymin = np.nextafter(ymin, -np.Infinity) - xmax = np.nextafter(xmax, np.Infinity) - ymax = np.nextafter(ymax, np.Infinity) - - struct_arr = pa.StructArray.from_arrays( - [ - xmin, - ymin, - xmax, - ymax, - ], - names=[ - "xmin", - "ymin", - "xmax", - "ymax", - ], - ) - - new_chunks.append(struct_arr) - - return table.set_column(bbox_col_idx, "bbox", new_chunks) +from stac_geoparquet.arrow._to_arrow import * # noqa diff --git a/stac_geoparquet/to_parquet.py b/stac_geoparquet/to_parquet.py index 3641001..6457152 100644 --- a/stac_geoparquet/to_parquet.py +++ b/stac_geoparquet/to_parquet.py @@ -1,46 +1,8 @@ -import json -from typing import Any +import warnings -import pyarrow as pa -import pyarrow.parquet as pq -from pyproj import CRS +warnings.warn( + "stac_geoparquet.to_parquet is deprecated. Please use stac_geoparquet.arrow instead.", + FutureWarning, +) -WGS84_CRS_JSON = CRS.from_epsg(4326).to_json_dict() - - -def to_parquet(table: pa.Table, where: Any, **kwargs: Any) -> None: - """Write an Arrow table with STAC data to GeoParquet - - This writes metadata compliant with GeoParquet 1.1. - - Args: - table: The table to write to Parquet - where: The destination for saving. - """ - # TODO: include bbox of geometries - column_meta = { - "encoding": "WKB", - # TODO: specify known geometry types - "geometry_types": [], - "crs": WGS84_CRS_JSON, - "edges": "planar", - "covering": { - "bbox": { - "xmin": ["bbox", "xmin"], - "ymin": ["bbox", "ymin"], - "xmax": ["bbox", "xmax"], - "ymax": ["bbox", "ymax"], - } - }, - } - geo_meta = { - "version": "1.1.0-dev", - "columns": {"geometry": column_meta}, - "primary_column": "geometry", - } - - metadata = table.schema.metadata or {} - metadata.update({b"geo": json.dumps(geo_meta).encode("utf-8")}) - table = table.replace_schema_metadata(metadata) - - pq.write_table(table, where, **kwargs) +from stac_geoparquet.arrow._to_parquet import * # noqa diff --git a/tests/test_arrow.py b/tests/test_arrow.py index 8daac78..0cc1803 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -6,8 +6,7 @@ import pytest from ciso8601 import parse_rfc3339 -from stac_geoparquet.from_arrow import stac_table_to_items -from stac_geoparquet.to_arrow import parse_stac_items_to_arrow +from stac_geoparquet.arrow import parse_stac_items_to_arrow, stac_table_to_items HERE = Path(__file__).parent @@ -209,3 +208,23 @@ def test_round_trip(collection_id: str): for result, expected in zip(items_result, items): assert_json_value_equal(result, expected, precision=0) + + +def test_to_arrow_deprecated(): + with pytest.warns(FutureWarning): + import stac_geoparquet.to_arrow + stac_geoparquet.to_arrow.parse_stac_items_to_arrow + + +def test_to_parquet_deprecated(): + with pytest.warns(FutureWarning): + import stac_geoparquet.to_parquet + + stac_geoparquet.to_parquet.to_parquet + + +def test_from_arrow_deprecated(): + with pytest.warns(FutureWarning): + import stac_geoparquet.from_arrow + + stac_geoparquet.from_arrow.stac_table_to_items