Skip to content

Commit

Permalink
Merge branch 'main' into kyle/write-parquet
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Apr 19, 2024
2 parents 759123a + 5ed701e commit 03caee1
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 41 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@ jobs:
run: pytest tests -v

- name: Lint
run: pre-commit run --all-files
run: pre-commit run --all-files

- name: Type check
run: mypy .
16 changes: 0 additions & 16 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,3 @@ repos:
hooks:
- id: flake8
language_version: python3
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
hooks:
- id: mypy
# Override default --ignore-missing-imports
args: []
additional_dependencies:
# Type stubs
- types-PyYAML
- types-requests
- types-python-dateutil
# Typed libraries
- numpy
- pystac
- azure-data-tables
- pytest
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ test = [
"pre-commit",
"stac-geoparquet[pgstac]",
"stac-geoparquet[pc]",
"types-python-dateutil",
"types-requests",
"mypy",
]


Expand Down Expand Up @@ -72,3 +75,7 @@ module = [
]

ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "stac_geoparquet.*"
disallow_untyped_defs = true
12 changes: 7 additions & 5 deletions stac_geoparquet/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import logging
import sys
import os
from typing import List, Optional

from stac_geoparquet import pc_runner

logger = logging.getLogger("stac_geoparquet.pgstac_reader")


def parse_args(args=None):
def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--output-protocol",
Expand Down Expand Up @@ -47,7 +49,7 @@ def parse_args(args=None):
return parser.parse_args(args)


def setup_logging():
def setup_logging() -> None:
import logging
import warnings
import rich.logging
Expand Down Expand Up @@ -88,10 +90,10 @@ def setup_logging():
}


def main(args=None):
def main(inp: Optional[List[str]] = None) -> int:
import azure.data.tables

args = parse_args(args)
args = parse_args(inp)

skip = set(SKIP)
if args.extra_skip:
Expand All @@ -112,7 +114,7 @@ def main(args=None):
"credential": args.storage_options_credential,
}

def f(config):
def f(config: pc_runner.CollectionConfig) -> None:
config.export_collection(
args.connection_info,
args.output_protocol,
Expand Down
5 changes: 3 additions & 2 deletions stac_geoparquet/from_arrow.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""Convert STAC Items in Arrow Table format to JSON Lines or Python dicts."""

import os
import json
from typing import Iterable, List
from typing import Iterable, List, Union

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import shapely


def stac_table_to_ndjson(table: pa.Table, dest: str):
def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None:
"""Write a STAC Table to a newline-delimited JSON file."""
with open(dest, "w") as f:
for item_dict in stac_table_to_items(table):
Expand Down
16 changes: 11 additions & 5 deletions stac_geoparquet/pc_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
from typing import Any

import azure.data.tables
import requests
Expand Down Expand Up @@ -79,7 +80,7 @@
}


def build_render_config(render_params, assets):
def build_render_config(render_params: dict[str, Any], assets: dict[str, Any]) -> str:
flat = []
if assets:
for asset in assets:
Expand All @@ -93,7 +94,9 @@ def build_render_config(render_params, assets):
return urllib.parse.urlencode(flat)


def generate_configs_from_storage_table(table_client: azure.data.tables.TableClient):
def generate_configs_from_storage_table(
table_client: azure.data.tables.TableClient,
) -> dict[str, CollectionConfig]:
configs = {}
for entity in table_client.list_entities():
collection_id = entity["RowKey"]
Expand All @@ -109,7 +112,7 @@ def generate_configs_from_storage_table(table_client: azure.data.tables.TableCli
return configs


def generate_configs_from_api(url):
def generate_configs_from_api(url: str) -> dict[str, CollectionConfig]:
configs = {}
r = requests.get(url)
r.raise_for_status()
Expand All @@ -131,7 +134,7 @@ def generate_configs_from_api(url):

def merge_configs(
table_configs: dict[str, CollectionConfig], api_configs: dict[str, CollectionConfig]
):
) -> dict[str, CollectionConfig]:
# what a mess. Get partitioning config from the API, render from the table.
configs = {}
for k in table_configs.keys() | api_configs.keys():
Expand All @@ -142,9 +145,12 @@ def merge_configs(
if api_config:
config.partition_frequency = api_config.partition_frequency
configs[k] = config
return configs


def get_configs(table_client):
def get_configs(
table_client: azure.data.tables.TableClient,
) -> dict[str, CollectionConfig]:
table_configs = generate_configs_from_storage_table(table_client)
api_configs = generate_configs_from_api(
"https://planetarycomputer.microsoft.com/api/stac/v1/collections"
Expand Down
17 changes: 9 additions & 8 deletions stac_geoparquet/pgstac_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CollectionConfig:
should_inject_dynamic_properties: bool = True
render_config: str | None = None

def __post_init__(self):
def __post_init__(self) -> None:
self._collection: pystac.Collection | None = None

@property
Expand Down Expand Up @@ -146,8 +146,8 @@ def export_partition(
output_protocol: str,
output_path: str,
storage_options: dict[str, Any] | None = None,
rewrite=False,
skip_empty_partitions=False,
rewrite: bool = False,
skip_empty_partitions: bool = False,
) -> str | None:
storage_options = storage_options or {}
az_fs = fsspec.filesystem(output_protocol, **storage_options)
Expand All @@ -157,6 +157,7 @@ def export_partition(

db = pypgstac.db.PgstacDB(conninfo)
with db:
assert db.connection is not None
db.connection.execute("set statement_timeout = 300000;")
# logger.debug("Reading base item")
# TODO: proper escaping
Expand All @@ -169,7 +170,7 @@ def export_partition(
logger.debug("No records found for query %s.", query)
return None

items = self.make_pgstac_items(records, base_item)
items = self.make_pgstac_items(records, base_item) # type: ignore[arg-type]
df = to_geodataframe(items)
filesystem = pyarrow.fs.PyFileSystem(pyarrow.fs.FSSpecHandler(az_fs))
df.to_parquet(output_path, index=False, filesystem=filesystem)
Expand All @@ -184,8 +185,8 @@ def export_partition_for_endpoints(
storage_options: dict[str, Any],
part_number: int | None = None,
total: int | None = None,
rewrite=False,
skip_empty_partitions=False,
rewrite: bool = False,
skip_empty_partitions: bool = False,
) -> str | None:
"""
Export results for a pair of endpoints.
Expand Down Expand Up @@ -221,8 +222,8 @@ def export_collection(
output_protocol: str,
output_path: str,
storage_options: dict[str, Any],
rewrite=False,
skip_empty_partitions=False,
rewrite: bool = False,
skip_empty_partitions: bool = False,
) -> list[str | None]:
base_query = textwrap.dedent(
f"""\
Expand Down
2 changes: 1 addition & 1 deletion stac_geoparquet/stac_geoparquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
SELF_LINK_COLUMN = "self_link"


def _fix_array(v):
def _fix_array(v: Any) -> Any:
if isinstance(v, np.ndarray):
v = v.tolist()

Expand Down
8 changes: 5 additions & 3 deletions stac_geoparquet/to_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union, Generator

import ciso8601
import numpy as np
Expand All @@ -14,7 +14,9 @@
import shapely.geometry


def _chunks(lst: Sequence[Dict[str, Any]], n: int):
def _chunks(
lst: Sequence[Dict[str, Any]], n: int
) -> Generator[Sequence[Dict[str, Any]], None, None]:
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
Expand Down Expand Up @@ -67,7 +69,7 @@ def parse_stac_ndjson_to_arrow(
*,
chunk_size: int = 8192,
schema: Optional[pa.Schema] = None,
):
) -> pa.Table:
# Define outside of if/else to make mypy happy
items: List[dict] = []

Expand Down

0 comments on commit 03caee1

Please sign in to comment.