Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Round trip tests & various fixes #42

Merged
merged 12 commits into from
Apr 22, 2024
52 changes: 39 additions & 13 deletions stac_geoparquet/from_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,46 @@ def _convert_bbox_to_array(table: pa.Table) -> pa.Table:
new_chunks = []
for chunk in bbox_col.chunks:
assert pa.types.is_struct(chunk.type)
xmin = chunk.field(0).to_numpy()
ymin = chunk.field(1).to_numpy()
xmax = chunk.field(2).to_numpy()
ymax = chunk.field(3).to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
xmax,
ymax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4)
if bbox_col.type.num_fields == 4:
xmin = chunk.field(0).to_numpy()
ymin = chunk.field(1).to_numpy()
xmax = chunk.field(2).to_numpy()
ymax = chunk.field(3).to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
xmax,
ymax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4)

elif bbox_col.type.num_fields == 6:
xmin = chunk.field(0).to_numpy()
ymin = chunk.field(1).to_numpy()
zmin = chunk.field(2).to_numpy()
xmax = chunk.field(3).to_numpy()
ymax = chunk.field(4).to_numpy()
zmax = chunk.field(5).to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
zmin,
xmax,
ymax,
zmax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 6)

else:
raise ValueError("Expected 4 or 6 fields in bbox struct.")

new_chunks.append(list_arr)

return table.set_column(bbox_col_idx, "bbox", new_chunks)
149 changes: 113 additions & 36 deletions stac_geoparquet/to_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def parse_stac_items_to_arrow(
*,
chunk_size: int = 8192,
schema: Optional[pa.Schema] = None,
downcast: bool = True,
) -> pa.Table:
"""Parse a collection of STAC Items to a :class:`pyarrow.Table`.

Expand All @@ -41,6 +42,7 @@ def parse_stac_items_to_arrow(
schema: The schema of the input data. If provided, can improve memory use;
otherwise all items need to be parsed into a single array for schema
inference. Defaults to None.
downcast: if True, store bbox as float32 for memory and disk saving.

Returns:
a pyarrow Table with the STAC-GeoParquet representation of items.
Expand All @@ -53,22 +55,23 @@ def parse_stac_items_to_arrow(
for chunk in _chunks(items, chunk_size):
batches.append(_stac_items_to_arrow(chunk, schema=schema))

stac_table = pa.Table.from_batches(batches, schema=schema)
table = pa.Table.from_batches(batches, schema=schema)
else:
# If schema is _not_ provided, then we must convert to Arrow all at once, or
# else it would be possible for a STAC item late in the collection (after the
# first chunk) to have a different schema and not match the schema inferred for
# the first chunk.
stac_table = pa.Table.from_batches([_stac_items_to_arrow(items)])
table = pa.Table.from_batches([_stac_items_to_arrow(items)])

return _process_arrow_table(stac_table)
return _process_arrow_table(table, downcast=downcast)


def parse_stac_ndjson_to_arrow(
path: Union[str, Path],
*,
chunk_size: int = 8192,
schema: Optional[pa.Schema] = None,
downcast: bool = True,
) -> pa.Table:
# Define outside of if/else to make mypy happy
items: List[dict] = []
Expand Down Expand Up @@ -98,14 +101,14 @@ def parse_stac_ndjson_to_arrow(
if len(items) > 0:
batches.append(_stac_items_to_arrow(items, schema=schema))

stac_table = pa.Table.from_batches(batches, schema=schema)
return _process_arrow_table(stac_table)
table = pa.Table.from_batches(batches, schema=schema)
return _process_arrow_table(table, downcast=downcast)


def _process_arrow_table(table: pa.Table) -> pa.Table:
def _process_arrow_table(table: pa.Table, *, downcast: bool = True) -> pa.Table:
table = _bring_properties_to_top_level(table)
table = _convert_timestamp_columns(table)
table = _convert_bbox_to_struct(table)
table = _convert_bbox_to_struct(table, downcast=downcast)
return table


Expand Down Expand Up @@ -192,11 +195,21 @@ def _convert_timestamp_columns(table: pa.Table) -> pa.Table:
except KeyError:
continue

field_index = table.schema.get_field_index(column_name)

if pa.types.is_timestamp(column.type):
continue

# STAC allows datetimes to be null. If all rows are null, the column type may be
# inferred as null. We cast this to a timestamp column.
elif pa.types.is_null(column.type):
table = table.set_column(
field_index, column_name, column.cast(pa.timestamp("us"))
)

elif pa.types.is_string(column.type):
table = table.drop(column_name).append_column(
column_name, _convert_timestamp_column(column)
table = table.set_column(
field_index, column_name, _convert_timestamp_column(column)
)
else:
raise ValueError(
Expand Down Expand Up @@ -224,7 +237,26 @@ def _convert_timestamp_column(column: pa.ChunkedArray) -> pa.ChunkedArray:
return pa.chunked_array(chunks)


def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool = True) -> pa.Table:
def is_bbox_3d(bbox_col: pa.ChunkedArray) -> bool:
"""Infer whether the bounding box column represents 2d or 3d bounding boxes."""
offsets_set = set()
for chunk in bbox_col.chunks:
offsets = chunk.offsets.to_numpy()
offsets_set.update(np.unique(offsets[1:] - offsets[:-1]))

if len(offsets_set) > 1:
raise ValueError("Mixed 2d-3d bounding boxes not yet supported")

offset = list(offsets_set)[0]
if offset == 6:
return True
elif offset == 4:
return False
else:
raise ValueError(f"Unexpected bbox offset: {offset=}")


def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool) -> pa.Table:
"""Convert bbox column to a struct representation

Since the bbox in JSON is stored as an array, pyarrow automatically converts the
Expand All @@ -244,6 +276,7 @@ def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool = True) -> pa.Tab
"""
bbox_col_idx = table.schema.get_field_index("bbox")
bbox_col = table.column(bbox_col_idx)
bbox_3d = is_bbox_3d(bbox_col)

new_chunks = []
for chunk in bbox_col.chunks:
Expand All @@ -252,36 +285,80 @@ def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool = True) -> pa.Tab
or pa.types.is_large_list(chunk.type)
or pa.types.is_fixed_size_list(chunk.type)
)
coords = chunk.flatten().to_numpy().reshape(-1, 4)
xmin = coords[:, 0]
ymin = coords[:, 1]
xmax = coords[:, 2]
ymax = coords[:, 3]
if bbox_3d:
coords = chunk.flatten().to_numpy().reshape(-1, 6)
else:
coords = chunk.flatten().to_numpy().reshape(-1, 4)

if downcast:
coords = coords.astype(np.float32)

# Round min values down to the next float32 value
# Round max values up to the next float32 value
xmin = np.nextafter(xmin, -np.Infinity)
ymin = np.nextafter(ymin, -np.Infinity)
xmax = np.nextafter(xmax, np.Infinity)
ymax = np.nextafter(ymax, np.Infinity)

struct_arr = pa.StructArray.from_arrays(
[
xmin,
ymin,
xmax,
ymax,
],
names=[
"xmin",
"ymin",
"xmax",
"ymax",
],
)
if bbox_3d:
xmin = coords[:, 0]
ymin = coords[:, 1]
zmin = coords[:, 2]
xmax = coords[:, 3]
ymax = coords[:, 4]
zmax = coords[:, 5]

if downcast:
# Round min values down to the next float32 value
# Round max values up to the next float32 value
xmin = np.nextafter(xmin, -np.Infinity)
ymin = np.nextafter(ymin, -np.Infinity)
zmin = np.nextafter(zmin, -np.Infinity)
xmax = np.nextafter(xmax, np.Infinity)
ymax = np.nextafter(ymax, np.Infinity)
zmax = np.nextafter(zmax, np.Infinity)

struct_arr = pa.StructArray.from_arrays(
[
xmin,
ymin,
zmin,
xmax,
ymax,
zmax,
],
names=[
"xmin",
"ymin",
"zmin",
"xmax",
"ymax",
"zmax",
],
)

else:
xmin = coords[:, 0]
ymin = coords[:, 1]
xmax = coords[:, 2]
ymax = coords[:, 3]

if downcast:
# Round min values down to the next float32 value
# Round max values up to the next float32 value
xmin = np.nextafter(xmin, -np.Infinity)
ymin = np.nextafter(ymin, -np.Infinity)
xmax = np.nextafter(xmax, np.Infinity)
ymax = np.nextafter(ymax, np.Infinity)

struct_arr = pa.StructArray.from_arrays(
[
xmin,
ymin,
xmax,
ymax,
],
names=[
"xmin",
"ymin",
"xmax",
"ymax",
],
)

new_chunks.append(struct_arr)

return table.set_column(bbox_col_idx, "bbox", new_chunks)
Loading
Loading