Skip to content

Commit

Permalink
Move arrow-based code into arrow module (#47)
Browse files Browse the repository at this point in the history
* Move arrow-based code into arrow module

* fix tests import

* deprecation

---------

Co-authored-by: Tom Augspurger <[email protected]>
  • Loading branch information
kylebarron and Tom Augspurger authored Apr 24, 2024
1 parent fd7c9a4 commit 5c0a682
Show file tree
Hide file tree
Showing 9 changed files with 620 additions and 574 deletions.
4 changes: 2 additions & 2 deletions stac_geoparquet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""stac-geoparquet"""

from .stac_geoparquet import to_geodataframe, to_dict, to_item_collection
from . import arrow
from ._version import __version__

from .stac_geoparquet import to_dict, to_geodataframe, to_item_collection

__all__ = [
"__version__",
Expand Down
3 changes: 3 additions & 0 deletions stac_geoparquet/arrow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._from_arrow import stac_table_to_items, stac_table_to_ndjson
from ._to_arrow import parse_stac_items_to_arrow, parse_stac_ndjson_to_arrow
from ._to_parquet import to_parquet
166 changes: 166 additions & 0 deletions stac_geoparquet/arrow/_from_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Convert STAC Items in Arrow Table format to JSON Lines or Python dicts."""

import os
import json
from typing import Iterable, List, Union

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import shapely


def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None:
"""Write a STAC Table to a newline-delimited JSON file."""
with open(dest, "w") as f:
for item_dict in stac_table_to_items(table):
json.dump(item_dict, f, separators=(",", ":"))
f.write("\n")


def stac_table_to_items(table: pa.Table) -> Iterable[dict]:
"""Convert a STAC Table to a generator of STAC Item `dict`s"""
table = _undo_stac_table_transformations(table)

# Convert WKB geometry column to GeoJSON, and then assign the geojson geometry when
# converting each row to a dictionary.
for batch in table.to_batches():
geoms = shapely.from_wkb(batch["geometry"])
geojson_strings = shapely.to_geojson(geoms)

# RecordBatch is missing a `drop()` method, so we keep all columns other than
# geometry instead
keep_column_names = [name for name in batch.column_names if name != "geometry"]
struct_batch = batch.select(keep_column_names).to_struct_array()

for row_idx in range(len(struct_batch)):
row_dict = struct_batch[row_idx].as_py()
row_dict["geometry"] = json.loads(geojson_strings[row_idx])
yield row_dict


def _undo_stac_table_transformations(table: pa.Table) -> pa.Table:
"""Undo the transformations done to convert STAC Json into an Arrow Table
Note that this function does _not_ undo the GeoJSON -> WKB geometry transformation,
as that is easier to do when converting each item in the table to a dict.
"""
table = _convert_timestamp_columns_to_string(table)
table = _lower_properties_from_top_level(table)
table = _convert_bbox_to_array(table)
return table


def _convert_timestamp_columns_to_string(table: pa.Table) -> pa.Table:
"""Convert any datetime columns in the table to a string representation"""
allowed_column_names = {
"datetime", # common metadata
"start_datetime",
"end_datetime",
"created",
"updated",
"expires", # timestamps extension
"published",
"unpublished",
}
for column_name in allowed_column_names:
try:
column = table[column_name]
except KeyError:
continue

table = table.drop(column_name).append_column(
column_name, pc.strftime(column, format="%Y-%m-%dT%H:%M:%SZ")
)

return table


def _lower_properties_from_top_level(table: pa.Table) -> pa.Table:
"""Take properties columns from the top level and wrap them in a struct column"""
stac_top_level_keys = {
"stac_version",
"stac_extensions",
"type",
"id",
"bbox",
"geometry",
"collection",
"links",
"assets",
}

properties_column_names: List[str] = []
properties_column_fields: List[pa.Field] = []
for column_idx in range(table.num_columns):
column_name = table.column_names[column_idx]
if column_name in stac_top_level_keys:
continue

properties_column_names.append(column_name)
properties_column_fields.append(table.schema.field(column_idx))

properties_array_chunks = []
for batch in table.select(properties_column_names).to_batches():
struct_arr = pa.StructArray.from_arrays(
batch.columns, fields=properties_column_fields
)
properties_array_chunks.append(struct_arr)

return table.drop_columns(properties_column_names).append_column(
"properties", pa.chunked_array(properties_array_chunks)
)


def _convert_bbox_to_array(table: pa.Table) -> pa.Table:
"""Convert the struct bbox column back to a list column for writing to JSON"""

bbox_col_idx = table.schema.get_field_index("bbox")
bbox_col = table.column(bbox_col_idx)

new_chunks = []
for chunk in bbox_col.chunks:
assert pa.types.is_struct(chunk.type)

if bbox_col.type.num_fields == 4:
xmin = chunk.field("xmin").to_numpy()
ymin = chunk.field("ymin").to_numpy()
xmax = chunk.field("xmax").to_numpy()
ymax = chunk.field("ymax").to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
xmax,
ymax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4)

elif bbox_col.type.num_fields == 6:
xmin = chunk.field("xmin").to_numpy()
ymin = chunk.field("ymin").to_numpy()
zmin = chunk.field("zmin").to_numpy()
xmax = chunk.field("xmax").to_numpy()
ymax = chunk.field("ymax").to_numpy()
zmax = chunk.field("zmax").to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
zmin,
xmax,
ymax,
zmax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 6)

else:
raise ValueError("Expected 4 or 6 fields in bbox struct.")

new_chunks.append(list_arr)

return table.set_column(bbox_col_idx, "bbox", new_chunks)
Loading

0 comments on commit 5c0a682

Please sign in to comment.