-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move arrow-based code into arrow module (#47)
* Move arrow-based code into arrow module * fix tests import * deprecation --------- Co-authored-by: Tom Augspurger <[email protected]>
- Loading branch information
1 parent
fd7c9a4
commit 5c0a682
Showing
9 changed files
with
620 additions
and
574 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ._from_arrow import stac_table_to_items, stac_table_to_ndjson | ||
from ._to_arrow import parse_stac_items_to_arrow, parse_stac_ndjson_to_arrow | ||
from ._to_parquet import to_parquet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
"""Convert STAC Items in Arrow Table format to JSON Lines or Python dicts.""" | ||
|
||
import os | ||
import json | ||
from typing import Iterable, List, Union | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
import pyarrow.compute as pc | ||
import shapely | ||
|
||
|
||
def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None: | ||
"""Write a STAC Table to a newline-delimited JSON file.""" | ||
with open(dest, "w") as f: | ||
for item_dict in stac_table_to_items(table): | ||
json.dump(item_dict, f, separators=(",", ":")) | ||
f.write("\n") | ||
|
||
|
||
def stac_table_to_items(table: pa.Table) -> Iterable[dict]: | ||
"""Convert a STAC Table to a generator of STAC Item `dict`s""" | ||
table = _undo_stac_table_transformations(table) | ||
|
||
# Convert WKB geometry column to GeoJSON, and then assign the geojson geometry when | ||
# converting each row to a dictionary. | ||
for batch in table.to_batches(): | ||
geoms = shapely.from_wkb(batch["geometry"]) | ||
geojson_strings = shapely.to_geojson(geoms) | ||
|
||
# RecordBatch is missing a `drop()` method, so we keep all columns other than | ||
# geometry instead | ||
keep_column_names = [name for name in batch.column_names if name != "geometry"] | ||
struct_batch = batch.select(keep_column_names).to_struct_array() | ||
|
||
for row_idx in range(len(struct_batch)): | ||
row_dict = struct_batch[row_idx].as_py() | ||
row_dict["geometry"] = json.loads(geojson_strings[row_idx]) | ||
yield row_dict | ||
|
||
|
||
def _undo_stac_table_transformations(table: pa.Table) -> pa.Table: | ||
"""Undo the transformations done to convert STAC Json into an Arrow Table | ||
Note that this function does _not_ undo the GeoJSON -> WKB geometry transformation, | ||
as that is easier to do when converting each item in the table to a dict. | ||
""" | ||
table = _convert_timestamp_columns_to_string(table) | ||
table = _lower_properties_from_top_level(table) | ||
table = _convert_bbox_to_array(table) | ||
return table | ||
|
||
|
||
def _convert_timestamp_columns_to_string(table: pa.Table) -> pa.Table: | ||
"""Convert any datetime columns in the table to a string representation""" | ||
allowed_column_names = { | ||
"datetime", # common metadata | ||
"start_datetime", | ||
"end_datetime", | ||
"created", | ||
"updated", | ||
"expires", # timestamps extension | ||
"published", | ||
"unpublished", | ||
} | ||
for column_name in allowed_column_names: | ||
try: | ||
column = table[column_name] | ||
except KeyError: | ||
continue | ||
|
||
table = table.drop(column_name).append_column( | ||
column_name, pc.strftime(column, format="%Y-%m-%dT%H:%M:%SZ") | ||
) | ||
|
||
return table | ||
|
||
|
||
def _lower_properties_from_top_level(table: pa.Table) -> pa.Table: | ||
"""Take properties columns from the top level and wrap them in a struct column""" | ||
stac_top_level_keys = { | ||
"stac_version", | ||
"stac_extensions", | ||
"type", | ||
"id", | ||
"bbox", | ||
"geometry", | ||
"collection", | ||
"links", | ||
"assets", | ||
} | ||
|
||
properties_column_names: List[str] = [] | ||
properties_column_fields: List[pa.Field] = [] | ||
for column_idx in range(table.num_columns): | ||
column_name = table.column_names[column_idx] | ||
if column_name in stac_top_level_keys: | ||
continue | ||
|
||
properties_column_names.append(column_name) | ||
properties_column_fields.append(table.schema.field(column_idx)) | ||
|
||
properties_array_chunks = [] | ||
for batch in table.select(properties_column_names).to_batches(): | ||
struct_arr = pa.StructArray.from_arrays( | ||
batch.columns, fields=properties_column_fields | ||
) | ||
properties_array_chunks.append(struct_arr) | ||
|
||
return table.drop_columns(properties_column_names).append_column( | ||
"properties", pa.chunked_array(properties_array_chunks) | ||
) | ||
|
||
|
||
def _convert_bbox_to_array(table: pa.Table) -> pa.Table: | ||
"""Convert the struct bbox column back to a list column for writing to JSON""" | ||
|
||
bbox_col_idx = table.schema.get_field_index("bbox") | ||
bbox_col = table.column(bbox_col_idx) | ||
|
||
new_chunks = [] | ||
for chunk in bbox_col.chunks: | ||
assert pa.types.is_struct(chunk.type) | ||
|
||
if bbox_col.type.num_fields == 4: | ||
xmin = chunk.field("xmin").to_numpy() | ||
ymin = chunk.field("ymin").to_numpy() | ||
xmax = chunk.field("xmax").to_numpy() | ||
ymax = chunk.field("ymax").to_numpy() | ||
coords = np.column_stack( | ||
[ | ||
xmin, | ||
ymin, | ||
xmax, | ||
ymax, | ||
] | ||
) | ||
|
||
list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4) | ||
|
||
elif bbox_col.type.num_fields == 6: | ||
xmin = chunk.field("xmin").to_numpy() | ||
ymin = chunk.field("ymin").to_numpy() | ||
zmin = chunk.field("zmin").to_numpy() | ||
xmax = chunk.field("xmax").to_numpy() | ||
ymax = chunk.field("ymax").to_numpy() | ||
zmax = chunk.field("zmax").to_numpy() | ||
coords = np.column_stack( | ||
[ | ||
xmin, | ||
ymin, | ||
zmin, | ||
xmax, | ||
ymax, | ||
zmax, | ||
] | ||
) | ||
|
||
list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 6) | ||
|
||
else: | ||
raise ValueError("Expected 4 or 6 fields in bbox struct.") | ||
|
||
new_chunks.append(list_arr) | ||
|
||
return table.set_column(bbox_col_idx, "bbox", new_chunks) |
Oops, something went wrong.