diff --git a/CHANGELOG.md b/CHANGELOG.md index ceb1634..46a11ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,10 +10,12 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.0...HEAD) -### Added +### Added - feat: Define node sets and edges based on an ICON icosahedral mesh (#53) - feat: Add support for `post_processors` in the recipe. (#71) - feat: Add `RemoveUnconnectedNodes` post processor to clean unconnected nodes in LAM. (#71) +- feat: Define node sets and edges based on an ICON icosahedral mesh (#53) +- feat: Support for multiple edge builders between two sets of nodes (#70) ## [0.4.0 - LAM and stretched graphs](https://github.com/ecmwf/anemoi-graphs/compare/0.3.0...0.4.0) - 2024-11-08 @@ -49,6 +51,8 @@ Keep it human-readable, your future self will thank you! - ci: extened python versions to include 3.11 and 3.12 - Update copyright notice - Fix `__version__` import in init +- The `edge_builder` field in the recipe is renamed to `edge_builders`. It now receives a list of edge builders. (#70) +- The `{source|target}_mask_attr_name` field is moved to inside the edge builder definition. (#70) ## [0.3.0 Anemoi-graphs, minor release](https://github.com/ecmwf/anemoi-graphs/compare/0.2.1...0.3.0) - 2024-09-03 diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 5d37f4d..74d23eb 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -12,6 +12,7 @@ import logging from itertools import chain from pathlib import Path +from warnings import warn import torch from anemoi.utils.config import DotDict @@ -55,13 +56,27 @@ def update_graph(self, graph: HeteroData) -> HeteroData: ) for edges_cfg in self.config.get("edges", {}): - graph = instantiate( - edges_cfg.edge_builder, - edges_cfg.source_name, - edges_cfg.target_name, - source_mask_attr_name=edges_cfg.get("source_mask_attr_name", None), - target_mask_attr_name=edges_cfg.get("target_mask_attr_name", None), - ).update_graph(graph, edges_cfg.get("attributes", {})) + + if "edge_builder" in edges_cfg: + warn( + "This format will be deprecated. The key 'edge_builder' is renamed to 'edge_builders' and takes a list of edge builders. In addition, the source_mask_attr_name & target_mask_attr_name fields are moved under the each edge builder.", + DeprecationWarning, + stacklevel=2, + ) + + edge_builder_cfg = edges_cfg.get("edge_builder") + if edge_builder_cfg is not None: + edge_builder_cfg.source_mask_attr_name = edges_cfg.get("source_mask_attr_name") + edge_builder_cfg.target_mask_attr_name = edges_cfg.get("target_mask_attr_name") + edges_cfg.edge_builders = [edge_builder_cfg] + + for edge_builder_cfg in edges_cfg.edge_builders: + edge_builder = instantiate( + edge_builder_cfg, source_name=edges_cfg.source_name, target_name=edges_cfg.target_name + ) + graph = edge_builder.register_edges(graph) + + graph = edge_builder.register_attributes(graph, edges_cfg.get("attributes", {})) return graph diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 92c3003..936aab9 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -33,6 +33,7 @@ from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaTriNodes from anemoi.graphs.nodes.builders.from_refined_icosahedron import StretchedTriNodes from anemoi.graphs.nodes.builders.from_refined_icosahedron import TriNodes +from anemoi.graphs.utils import concat_edges from anemoi.graphs.utils import get_grid_reference_distance LOGGER = logging.getLogger(__name__) @@ -98,8 +99,19 @@ def register_edges(self, graph: HeteroData) -> HeteroData: HeteroData The graph with the registered edges. """ - graph[self.name].edge_index = self.get_edge_index(graph) - graph[self.name].edge_type = type(self).__name__ + edge_index = self.get_edge_index(graph) + edge_type = type(self).__name__ + + if "edge_index" in graph[self.name]: + # Expand current edge indices + graph[self.name].edge_index = concat_edges(graph[self.name].edge_index, edge_index) + if edge_type not in graph[self.name].edge_type: + graph[self.name].edge_type = graph[self.name].edge_type + "," + edge_type + return graph + + # Register new edge indices + graph[self.name].edge_index = edge_index + graph[self.name].edge_type = edge_type return graph def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: @@ -164,12 +176,14 @@ def get_node_coordinates( def undo_masking(self, adj_matrix, source_nodes: NodeStorage, target_nodes: NodeStorage): if self.target_mask_attr_name is not None: target_mask = target_nodes[self.target_mask_attr_name].squeeze() - target_mapper = dict(zip(list(range(len(adj_matrix.row))), np.where(target_mask)[0])) + assert adj_matrix.shape[0] == target_mask.sum() + target_mapper = dict(zip(list(range(adj_matrix.shape[0])), np.where(target_mask)[0])) adj_matrix.row = np.vectorize(target_mapper.get)(adj_matrix.row) if self.source_mask_attr_name is not None: source_mask = source_nodes[self.source_mask_attr_name].squeeze() - source_mapper = dict(zip(list(range(len(adj_matrix.col))), np.where(source_mask)[0])) + assert adj_matrix.shape[1] == source_mask.sum() + source_mapper = dict(zip(list(range(adj_matrix.shape[1])), np.where(source_mask)[0])) adj_matrix.col = np.vectorize(source_mapper.get)(adj_matrix.col) if self.source_mask_attr_name is not None or self.target_mask_attr_name is not None: @@ -394,7 +408,6 @@ def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs): assert isinstance(x_hops, int), "Number of x_hops must be an integer" assert x_hops > 0, "Number of x_hops must be positive" self.x_hops = x_hops - self.node_type = None def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage: nodes["_nx_graph"] = tri_icosahedron.add_edges_to_nx_graph( @@ -428,25 +441,23 @@ def add_edges_from_hex_nodes(self, nodes: NodeStorage) -> NodeStorage: return nodes def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): - if self.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]: + if source_nodes.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]: source_nodes = self.add_edges_from_tri_nodes(source_nodes) - elif self.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]: + elif source_nodes.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]: source_nodes = self.add_edges_from_hex_nodes(source_nodes) - elif self.node_type == StretchedTriNodes.__name__: + elif source_nodes.node_type == StretchedTriNodes.__name__: source_nodes = self.add_edges_from_stretched_tri_nodes(source_nodes) else: - raise ValueError(f"Invalid node type {self.node_type}") + raise ValueError(f"Invalid node type {source_nodes.node_type}") adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo") return adjmat def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData: - self.node_type = graph[self.source_name].node_type + node_type = graph[self.source_name].node_type valid_node_names = [n.__name__ for n in self.VALID_NODES] - assert ( - self.node_type in valid_node_names - ), f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes." + assert node_type in valid_node_names, f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes." return super().update_graph(graph, attrs_config) diff --git a/src/anemoi/graphs/nodes/attributes.py b/src/anemoi/graphs/nodes/attributes.py index 4a7ad10..db4ffa1 100644 --- a/src/anemoi/graphs/nodes/attributes.py +++ b/src/anemoi/graphs/nodes/attributes.py @@ -13,6 +13,7 @@ from abc import ABC from abc import abstractmethod from typing import Type +from typing import Union import numpy as np import torch @@ -26,6 +27,8 @@ LOGGER = logging.getLogger(__name__) +MaskAttributeType = Union[str, Type["BooleanBaseNodeAttribute"]] + class BaseNodeAttribute(ABC, NormaliserMixin): """Base class for the weights of the nodes.""" @@ -213,12 +216,12 @@ def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: class BooleanOperation(BooleanBaseNodeAttribute, ABC): """Base class for boolean operations.""" - def __init__(self, masks: list[str | Type[BooleanBaseNodeAttribute]]) -> None: + def __init__(self, masks: MaskAttributeType | list[MaskAttributeType]) -> None: super().__init__() - self.masks = masks + self.masks = masks if isinstance(masks, list) else [masks] @staticmethod - def get_mask_values(mask: str | Type[BaseNodeAttribute], nodes: NodeStorage, **kwargs) -> np.array: + def get_mask_values(mask: MaskAttributeType, nodes: NodeStorage, **kwargs) -> np.array: if isinstance(mask, str): attributes = nodes[mask] assert ( @@ -228,16 +231,31 @@ def get_mask_values(mask: str | Type[BaseNodeAttribute], nodes: NodeStorage, **k return mask.get_raw_values(nodes, **kwargs) + @abstractmethod + def reduce_op(self, masks: list[np.ndarray]) -> np.ndarray: ... + + def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: + mask_values = [BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks] + return self.reduce_op(mask_values) + + +class BooleanNot(BooleanOperation): + """Boolean NOT mask.""" + + def reduce_op(self, masks: list[np.ndarray]) -> np.ndarray: + assert len(self.masks) == 1, f"The {self.__class__.__name__} can only be aplied to one mask." + return ~masks[0] + class BooleanAndMask(BooleanOperation): """Boolean AND mask.""" - def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: - return np.logical_and.reduce([BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks]) + def reduce_op(self, masks: list[np.ndarray]) -> np.ndarray: + return np.logical_and.reduce(masks) class BooleanOrMask(BooleanOperation): """Boolean OR mask.""" - def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray: - return np.logical_or.reduce([BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks]) + def reduce_op(self, masks: list[np.ndarray]) -> np.ndarray: + return np.logical_or.reduce(masks) diff --git a/src/anemoi/graphs/utils.py b/src/anemoi/graphs/utils.py index a68d6e7..5aa4b79 100644 --- a/src/anemoi/graphs/utils.py +++ b/src/anemoi/graphs/utils.py @@ -63,61 +63,22 @@ def get_grid_reference_distance(coords_rad: torch.Tensor, mask: torch.Tensor | N return dists[dists > 0].max() -def add_margin(lats: np.ndarray, lons: np.ndarray, margin: float) -> tuple[np.ndarray, np.ndarray]: - """Add a margin to the convex hull of the points considered. - - For each point (lat, lon) add 8 points around it, each at a distance of `margin` from the original point. - - Arguments - --------- - lats : np.ndarray - Latitudes of the points considered. - lons : np.ndarray - Longitudes of the points considered. - margin : float - The margin to add to the convex hull. - - Returns - ------- - latitudes : np.ndarray - Latitudes of the points considered, including the margin. - longitudes : np.ndarray - Longitudes of the points considered, including the margin. - """ - assert margin >= 0, "Margin must be non-negative" - if margin == 0: - return lats, lons - - latitudes, longitudes = [], [] - for lat_sign in [-1, 0, 1]: - for lon_sign in [-1, 0, 1]: - latitudes.append(lats + lat_sign * margin) - longitudes.append(lons + lon_sign * margin) - - return np.concatenate(latitudes), np.concatenate(longitudes) - - -def get_index_in_outer_join(vector: torch.Tensor, tensor: torch.Tensor) -> int: - """Index position of vector. - - Get the index position of a vector in a matrix. +def concat_edges(edge_indices1: torch.Tensor, edge_indices2: torch.Tensor) -> torch.Tensor: + """Concat edges Parameters ---------- - vector : torch.Tensor of shape (N, ) - Vector to get its position in the matrix. - tensor : torch.Tensor of shape (M, N,) - Tensor in which the position is searched. + edge_indices1: torch.Tensor + Edge indices of the first set of edges. Shape: (2, num_edges1) + edge_indices2: torch.Tensor + Edge indices of the second set of edges. Shape: (2, num_edges2) Returns ------- - int - Index position of `vector` in `tensor`. -1 if `vector` is not in `tensor`. + torch.Tensor + Concatenated edge indices. """ - mask = torch.all(tensor == vector, axis=1) - if mask.any(): - return int(torch.where(mask)[0]) - return -1 + return torch.unique(torch.cat([edge_indices1, edge_indices2], axis=1), dim=1) def haversine_distance(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndarray: diff --git a/tests/conftest.py b/tests/conftest.py index b94cdb4..3fe8406 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,10 +100,9 @@ def config_file(tmp_path) -> tuple[str, str]: { "source_name": "test_nodes", "target_name": "test_nodes", - "edge_builder": { - "_target_": "anemoi.graphs.edges.KNNEdges", - "num_nearest_neighbours": 3, - }, + "edge_builders": [ + {"_target_": "anemoi.graphs.edges.KNNEdges", "num_nearest_neighbours": 3}, + ], "attributes": { "dist_norm": {"_target_": "anemoi.graphs.edges.attributes.EdgeLength"}, "edge_dirs": {"_target_": "anemoi.graphs.edges.attributes.EdgeDirection"}, diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..8e7a8c5 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,18 @@ +import numpy as np +import torch + +from anemoi.graphs.utils import concat_edges + + +def test_concat_edges(): + edge_indices1 = torch.tensor([[0, 1, 2, 3], [-1, -2, -3, -4]], dtype=torch.int64) + edge_indices2 = torch.tensor(np.array([[0, 4], [-1, -5]]), dtype=torch.int64) + no_edges = torch.tensor([[], []], dtype=torch.int64) + + result1 = concat_edges(edge_indices1, edge_indices2) + result2 = concat_edges(no_edges, edge_indices2) + + expected1 = torch.tensor([[0, 1, 2, 3, 4], [-1, -2, -3, -4, -5]], dtype=torch.int64) + + assert torch.allclose(result1, expected1) + assert torch.allclose(result2, edge_indices2)