Skip to content

Commit

Permalink
user beter typing
Browse files Browse the repository at this point in the history
  • Loading branch information
savente93 committed Oct 30, 2023
1 parent e3f3495 commit f9061a7
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 29 deletions.
16 changes: 9 additions & 7 deletions hydromt/data_adapter/geodataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,26 @@
import warnings
from datetime import datetime
from os.path import basename, join
from pathlib import Path
from typing import Literal, NewType, Optional, Tuple, Union
from typing import Literal, Optional, Union

import numpy as np
import pyproj
from pystac import Asset as StacAsset
from pystac import Catalog as StacCatalog
from pystac import Item as StacItem

from hydromt.typing import (
GeoDataframeSource,
TotalBounds,
)

from .. import gis_utils, io
from .data_adapter import DataAdapter

logger = logging.getLogger(__name__)

__all__ = ["GeoDataFrameAdapter", "GeoDataframeSource"]

GeoDataframeSource = NewType("GeoDataframeSource", Union[str, Path])


class GeoDataFrameAdapter(DataAdapter):

Expand Down Expand Up @@ -408,7 +410,7 @@ def _set_metadata(self, gdf):

return gdf

def get_bbox(self, detect=True) -> Tuple[Tuple[float, float, float, float], int]:
def get_bbox(self, detect=True) -> TotalBounds:
"""Return the bounding box and espg code of the dataset.
if the bounding box is not set and detect is True,
Expand Down Expand Up @@ -438,7 +440,7 @@ def get_bbox(self, detect=True) -> Tuple[Tuple[float, float, float, float], int]
def detect_bbox(
self,
gdf=None,
) -> Tuple[Tuple[float, float, float, float], int]:
) -> TotalBounds:
"""Detect the bounding box and crs of the dataset.
If no dataset is provided, it will be fetched acodring to the settings in the
Expand Down Expand Up @@ -503,7 +505,7 @@ def to_stac_catalog(
bbox, crs = self.get_bbox(detect=True)
bbox = list(bbox)
props = {**self.meta, "crs": crs}
except Exception as e:
except (IndexError, KeyError, pyproj.exceptions.CRSError) as e:
if on_error == "skip":
logger.warning(
"Skipping {name} during stac conversion because"
Expand Down
24 changes: 14 additions & 10 deletions hydromt/data_adapter/geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import warnings
from datetime import datetime
from os.path import basename, join
from pathlib import Path
from typing import Literal, NewType, Optional, Tuple, Union
from typing import Literal, Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -15,6 +14,12 @@
from pystac import Catalog as StacCatalog
from pystac import Item as StacItem

from hydromt.typing import (
GeoDatasetSource,
TimeRange,
TotalBounds,
)

from .. import gis_utils, io
from ..raster import GEO_MAP_COORD
from .data_adapter import DataAdapter
Expand All @@ -23,8 +28,6 @@

__all__ = ["GeoDatasetAdapter", "GeoDatasetSource"]

GeoDatasetSource = NewType("GeoDatasetSource", Union[str, Path])


class GeoDatasetAdapter(DataAdapter):

Expand Down Expand Up @@ -475,7 +478,7 @@ def _apply_unit_conversion(self, ds, logger=logger):
ds[name].attrs.update(attrs) # set original attributes
return ds

def get_bbox(self, detect=True) -> Tuple[Tuple[float, float, float, float], int]:
def get_bbox(self, detect=True) -> TotalBounds:
"""Return the bounding box and espg code of the dataset.
if the bounding box is not set and detect is True,
Expand All @@ -496,13 +499,14 @@ def get_bbox(self, detect=True) -> Tuple[Tuple[float, float, float, float], int]
The ESPG code of the CRS of the coordinates returned in bbox
"""
bbox = self.extent.get("bbox", None)
crs = self.crs
if bbox is None and detect:
bbox, crs = self.detect_bbox()

crs = self.crs

return bbox, crs

def get_time_range(self, detect=True):
def get_time_range(self, detect=True) -> TimeRange:
"""Detect the time range of the dataset.
if the time range is not set and detect is True,
Expand Down Expand Up @@ -531,7 +535,7 @@ def get_time_range(self, detect=True):
def detect_bbox(
self,
ds=None,
) -> Tuple[Tuple[float, float, float, float], int]:
) -> TotalBounds:
"""Detect the bounding box and crs of the dataset.
If no dataset is provided, it will be fetched according to the settings in the
Expand Down Expand Up @@ -562,7 +566,7 @@ def detect_bbox(
bounds = ds.vector.bounds
return bounds, crs

def detect_time_range(self, ds=None) -> Tuple[np.datetime64, np.datetime64]:
def detect_time_range(self, ds=None) -> TimeRange:
"""Detect the temporal range of the dataset.
If no dataset is provided, it will be fetched according to the settings in the
Expand Down Expand Up @@ -623,7 +627,7 @@ def to_stac_catalog(
start_dt = pd.to_datetime(start_dt)
end_dt = pd.to_datetime(end_dt)
props = {**self.meta, "crs": crs}
except Exception as e:
except (IndexError, KeyError, pyproj.exceptions.CRSError) as e:
if on_error == "skip":
logger.warning(
"Skipping {name} during stac conversion because"
Expand Down
21 changes: 12 additions & 9 deletions hydromt/data_adapter/rasterdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import os
import warnings
from datetime import datetime
from os import PathLike
from os.path import basename, join
from typing import Dict, Literal, NewType, Optional, Tuple, Union, cast
from typing import Dict, Literal, Optional, Tuple, Union, cast

import geopandas as gpd
import numpy as np
Expand All @@ -20,6 +19,12 @@
from pystac import Item as StacItem
from rasterio.errors import RasterioIOError

from hydromt.typing import (
RasterDatasetSource,
TimeRange,
TotalBounds,
)

from .. import gis_utils, io
from ..raster import GEO_MAP_COORD
from .caching import cache_vrt_tiles
Expand All @@ -29,8 +34,6 @@

__all__ = ["RasterDatasetAdapter", "RasterDatasetSource"]

RasterDatasetSource = NewType("RasterDatasetSource", Union[str, PathLike])


class RasterDatasetAdapter(DataAdapter):

Expand Down Expand Up @@ -674,7 +677,7 @@ def _parse_zoom_level(
logger.debug(f"Parsed zoom_level {zl} ({dst_res:.2f})")
return zl

def get_bbox(self, detect=True) -> Tuple[Tuple[float, float, float, float], int]:
def get_bbox(self, detect=True) -> TotalBounds:
"""Return the bounding box and espg code of the dataset.
if the bounding box is not set and detect is True,
Expand All @@ -701,7 +704,7 @@ def get_bbox(self, detect=True) -> Tuple[Tuple[float, float, float, float], int]

return bbox, crs

def get_time_range(self, detect=True) -> Tuple[np.datetime64, np.datetime64]:
def get_time_range(self, detect=True) -> TimeRange:
"""Detect the time range of the dataset.
if the time range is not set and detect is True,
Expand Down Expand Up @@ -730,7 +733,7 @@ def get_time_range(self, detect=True) -> Tuple[np.datetime64, np.datetime64]:
def detect_bbox(
self,
ds=None,
) -> Tuple[Tuple[float, float, float, float], int]:
) -> TotalBounds:
"""Detect the bounding box and crs of the dataset.
If no dataset is provided, it will be fetched according to the settings in the
Expand Down Expand Up @@ -761,7 +764,7 @@ def detect_bbox(

return bounds, crs

def detect_time_range(self, ds=None) -> Tuple[np.datetime64, np.datetime64]:
def detect_time_range(self, ds=None) -> TimeRange:
"""Detect the temporal range of the dataset.
If no dataset is provided, it will be fetched accodring to the settings in the
Expand Down Expand Up @@ -821,7 +824,7 @@ def to_stac_catalog(
start_dt = pd.to_datetime(start_dt)
end_dt = pd.to_datetime(end_dt)
props = {**self.meta, "crs": crs}
except Exception as e:
except (IndexError, KeyError, pyproj.exceptions.CRSError) as e:
if on_error == "skip":
logger.warning(
"Skipping {name} during stac conversion because"
Expand Down
22 changes: 22 additions & 0 deletions hydromt/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Type aliases used by hydromt."""

from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Tuple, Union

GeoDataframeSource = Union[str, Path]
GeoDatasetSource = Union[str, Path]
RasterDatasetSource = Union[str, Path]
Bbox = Tuple[float, float, float, float]
Crs = int
TotalBounds = Tuple[Bbox, Crs]
TimeRange = Tuple[datetime, datetime]


class ErrorHandleMethod(Enum):
"""Strategies for error handling withing hydromt."""

RAISE = 1
SKIP = 2
COERCE = 3
6 changes: 3 additions & 3 deletions tests/test_data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def test_to_stac_geodataframe(geodf, tmpdir):
gdf_stac_catalog.add_item(gds_stac_item)
outcome = cast(StacCatalog, adapter.to_stac_catalog(on_error="coerce"))
assert gdf_stac_catalog.to_dict() == outcome.to_dict() # type: ignore
del adapter.crs # manually create an invalid adapter by deleting the crs
adapter.crs = -3.14 # manually create an invalid adapter by deleting the crs
assert adapter.to_stac_catalog("skip") is None


Expand Down Expand Up @@ -535,7 +535,7 @@ def test_to_stac_raster():
outcome = cast(StacCatalog, adapter.to_stac_catalog(on_error="raise"))

assert raster_stac_catalog.to_dict() == outcome.to_dict() # type: ignore
del adapter.crs # manually create an invalid adapter by deleting the crs
adapter.crs = -3.14 # manually create an invalid adapter by deleting the crs
assert adapter.to_stac_catalog("skip") is None


Expand Down Expand Up @@ -568,7 +568,7 @@ def test_to_stac_geodataset(geoda, tmpdir):

outcome = cast(StacCatalog, adapter.to_stac_catalog(on_error="coerce"))
assert gds_stac_catalog.to_dict() == outcome.to_dict() # type: ignore
del adapter.crs # manually create an invalid adapter by deleting the crs
adapter.crs = -3.14 # manually create an invalid adapter by deleting the crs
assert adapter.to_stac_catalog("skip") is None


Expand Down

0 comments on commit f9061a7

Please sign in to comment.