Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STAC Interoperability with Arrow #37

Merged
merged 18 commits into from
Apr 17, 2024
Merged
30 changes: 12 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@ build-backend = "hatchling.build"

[project]
name = "stac_geoparquet"
authors = [{name = "Tom Augspurger", email = "[email protected]"}]
authors = [{ name = "Tom Augspurger", email = "[email protected]" }]
readme = "README.md"
license = {file = "LICENSE"}
license = { file = "LICENSE" }
classifiers = ["License :: OSI Approved :: MIT License"]
dynamic = ["version", "description"]
requires-python = ">=3.8"
dependencies = [
"pystac",
"ciso8601",
"geopandas",
"packaging",
"pandas",
"pyarrow",
"pystac",
"shapely",
"packaging",
]

[tool.hatch.version]
Expand All @@ -33,13 +34,7 @@ pgstac = [
"tqdm",
"python-dateutil",
]
pc = [
"adlfs",
"pypgstac",
"psycopg[binary,pool]",
"tqdm",
"azure-data-tables",
]
pc = ["adlfs", "pypgstac", "psycopg[binary,pool]", "tqdm", "azure-data-tables"]
test = [
"pytest",
"requests",
Expand All @@ -57,24 +52,23 @@ pc-geoparquet = "stac_geoparquet.cli:main"

[tool.pytest.ini_options]
minversion = "6.0"
filterwarnings = [
"ignore:.*distutils Version.*:DeprecationWarning",
]
filterwarnings = ["ignore:.*distutils Version.*:DeprecationWarning"]

[tool.mypy]

python_version = "3.10"

[[tool.mypy.overrides]]
module = [
"shapely.*",
"ciso8601.*",
"fsspec.*",
"geopandas.*",
"pandas.*",
"fsspec.*",
"tqdm.*",
"pypgstac.*",
"pyarrow.*",
"pypgstac.*",
"rich.*",
"shapely.*",
"tqdm.*",
]

ignore_missing_imports = true
139 changes: 139 additions & 0 deletions stac_geoparquet/from_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Convert STAC Items in Arrow Table format to JSON Lines or Python dicts."""

import json
from typing import Iterable, List

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import shapely


def stac_table_to_ndjson(table: pa.Table, dest: str):
"""Write a STAC Table to a newline-delimited JSON file."""
with open(dest, "w") as f:
for item_dict in stac_table_to_items(table):
json.dump(item_dict, f, separators=(",", ":"))
f.write("\n")


def stac_table_to_items(table: pa.Table) -> Iterable[dict]:
"""Convert a STAC Table to a generator of STAC Item `dict`s"""
table = _undo_stac_table_transformations(table)

# Convert WKB geometry column to GeoJSON, and then assign the geojson geometry when
# converting each row to a dictionary.
for batch in table.to_batches():
geoms = shapely.from_wkb(batch["geometry"])
geojson_strings = shapely.to_geojson(geoms)

# RecordBatch is missing a `drop()` method, so we keep all columns other than
# geometry instead
keep_column_names = [name for name in batch.column_names if name != "geometry"]
struct_batch = batch.select(keep_column_names).to_struct_array()

for row_idx in range(len(struct_batch)):
row_dict = struct_batch[row_idx].as_py()
row_dict["geometry"] = json.loads(geojson_strings[row_idx])
yield row_dict


def _undo_stac_table_transformations(table: pa.Table) -> pa.Table:
"""Undo the transformations done to convert STAC Json into an Arrow Table

Note that this function does _not_ undo the GeoJSON -> WKB geometry transformation,
as that is easier to do when converting each item in the table to a dict.
"""
table = _convert_timestamp_columns_to_string(table)
table = _lower_properties_from_top_level(table)
table = _convert_bbox_to_array(table)
return table


def _convert_timestamp_columns_to_string(table: pa.Table) -> pa.Table:
"""Convert any datetime columns in the table to a string representation"""
allowed_column_names = {
"datetime", # common metadata
"start_datetime",
"end_datetime",
"created",
"updated",
"expires", # timestamps extension
"published",
"unpublished",
}
for column_name in allowed_column_names:
try:
column = table[column_name]
except KeyError:
continue

table = table.drop(column_name).append_column(
column_name, pc.strftime(column, format="%Y-%m-%dT%H:%M:%SZ")
)

return table


def _lower_properties_from_top_level(table: pa.Table) -> pa.Table:
"""Take properties columns from the top level and wrap them in a struct column"""
stac_top_level_keys = {
"stac_version",
"stac_extensions",
"type",
"id",
"bbox",
"geometry",
"collection",
"links",
"assets",
}

properties_column_names: List[str] = []
properties_column_fields: List[pa.Field] = []
for column_idx in range(table.num_columns):
column_name = table.column_names[column_idx]
if column_name in stac_top_level_keys:
continue

properties_column_names.append(column_name)
properties_column_fields.append(table.schema.field(column_idx))

properties_array_chunks = []
for batch in table.select(properties_column_names).to_batches():
struct_arr = pa.StructArray.from_arrays(
batch.columns, fields=properties_column_fields
)
properties_array_chunks.append(struct_arr)

return table.drop_columns(properties_column_names).append_column(
"properties", pa.chunked_array(properties_array_chunks)
)


def _convert_bbox_to_array(table: pa.Table) -> pa.Table:
"""Convert the struct bbox column back to a list column for writing to JSON"""

bbox_col_idx = table.schema.get_field_index("bbox")
bbox_col = table.column(bbox_col_idx)

new_chunks = []
for chunk in bbox_col.chunks:
assert pa.types.is_struct(chunk.type)
xmin = chunk.field(0).to_numpy()
ymin = chunk.field(1).to_numpy()
xmax = chunk.field(2).to_numpy()
ymax = chunk.field(3).to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
xmax,
ymax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4)
new_chunks.append(list_arr)

return table.set_column(bbox_col_idx, "bbox", new_chunks)
Loading
Loading