Skip to content

Commit

Permalink
Optionally use pyarrow types in to_geodataframe
Browse files Browse the repository at this point in the history
This updates to_geodataframe to optionally use pyarrow types, rather
than NumPy. These types let us faithfully represent the actual nested
types, rather than casting everything to `object`.
  • Loading branch information
TomAugspurger committed Mar 17, 2024
1 parent 3901d33 commit 73fdfac
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 158 deletions.
141 changes: 109 additions & 32 deletions stac_geoparquet/stac_geoparquet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""
Generate geoparquet from a sequence of STAC items.
"""

from __future__ import annotations
import collections

from typing import Sequence, Any
from typing import Sequence, Any, Literal
import warnings

import pystac
import geopandas
import pandas as pd
import pyarrow as pa
import numpy as np
import shapely.geometry

Expand All @@ -16,7 +20,7 @@
from stac_geoparquet.utils import fix_empty_multipolygon

STAC_ITEM_TYPES = ["application/json", "application/geo+json"]

DTYPE_BACKEND = Literal["numpy_nullable", "pyarrow"]
SELF_LINK_COLUMN = "self_link"


Expand All @@ -31,7 +35,9 @@ def _fix_array(v):


def to_geodataframe(
items: Sequence[dict[str, Any]], add_self_link: bool = False
items: Sequence[dict[str, Any]],
add_self_link: bool = False,
dtype_backend: DTYPE_BACKEND | None = None,
) -> geopandas.GeoDataFrame:
"""
Convert a sequence of STAC items to a :class:`geopandas.GeoDataFrame`.
Expand All @@ -42,19 +48,68 @@ def to_geodataframe(
Parameters
----------
items: A sequence of STAC items.
add_self_link: Add the absolute link (if available) to the source STAC Item as a separate column named "self_link"
add_self_link: bool, default False
Add the absolute link (if available) to the source STAC Item
as a separate column named "self_link"
dtype_backend: {'pyarrow', 'numpy_nullable'}, optional
The dtype backend to use for storing arrays.
By default, this will use 'numpy_nullable' and emit a
FutureWarning that the default will change to 'pyarrow' in
the next release.
Set to 'numpy_nullable' to silence the warning and accept the
old behavior.
Set to 'pyarrow' to silence the warning and accept the new behavior.
There are some difference in the output as well: with
``dtype_backend="pyarrow"``, struct-like fields will explicitly
contain null values for fields that appear in only some of the
records. For example, given an ``assets`` like::
{
"a": {
"href": "a.tif",
},
"b": {
"href": "b.tif",
"title": "B",
}
}
The ``assets`` field of the output for the first row with
``dtype_backend="numpy_nullable"`` will be a Python dictionary with
just ``{"href": "a.tiff"}``.
With ``dtype_backend="pyarrow"``, this will be a pyarrow struct
with fields ``{"href": "a.tif", "title", None}``. pyarrow will
infer that the struct field ``asset.title`` is nullable.
Returns
-------
The converted GeoDataFrame.
"""
items2 = []
items2 = collections.defaultdict(list)

for item in items:
item2 = {k: v for k, v in item.items() if k != "properties"}
keys = set(item) - {"properties", "geometry"}

for k in keys:
items2[k].append(item[k])

item_geometry = item["geometry"]
if item_geometry:
item_geometry = fix_empty_multipolygon(item_geometry)

items2["geometry"].append(item_geometry)

for k, v in item["properties"].items():
if k in item2:
raise ValueError("k", k)
item2[k] = v
if k in item:
msg = f"Key '{k}' appears in both 'properties' and the top level."
raise ValueError(msg)
items2[k].append(v)

if add_self_link:
self_href = None
for link in item["links"]:
Expand All @@ -65,23 +120,11 @@ def to_geodataframe(
):
self_href = link["href"]
break
item2[SELF_LINK_COLUMN] = self_href
items2.append(item2)

# Filter out missing geoms in MultiPolygons
# https://github.com/shapely/shapely/issues/1407
# geometry = [shapely.geometry.shape(x["geometry"]) for x in items2]

geometry = []
for item2 in items2:
item_geometry = item2["geometry"]
if item_geometry:
item_geometry = fix_empty_multipolygon(item_geometry) # type: ignore
geometry.append(item_geometry)
items2[SELF_LINK_COLUMN].append(self_href)

gdf = geopandas.GeoDataFrame(items2, geometry=geometry, crs="WGS84")

for column in [
# TODO: Ideally we wouldn't have to hard-code this list.
# Could we get it from the JSON schema.
DATETIME_COLUMNS = {
"datetime", # common metadata
"start_datetime",
"end_datetime",
Expand All @@ -90,9 +133,42 @@ def to_geodataframe(
"expires", # timestamps extension
"published",
"unpublished",
]:
if column in gdf.columns:
gdf[column] = pd.to_datetime(gdf[column], format="ISO8601")
}

items2["geometry"] = geopandas.array.from_shapely(items2["geometry"])

if dtype_backend is None:
msg = (
"The default argument for 'dtype_backend' will change from "
"'numpy_nullable' to 'pyarrow'. To keep the previous default "
"specify ``dtype_backend='numpy_nullable'``. To accept the future "
"behavior specify ``dtype_backend='pyarrow'."
)
warnings.warn(FutureWarning(msg))
dtype_backend = "numpy_nullable"

if dtype_backend == "pyarrow":
for k, v in items2.items():
if k in DATETIME_COLUMNS:
items2[k] = pd.arrays.ArrowExtensionArray(
pa.array(pd.to_datetime(v, format="ISO8601"))
)

elif k != "geometry":
items2[k] = pd.arrays.ArrowExtensionArray(pa.array(v))

elif dtype_backend == "numpy_nullable":
for k, v in items2.items():
if k in DATETIME_COLUMNS:
items2[k] = pd.to_datetime(v, format="ISO8601")

if k in {"type", "stac_version", "id", "collection", SELF_LINK_COLUMN}:
items2[k] = pd.array(v, dtype="string")
else:
msg = f"Invalid 'dtype_backend={dtype_backend}'."
raise TypeError(msg)

gdf = geopandas.GeoDataFrame(items2, geometry="geometry", crs="WGS84")

columns = [
"type",
Expand All @@ -111,10 +187,6 @@ def to_geodataframe(
columns.remove(col)

gdf = pd.concat([gdf[columns], gdf.drop(columns=columns)], axis="columns")
for k in ["type", "stac_version", "id", "collection", SELF_LINK_COLUMN]:
if k in gdf:
gdf[k] = gdf[k].astype("string")

return gdf


Expand Down Expand Up @@ -175,6 +247,11 @@ def to_item_collection(df: geopandas.GeoDataFrame) -> pystac.ItemCollection:
include=["datetime64[ns, UTC]", "datetime64[ns]"]
).columns
for k in datelike:
# %f isn't implemented in pyarrow
# https://github.com/apache/arrow/issues/20146
if isinstance(df2[k].dtype, pd.ArrowDtype):
df2[k] = df2[k].astype("datetime64[ns, utc]")

df2[k] = (
df2[k].dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ").fillna("").replace({"": None})
)
Expand Down
44 changes: 36 additions & 8 deletions stac_geoparquet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,27 @@


@functools.singledispatch
def assert_equal(result: Any, expected: Any) -> bool:
def assert_equal(result: Any, expected: Any, ignore_none: bool = False) -> bool:
raise TypeError(f"Invalid type {type(result)}")


@assert_equal.register(pystac.ItemCollection)
def assert_equal_ic(
result: pystac.ItemCollection, expected: pystac.ItemCollection
result: pystac.ItemCollection,
expected: pystac.ItemCollection,
ignore_none: bool = False,
) -> None:
assert type(result) == type(expected)
assert len(result) == len(expected)
assert result.extra_fields == expected.extra_fields
for a, b in zip(result.items, expected.items):
assert_equal(a, b)
assert_equal(a, b, ignore_none=ignore_none)


@assert_equal.register(pystac.Item)
def assert_equal_item(result: pystac.Item, expected: pystac.Item) -> None:
def assert_equal_item(
result: pystac.Item, expected: pystac.Item, ignore_none: bool = False
) -> None:
assert type(result) == type(expected)
assert result.id == expected.id
assert shapely.geometry.shape(result.geometry) == shapely.geometry.shape(
Expand All @@ -41,20 +45,44 @@ def assert_equal_item(result: pystac.Item, expected: pystac.Item) -> None:
expected_links = sorted(expected.links, key=lambda x: x.href)
assert len(result_links) == len(expected_links)
for a, b in zip(result_links, expected_links):
assert_equal(a, b)
assert_equal(a, b, ignore_none=ignore_none)

assert set(result.assets) == set(expected.assets)
for k in result.assets:
assert_equal(result.assets[k], expected.assets[k])
assert_equal(result.assets[k], expected.assets[k], ignore_none=ignore_none)


@assert_equal.register(pystac.Link)
@assert_equal.register(pystac.Asset)
def assert_link_equal(
result: pystac.Link | pystac.Asset, expected: pystac.Link | pystac.Asset
result: pystac.Link | pystac.Asset,
expected: pystac.Link | pystac.Asset,
ignore_none: bool = False,
) -> None:
assert type(result) == type(expected)
assert result.to_dict() == expected.to_dict()
resultd = result.to_dict()
expectedd = expected.to_dict()

left = {}

if ignore_none:
for k, v in resultd.items():
if v is None and k not in expectedd:
pass
elif isinstance(v, list) and k in expectedd:
out = []
for val in v:
if isinstance(val, dict):
out.append({k: v2 for k, v2 in val.items() if v2 is not None})
else:
out.append(val)
left[k] = out
else:
left[k] = v
else:
left = resultd

assert left == expectedd


def fix_empty_multipolygon(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pgstac_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_naip_item():
expected.remove_links(rel=pystac.RelType.SELF)
result.remove_links(rel=pystac.RelType.SELF)

assert_equal(result, expected)
assert_equal(result, expected, ignore_none=True)


def test_sentinel2_l2a():
Expand All @@ -139,7 +139,7 @@ def test_sentinel2_l2a():
result.remove_links(rel=pystac.RelType.SELF)

expected.remove_links(rel=pystac.RelType.LICENSE)
assert_equal(result, expected)
assert_equal(result, expected, ignore_none=True)


def test_generate_endpoints():
Expand Down
Loading

0 comments on commit 73fdfac

Please sign in to comment.