diff --git a/CHANGELOG.md b/CHANGELOG.md index 31c938d..f474881 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,13 +12,6 @@ Keep it human-readable, your future self will thank you! ### Added -- fix: bug in color when plotting isolated nodes - -### Added - -- Add anemoi-transform link to documentation - -### Added - ci: hpc-config, CODEOWNERS (#49) - feat: New node builder class, CutOutZarrDatasetNodes, to create nodes from 2 datasets. (#30) - feat: New class, KNNAreaMaskBuilder, to specify Area of Interest (AOI) based on a set of nodes. (#30) @@ -27,10 +20,20 @@ Keep it human-readable, your future self will thank you! - feat: New method update_graph(graph) in the GraphCreator class. (#60) - feat: New class StretchedTriNodes to create a stretched mesh. (#51) - feat: Expanded MultiScaleEdges to support multi-scale connections in stretched graphs. (#51) +- fix: bug in color when plotting isolated nodes (#63) +- Add anemoi-transform link to documentation (#59) +- Added `CutOutMask` class to create a mask for a cutout. (#68) +- Added `MissingZarrVariable` and `NotMissingZarrVariable` classes to create a mask for missing zarr variables. (#68) - feat: Add CONTRIBUTORS.md file. (#72) ### Changed - ci: small fixes and updates pre-commit, downsteam-ci (#49) +- Update CODEOWNERS (#61) +- ci: extened python versions to include 3.11 and 3.12 (#66) +- Update copyright notice (#67) + +### Removed +- Remove `CutOutZarrDatasetNodes` class. (#68) - Update CODEOWNERS - Fix pre-commit regex - ci: extened python versions to include 3.11 and 3.12 diff --git a/src/anemoi/graphs/generate/tri_icosahedron.py b/src/anemoi/graphs/generate/tri_icosahedron.py index c4b11ce..72cd5cb 100644 --- a/src/anemoi/graphs/generate/tri_icosahedron.py +++ b/src/anemoi/graphs/generate/tri_icosahedron.py @@ -175,8 +175,8 @@ def add_edges_to_nx_graph( # Limit area of mesh points. if area_mask_builder is not None: - aoi_mask = area_mask_builder.get_mask(r_vertices_rad) - valid_nodes = np.where(aoi_mask)[0] + area_mask = area_mask_builder.get_mask(r_vertices_rad) + valid_nodes = np.where(area_mask)[0] else: valid_nodes = None diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index ac65f0f..ef48d45 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,4 +1,3 @@ -from .builders.from_file import CutOutZarrDatasetNodes from .builders.from_file import LimitedAreaNPZFileNodes from .builders.from_file import NPZFileNodes from .builders.from_file import ZarrDatasetNodes @@ -17,7 +16,6 @@ "HexNodes", "HEALPixNodes", "LimitedAreaHEALPixNodes", - "CutOutZarrDatasetNodes", "LimitedAreaNPZFileNodes", "LimitedAreaTriNodes", "LimitedAreaHexNodes", diff --git a/src/anemoi/graphs/nodes/attributes.py b/src/anemoi/graphs/nodes/attributes.py index e9e00b2..1498e6d 100644 --- a/src/anemoi/graphs/nodes/attributes.py +++ b/src/anemoi/graphs/nodes/attributes.py @@ -12,9 +12,11 @@ import logging from abc import ABC from abc import abstractmethod +from typing import Type import numpy as np import torch +from anemoi.datasets import open_dataset from scipy.spatial import SphericalVoronoi from torch_geometric.data import HeteroData from torch_geometric.data.storage import NodeStorage @@ -25,14 +27,15 @@ LOGGER = logging.getLogger(__name__) -class BaseWeights(ABC, NormaliserMixin): +class BaseNodeAttribute(ABC, NormaliserMixin): """Base class for the weights of the nodes.""" - def __init__(self, norm: str | None = None) -> None: + def __init__(self, norm: str | None = None, dtype: str = "float32") -> None: self.norm = norm + self.dtype = dtype @abstractmethod - def get_raw_values(self, nodes: NodeStorage, *args, **kwargs): ... + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: ... def post_process(self, values: np.ndarray) -> torch.Tensor: """Post-process the values.""" @@ -41,10 +44,10 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: norm_values = self.normalise(values) - return torch.tensor(norm_values, dtype=torch.float32) + return torch.tensor(norm_values.astype(self.dtype)) - def compute(self, graph: HeteroData, nodes_name: str, *args, **kwargs) -> torch.Tensor: - """Get the node weights. + def compute(self, graph: HeteroData, nodes_name: str, **kwargs) -> torch.Tensor: + """Get the nodes attribute. Parameters ---------- @@ -52,22 +55,20 @@ def compute(self, graph: HeteroData, nodes_name: str, *args, **kwargs) -> torch. Graph. nodes_name : str Name of the nodes. - args : tuple - Additional arguments. kwargs : dict Additional keyword arguments. Returns ------- torch.Tensor - Weights associated to the nodes. + Attributes associated to the nodes. """ nodes = graph[nodes_name] - weights = self.get_raw_values(nodes, *args, **kwargs) - return self.post_process(weights) + attributes = self.get_raw_values(nodes, **kwargs) + return self.post_process(attributes) -class UniformWeights(BaseWeights): +class UniformWeights(BaseNodeAttribute): """Implements a uniform weight for the nodes. Methods @@ -76,27 +77,25 @@ class UniformWeights(BaseWeights): Compute the area attributes for each node. """ - def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: """Compute the weights. Parameters ---------- nodes : NodeStorage Nodes of the graph. - args : tuple - Additional arguments. kwargs : dict Additional keyword arguments. Returns ------- np.ndarray - Weights. + Attributes. """ return np.ones(nodes.num_nodes) -class AreaWeights(BaseWeights): +class AreaWeights(BaseNodeAttribute): """Implements the area of the nodes as the weights. Attributes @@ -114,12 +113,18 @@ class AreaWeights(BaseWeights): Compute the area attributes for each node. """ - def __init__(self, norm: str | None = None, radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0])) -> None: - super().__init__(norm) + def __init__( + self, + norm: str | None = None, + radius: float = 1.0, + centre: np.ndarray = np.array([0, 0, 0]), + dtype: str = "float32", + ) -> None: + super().__init__(norm, dtype) self.radius = radius self.centre = centre - def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: """Compute the area associated to each node. It uses Voronoi diagrams to compute the area of each node. @@ -128,15 +133,13 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: ---------- nodes : NodeStorage Nodes of the graph. - args : tuple - Additional arguments. kwargs : dict Additional keyword arguments. Returns ------- np.ndarray - Weights. + Attributes. """ latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] points = latlon_rad_to_cartesian((np.asarray(latitudes), np.asarray(longitudes))) @@ -148,3 +151,83 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: np.array(area_weights).sum(), ) return area_weights + + +class BooleanBaseNodeAttribute(BaseNodeAttribute, ABC): + """Base class for boolean node attributes.""" + + def __init__(self) -> None: + super().__init__(norm=None, dtype="bool") + + +class NonmissingZarrVariable(BooleanBaseNodeAttribute): + """Mask of valid (not missing) values of a Zarr dataset variable. + + It reads a variable from a Zarr dataset and returns a boolean mask of nonmissing values in the first timestep. + + Attributes + ---------- + variable : str + Variable to read from the Zarr dataset. + norm : str + Normalization of the weights. + + Methods + ------- + compute(self, graph, nodes_name) + Compute the attribute for each node. + """ + + def __init__(self, variable: str) -> None: + super().__init__() + self.variable = variable + + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: + assert ( + nodes["node_type"] == "ZarrDatasetNodes" + ), f"{self.__class__.__name__} can only be used with ZarrDatasetNodes." + ds = open_dataset(nodes["_dataset"], select=self.variable)[0].squeeze() + return ~np.isnan(ds) + + +class CutOutMask(BooleanBaseNodeAttribute): + """Cut out mask.""" + + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: + assert isinstance(nodes["_dataset"], dict), "The 'dataset' attribute must be a dictionary." + assert "cutout" in nodes["_dataset"], "The 'dataset' attribute must contain a 'cutout' key." + num_lam, num_other = open_dataset(nodes["_dataset"]).grids + return np.array([True] * num_lam + [False] * num_other, dtype=bool) + + +class BooleanOperation(BooleanBaseNodeAttribute, ABC): + """Base class for boolean operations.""" + + def __init__(self, masks: list[str | Type[BooleanBaseNodeAttribute]]) -> None: + super().__init__() + self.masks = masks + + @staticmethod + def get_mask_values(mask: str | Type[BaseNodeAttribute], nodes: NodeStorage, **kwargs) -> np.array: + if isinstance(mask, str): + attributes = nodes[mask] + assert ( + attributes.dtype == "bool" + ), f"The mask attribute '{mask}' must be a boolean but is {attributes.dtype}." + return attributes + + return mask.get_raw_values(nodes, **kwargs) + + +class BooleanAndMask(BooleanOperation): + """Boolean AND mask.""" + + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: + return np.logical_and.reduce([BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks]) + + +class BooleanOrMask(BooleanOperation): + """Boolean OR mask.""" + + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: + return np.logical_or.reduce([BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks]) diff --git a/src/anemoi/graphs/nodes/builders/base.py b/src/anemoi/graphs/nodes/builders/base.py index 7b7c59f..a141ebf 100644 --- a/src/anemoi/graphs/nodes/builders/base.py +++ b/src/anemoi/graphs/nodes/builders/base.py @@ -28,10 +28,12 @@ class BaseNodeBuilder(ABC): ---------- name : str name of the nodes, key for the nodes in the HeteroData graph object. - aoi_mask_builder : KNNAreaMaskBuilder + area_mask_builder : KNNAreaMaskBuilder The area of interest mask builder, if any. Defaults to None. """ + hidden_attributes: set[str] = set() + def __init__(self, name: str) -> None: self.name = name self.area_mask_builder = None @@ -68,6 +70,9 @@ def register_attributes(self, graph: HeteroData, config: DotDict | None = None) HeteroData The graph with the registered attributes. """ + for hidden_attr in self.hidden_attributes: + graph[self.name][f"_{hidden_attr}"] = getattr(self, hidden_attr) + for attr_name, attr_config in config.items(): graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py index e0b7e9d..786ae76 100644 --- a/src/anemoi/graphs/nodes/builders/from_file.py +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -15,7 +15,8 @@ import numpy as np import torch from anemoi.datasets import open_dataset -from anemoi.utils.config import DotDict +from omegaconf import DictConfig +from omegaconf import OmegaConf from torch_geometric.data import HeteroData from anemoi.graphs.generate.masks import KNNAreaMaskBuilder @@ -29,7 +30,7 @@ class ZarrDatasetNodes(BaseNodeBuilder): Attributes ---------- - dataset : zarr.core.Array + dataset : str | DictConfig The dataset. Methods @@ -44,10 +45,11 @@ class ZarrDatasetNodes(BaseNodeBuilder): Update the graph with new nodes and attributes. """ - def __init__(self, dataset: DotDict, name: str) -> None: + def __init__(self, dataset: DictConfig, name: str) -> None: LOGGER.info("Reading the dataset from %s.", dataset) - self.dataset = open_dataset(dataset) + self.dataset = dataset if isinstance(dataset, str) else OmegaConf.to_container(dataset) super().__init__(name) + self.hidden_attributes = BaseNodeBuilder.hidden_attributes | {"dataset"} def get_coordinates(self) -> torch.Tensor: """Get the coordinates of the nodes. @@ -57,28 +59,8 @@ def get_coordinates(self) -> torch.Tensor: torch.Tensor of shape (num_nodes, 2) A 2D tensor with the coordinates, in radians. """ - return self.reshape_coords(self.dataset.latitudes, self.dataset.longitudes) - - -class CutOutZarrDatasetNodes(ZarrDatasetNodes): - """Nodes from Zarr dataset.""" - - def __init__( - self, name: str, lam_dataset: str, forcing_dataset: str, thinning: int = 1, adjust: str = "all" - ) -> None: - dataset_config = { - "cutout": [{"dataset": lam_dataset, "thinning": thinning}, {"dataset": forcing_dataset}], - "adjust": adjust, - } - super().__init__(dataset_config, name) - self.n_cutout, self.n_other = self.dataset.grids - - def register_attributes(self, graph: HeteroData, config: DotDict) -> None: - # this is a mask to cutout the LAM area - graph[self.name]["cutout"] = torch.tensor([True] * self.n_cutout + [False] * self.n_other, dtype=bool).reshape( - (-1, 1) - ) - return super().register_attributes(graph, config) + dataset = open_dataset(self.dataset) + return self.reshape_coords(dataset.latitudes, dataset.longitudes) class NPZFileNodes(BaseNodeBuilder): diff --git a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py index 5c2ab51..1f961e0 100644 --- a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py +++ b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py @@ -16,7 +16,6 @@ import networkx as nx import numpy as np import torch -from anemoi.utils.config import DotDict from torch_geometric.data import HeteroData from anemoi.graphs.generate.hex_icosahedron import create_hex_nodes @@ -48,6 +47,12 @@ def __init__( self.resolutions = resolution super().__init__(name) + self.hidden_attributes = BaseNodeBuilder.hidden_attributes | { + "resolutions", + "nx_graph", + "node_ordering", + "area_mask_builder", + } def get_coordinates(self) -> torch.Tensor: """Get the coordinates of the nodes. @@ -61,14 +66,7 @@ def get_coordinates(self) -> torch.Tensor: return torch.tensor(coords_rad[self.node_ordering], dtype=torch.float32) @abstractmethod - def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: ... - - def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: - graph[self.name]["_resolutions"] = self.resolutions - graph[self.name]["_nx_graph"] = self.nx_graph - graph[self.name]["_node_ordering"] = self.node_ordering - graph[self.name]["_area_mask_builder"] = self.area_mask_builder - return super().register_attributes(graph, config) + def create_nodes(self) -> tuple[nx.DiGraph, np.ndarray, list[int]]: ... class LimitedAreaIcosahedralNodes(IcosahedralNodes): diff --git a/tests/nodes/test_cutout_nodes.py b/tests/nodes/test_cutout_nodes.py index 32135a9..8ce6613 100644 --- a/tests/nodes/test_cutout_nodes.py +++ b/tests/nodes/test_cutout_nodes.py @@ -9,6 +9,7 @@ import pytest import torch +from omegaconf import OmegaConf from torch_geometric.data import HeteroData from anemoi.graphs.nodes.attributes import AreaWeights @@ -17,27 +18,21 @@ def test_init(mocker, mock_zarr_dataset_cutout): - """Test CutOutZarrDatasetNodes initialization.""" + """Test ZarrDatasetNodes initialization with cutout.""" mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) - node_builder = from_file.CutOutZarrDatasetNodes( - forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" + node_builder = from_file.ZarrDatasetNodes( + OmegaConf.create({"cutout": ["lam.zarr", "global.zarr"]}), name="test_nodes" ) assert isinstance(node_builder, from_file.BaseNodeBuilder) - assert isinstance(node_builder, from_file.CutOutZarrDatasetNodes) - - -def test_fail_init(): - """Test CutOutZarrDatasetNodes initialization with invalid resolution.""" - with pytest.raises(TypeError): - from_file.CutOutZarrDatasetNodes("global_dataset.zarr", name="test_nodes") + assert isinstance(node_builder, from_file.ZarrDatasetNodes) def test_register_nodes(mocker, mock_zarr_dataset_cutout): - """Test CutOutZarrDatasetNodes register correctly the nodes.""" + """Test ZarrDatasetNodes register correctly the nodes with cutout operation.""" mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) - node_builder = from_file.CutOutZarrDatasetNodes( - forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" + node_builder = from_file.ZarrDatasetNodes( + OmegaConf.create({"cutout": ["lam.zarr", "global.zarr"]}), name="test_nodes" ) graph = HeteroData() @@ -45,16 +40,16 @@ def test_register_nodes(mocker, mock_zarr_dataset_cutout): assert graph["test_nodes"].x is not None assert isinstance(graph["test_nodes"].x, torch.Tensor) - assert graph["test_nodes"].x.shape == (node_builder.dataset.num_nodes, 2) - assert graph["test_nodes"].node_type == "CutOutZarrDatasetNodes" + assert graph["test_nodes"].x.shape == (mock_zarr_dataset_cutout.num_nodes, 2) + assert graph["test_nodes"].node_type == "ZarrDatasetNodes" @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) def test_register_attributes(mocker, mock_zarr_dataset_cutout, graph_with_nodes: HeteroData, attr_class): - """Test CutOutZarrDatasetNodes register correctly the weights.""" + """Test ZarrDatasetNodes register correctly the weights with cutout operation.""" mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) - node_builder = from_file.CutOutZarrDatasetNodes( - forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" + node_builder = from_file.ZarrDatasetNodes( + OmegaConf.create({"cutout": ["lam.zarr", "global.zarr"]}), name="test_nodes" ) config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index ed4590a..b8656b7 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -26,10 +26,11 @@ def test_init(mocker, mock_zarr_dataset): assert isinstance(node_builder, from_file.ZarrDatasetNodes) -def test_fail_init(): - """Test ZarrDatasetNodes initialization with invalid resolution.""" +def test_fail(): + """Test ZarrDatasetNodes with invalid dataset.""" + node_builder = from_file.ZarrDatasetNodes("invalid_path.zarr", name="test_nodes") with pytest.raises(zarr.errors.PathNotFoundError): - from_file.ZarrDatasetNodes("invalid_path.zarr", name="test_nodes") + node_builder.update_graph(HeteroData()) def test_register_nodes(mocker, mock_zarr_dataset): @@ -42,7 +43,7 @@ def test_register_nodes(mocker, mock_zarr_dataset): assert graph["test_nodes"].x is not None assert isinstance(graph["test_nodes"].x, torch.Tensor) - assert graph["test_nodes"].x.shape == (node_builder.dataset.num_nodes, 2) + assert graph["test_nodes"].x.shape == (mock_zarr_dataset.num_nodes, 2) assert graph["test_nodes"].node_type == "ZarrDatasetNodes"