diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index 3a26669..7836f71 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -28,4 +28,7 @@ jobs: run: pytest tests -v - name: Lint - run: pre-commit run --all-files \ No newline at end of file + run: pre-commit run --all-files + + - name: Type check + run: mypy . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 39984c0..cedf8ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,19 +18,3 @@ repos: hooks: - id: flake8 language_version: python3 - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.9.0 - hooks: - - id: mypy - # Override default --ignore-missing-imports - args: [] - additional_dependencies: - # Type stubs - - types-PyYAML - - types-requests - - types-python-dateutil - # Typed libraries - - numpy - - pystac - - azure-data-tables - - pytest diff --git a/pyproject.toml b/pyproject.toml index 0efdd6c..580f8c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,9 @@ test = [ "pre-commit", "stac-geoparquet[pgstac]", "stac-geoparquet[pc]", + "types-python-dateutil", + "types-requests", + "mypy", ] @@ -72,3 +75,7 @@ module = [ ] ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "stac_geoparquet.*" +disallow_untyped_defs = true \ No newline at end of file diff --git a/stac_geoparquet/cli.py b/stac_geoparquet/cli.py index 73c3978..bbb2403 100644 --- a/stac_geoparquet/cli.py +++ b/stac_geoparquet/cli.py @@ -2,12 +2,14 @@ import logging import sys import os +from typing import List, Optional + from stac_geoparquet import pc_runner logger = logging.getLogger("stac_geoparquet.pgstac_reader") -def parse_args(args=None): +def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( "--output-protocol", @@ -47,7 +49,7 @@ def parse_args(args=None): return parser.parse_args(args) -def setup_logging(): +def setup_logging() -> None: import logging import warnings import rich.logging @@ -88,10 +90,10 @@ def setup_logging(): } -def main(args=None): +def main(inp: Optional[List[str]] = None) -> int: import azure.data.tables - args = parse_args(args) + args = parse_args(inp) skip = set(SKIP) if args.extra_skip: @@ -112,7 +114,7 @@ def main(args=None): "credential": args.storage_options_credential, } - def f(config): + def f(config: pc_runner.CollectionConfig) -> None: config.export_collection( args.connection_info, args.output_protocol, diff --git a/stac_geoparquet/from_arrow.py b/stac_geoparquet/from_arrow.py index d079c19..7cf9d85 100644 --- a/stac_geoparquet/from_arrow.py +++ b/stac_geoparquet/from_arrow.py @@ -1,7 +1,8 @@ """Convert STAC Items in Arrow Table format to JSON Lines or Python dicts.""" +import os import json -from typing import Iterable, List +from typing import Iterable, List, Union import numpy as np import pyarrow as pa @@ -9,7 +10,7 @@ import shapely -def stac_table_to_ndjson(table: pa.Table, dest: str): +def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None: """Write a STAC Table to a newline-delimited JSON file.""" with open(dest, "w") as f: for item_dict in stac_table_to_items(table): diff --git a/stac_geoparquet/pc_runner.py b/stac_geoparquet/pc_runner.py index cdf37c3..d65be25 100644 --- a/stac_geoparquet/pc_runner.py +++ b/stac_geoparquet/pc_runner.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from typing import Any import azure.data.tables import requests @@ -79,7 +80,7 @@ } -def build_render_config(render_params, assets): +def build_render_config(render_params: dict[str, Any], assets: dict[str, Any]) -> str: flat = [] if assets: for asset in assets: @@ -93,7 +94,9 @@ def build_render_config(render_params, assets): return urllib.parse.urlencode(flat) -def generate_configs_from_storage_table(table_client: azure.data.tables.TableClient): +def generate_configs_from_storage_table( + table_client: azure.data.tables.TableClient, +) -> dict[str, CollectionConfig]: configs = {} for entity in table_client.list_entities(): collection_id = entity["RowKey"] @@ -109,7 +112,7 @@ def generate_configs_from_storage_table(table_client: azure.data.tables.TableCli return configs -def generate_configs_from_api(url): +def generate_configs_from_api(url: str) -> dict[str, CollectionConfig]: configs = {} r = requests.get(url) r.raise_for_status() @@ -131,7 +134,7 @@ def generate_configs_from_api(url): def merge_configs( table_configs: dict[str, CollectionConfig], api_configs: dict[str, CollectionConfig] -): +) -> dict[str, CollectionConfig]: # what a mess. Get partitioning config from the API, render from the table. configs = {} for k in table_configs.keys() | api_configs.keys(): @@ -142,9 +145,12 @@ def merge_configs( if api_config: config.partition_frequency = api_config.partition_frequency configs[k] = config + return configs -def get_configs(table_client): +def get_configs( + table_client: azure.data.tables.TableClient, +) -> dict[str, CollectionConfig]: table_configs = generate_configs_from_storage_table(table_client) api_configs = generate_configs_from_api( "https://planetarycomputer.microsoft.com/api/stac/v1/collections" diff --git a/stac_geoparquet/pgstac_reader.py b/stac_geoparquet/pgstac_reader.py index a065ab8..724d099 100644 --- a/stac_geoparquet/pgstac_reader.py +++ b/stac_geoparquet/pgstac_reader.py @@ -46,7 +46,7 @@ class CollectionConfig: should_inject_dynamic_properties: bool = True render_config: str | None = None - def __post_init__(self): + def __post_init__(self) -> None: self._collection: pystac.Collection | None = None @property @@ -146,8 +146,8 @@ def export_partition( output_protocol: str, output_path: str, storage_options: dict[str, Any] | None = None, - rewrite=False, - skip_empty_partitions=False, + rewrite: bool = False, + skip_empty_partitions: bool = False, ) -> str | None: storage_options = storage_options or {} az_fs = fsspec.filesystem(output_protocol, **storage_options) @@ -157,6 +157,7 @@ def export_partition( db = pypgstac.db.PgstacDB(conninfo) with db: + assert db.connection is not None db.connection.execute("set statement_timeout = 300000;") # logger.debug("Reading base item") # TODO: proper escaping @@ -169,7 +170,7 @@ def export_partition( logger.debug("No records found for query %s.", query) return None - items = self.make_pgstac_items(records, base_item) + items = self.make_pgstac_items(records, base_item) # type: ignore[arg-type] df = to_geodataframe(items) filesystem = pyarrow.fs.PyFileSystem(pyarrow.fs.FSSpecHandler(az_fs)) df.to_parquet(output_path, index=False, filesystem=filesystem) @@ -184,8 +185,8 @@ def export_partition_for_endpoints( storage_options: dict[str, Any], part_number: int | None = None, total: int | None = None, - rewrite=False, - skip_empty_partitions=False, + rewrite: bool = False, + skip_empty_partitions: bool = False, ) -> str | None: """ Export results for a pair of endpoints. @@ -221,8 +222,8 @@ def export_collection( output_protocol: str, output_path: str, storage_options: dict[str, Any], - rewrite=False, - skip_empty_partitions=False, + rewrite: bool = False, + skip_empty_partitions: bool = False, ) -> list[str | None]: base_query = textwrap.dedent( f"""\ diff --git a/stac_geoparquet/stac_geoparquet.py b/stac_geoparquet/stac_geoparquet.py index 74f1872..ccea058 100644 --- a/stac_geoparquet/stac_geoparquet.py +++ b/stac_geoparquet/stac_geoparquet.py @@ -24,7 +24,7 @@ SELF_LINK_COLUMN = "self_link" -def _fix_array(v): +def _fix_array(v: Any) -> Any: if isinstance(v, np.ndarray): v = v.tolist() diff --git a/stac_geoparquet/to_arrow.py b/stac_geoparquet/to_arrow.py index f2ef0d9..7321607 100644 --- a/stac_geoparquet/to_arrow.py +++ b/stac_geoparquet/to_arrow.py @@ -4,7 +4,7 @@ from copy import deepcopy from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union, Generator import ciso8601 import numpy as np @@ -14,7 +14,9 @@ import shapely.geometry -def _chunks(lst: Sequence[Dict[str, Any]], n: int): +def _chunks( + lst: Sequence[Dict[str, Any]], n: int +) -> Generator[Sequence[Dict[str, Any]], None, None]: """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): yield lst[i : i + n] @@ -67,7 +69,7 @@ def parse_stac_ndjson_to_arrow( *, chunk_size: int = 8192, schema: Optional[pa.Schema] = None, -): +) -> pa.Table: # Define outside of if/else to make mypy happy items: List[dict] = []