Skip to content

Commit

Permalink
Merge branch 'develop' into feature/log-warning-area-weights
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX authored Nov 22, 2024
2 parents dc4c074 + d5a35f0 commit 6b8ddce
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 80 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.0...HEAD)

### Added
### Added
- feat: Define node sets and edges based on an ICON icosahedral mesh (#53)
- feat: Add support for `post_processors` in the recipe. (#71)
- feat: Add `RemoveUnconnectedNodes` post processor to clean unconnected nodes in LAM. (#71)
- feat: Define node sets and edges based on an ICON icosahedral mesh (#53)
- feat: Support for multiple edge builders between two sets of nodes (#70)

## [0.4.0 - LAM and stretched graphs](https://github.com/ecmwf/anemoi-graphs/compare/0.3.0...0.4.0) - 2024-11-08

Expand Down Expand Up @@ -49,6 +51,8 @@ Keep it human-readable, your future self will thank you!
- ci: extened python versions to include 3.11 and 3.12
- Update copyright notice
- Fix `__version__` import in init
- The `edge_builder` field in the recipe is renamed to `edge_builders`. It now receives a list of edge builders. (#70)
- The `{source|target}_mask_attr_name` field is moved to inside the edge builder definition. (#70)

## [0.3.0 Anemoi-graphs, minor release](https://github.com/ecmwf/anemoi-graphs/compare/0.2.1...0.3.0) - 2024-09-03

Expand Down
29 changes: 22 additions & 7 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
from itertools import chain
from pathlib import Path
from warnings import warn

import torch
from anemoi.utils.config import DotDict
Expand Down Expand Up @@ -55,13 +56,27 @@ def update_graph(self, graph: HeteroData) -> HeteroData:
)

for edges_cfg in self.config.get("edges", {}):
graph = instantiate(
edges_cfg.edge_builder,
edges_cfg.source_name,
edges_cfg.target_name,
source_mask_attr_name=edges_cfg.get("source_mask_attr_name", None),
target_mask_attr_name=edges_cfg.get("target_mask_attr_name", None),
).update_graph(graph, edges_cfg.get("attributes", {}))

if "edge_builder" in edges_cfg:
warn(
"This format will be deprecated. The key 'edge_builder' is renamed to 'edge_builders' and takes a list of edge builders. In addition, the source_mask_attr_name & target_mask_attr_name fields are moved under the each edge builder.",
DeprecationWarning,
stacklevel=2,
)

edge_builder_cfg = edges_cfg.get("edge_builder")
if edge_builder_cfg is not None:
edge_builder_cfg.source_mask_attr_name = edges_cfg.get("source_mask_attr_name")
edge_builder_cfg.target_mask_attr_name = edges_cfg.get("target_mask_attr_name")
edges_cfg.edge_builders = [edge_builder_cfg]

for edge_builder_cfg in edges_cfg.edge_builders:
edge_builder = instantiate(
edge_builder_cfg, source_name=edges_cfg.source_name, target_name=edges_cfg.target_name
)
graph = edge_builder.register_edges(graph)

graph = edge_builder.register_attributes(graph, edges_cfg.get("attributes", {}))

return graph

Expand Down
37 changes: 24 additions & 13 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaTriNodes
from anemoi.graphs.nodes.builders.from_refined_icosahedron import StretchedTriNodes
from anemoi.graphs.nodes.builders.from_refined_icosahedron import TriNodes
from anemoi.graphs.utils import concat_edges
from anemoi.graphs.utils import get_grid_reference_distance

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -98,8 +99,19 @@ def register_edges(self, graph: HeteroData) -> HeteroData:
HeteroData
The graph with the registered edges.
"""
graph[self.name].edge_index = self.get_edge_index(graph)
graph[self.name].edge_type = type(self).__name__
edge_index = self.get_edge_index(graph)
edge_type = type(self).__name__

if "edge_index" in graph[self.name]:
# Expand current edge indices
graph[self.name].edge_index = concat_edges(graph[self.name].edge_index, edge_index)
if edge_type not in graph[self.name].edge_type:
graph[self.name].edge_type = graph[self.name].edge_type + "," + edge_type
return graph

# Register new edge indices
graph[self.name].edge_index = edge_index
graph[self.name].edge_type = edge_type
return graph

def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData:
Expand Down Expand Up @@ -164,12 +176,14 @@ def get_node_coordinates(
def undo_masking(self, adj_matrix, source_nodes: NodeStorage, target_nodes: NodeStorage):
if self.target_mask_attr_name is not None:
target_mask = target_nodes[self.target_mask_attr_name].squeeze()
target_mapper = dict(zip(list(range(len(adj_matrix.row))), np.where(target_mask)[0]))
assert adj_matrix.shape[0] == target_mask.sum()
target_mapper = dict(zip(list(range(adj_matrix.shape[0])), np.where(target_mask)[0]))
adj_matrix.row = np.vectorize(target_mapper.get)(adj_matrix.row)

if self.source_mask_attr_name is not None:
source_mask = source_nodes[self.source_mask_attr_name].squeeze()
source_mapper = dict(zip(list(range(len(adj_matrix.col))), np.where(source_mask)[0]))
assert adj_matrix.shape[1] == source_mask.sum()
source_mapper = dict(zip(list(range(adj_matrix.shape[1])), np.where(source_mask)[0]))
adj_matrix.col = np.vectorize(source_mapper.get)(adj_matrix.col)

if self.source_mask_attr_name is not None or self.target_mask_attr_name is not None:
Expand Down Expand Up @@ -394,7 +408,6 @@ def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs):
assert isinstance(x_hops, int), "Number of x_hops must be an integer"
assert x_hops > 0, "Number of x_hops must be positive"
self.x_hops = x_hops
self.node_type = None

def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage:
nodes["_nx_graph"] = tri_icosahedron.add_edges_to_nx_graph(
Expand Down Expand Up @@ -428,25 +441,23 @@ def add_edges_from_hex_nodes(self, nodes: NodeStorage) -> NodeStorage:
return nodes

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
if self.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]:
if source_nodes.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]:
source_nodes = self.add_edges_from_tri_nodes(source_nodes)
elif self.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]:
elif source_nodes.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]:
source_nodes = self.add_edges_from_hex_nodes(source_nodes)
elif self.node_type == StretchedTriNodes.__name__:
elif source_nodes.node_type == StretchedTriNodes.__name__:
source_nodes = self.add_edges_from_stretched_tri_nodes(source_nodes)
else:
raise ValueError(f"Invalid node type {self.node_type}")
raise ValueError(f"Invalid node type {source_nodes.node_type}")

adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo")

return adjmat

def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData:
self.node_type = graph[self.source_name].node_type
node_type = graph[self.source_name].node_type
valid_node_names = [n.__name__ for n in self.VALID_NODES]
assert (
self.node_type in valid_node_names
), f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes."
assert node_type in valid_node_names, f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes."

return super().update_graph(graph, attrs_config)

Expand Down
32 changes: 25 additions & 7 deletions src/anemoi/graphs/nodes/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from abc import ABC
from abc import abstractmethod
from typing import Type
from typing import Union

import numpy as np
import torch
Expand All @@ -26,6 +27,8 @@

LOGGER = logging.getLogger(__name__)

MaskAttributeType = Union[str, Type["BooleanBaseNodeAttribute"]]


class BaseNodeAttribute(ABC, NormaliserMixin):
"""Base class for the weights of the nodes."""
Expand Down Expand Up @@ -220,12 +223,12 @@ def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
class BooleanOperation(BooleanBaseNodeAttribute, ABC):
"""Base class for boolean operations."""

def __init__(self, masks: list[str | Type[BooleanBaseNodeAttribute]]) -> None:
def __init__(self, masks: MaskAttributeType | list[MaskAttributeType]) -> None:
super().__init__()
self.masks = masks
self.masks = masks if isinstance(masks, list) else [masks]

@staticmethod
def get_mask_values(mask: str | Type[BaseNodeAttribute], nodes: NodeStorage, **kwargs) -> np.array:
def get_mask_values(mask: MaskAttributeType, nodes: NodeStorage, **kwargs) -> np.array:
if isinstance(mask, str):
attributes = nodes[mask]
assert (
Expand All @@ -235,16 +238,31 @@ def get_mask_values(mask: str | Type[BaseNodeAttribute], nodes: NodeStorage, **k

return mask.get_raw_values(nodes, **kwargs)

@abstractmethod
def reduce_op(self, masks: list[np.ndarray]) -> np.ndarray: ...

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
mask_values = [BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks]
return self.reduce_op(mask_values)


class BooleanNot(BooleanOperation):
"""Boolean NOT mask."""

def reduce_op(self, masks: list[np.ndarray]) -> np.ndarray:
assert len(self.masks) == 1, f"The {self.__class__.__name__} can only be aplied to one mask."
return ~masks[0]


class BooleanAndMask(BooleanOperation):
"""Boolean AND mask."""

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
return np.logical_and.reduce([BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks])
def reduce_op(self, masks: list[np.ndarray]) -> np.ndarray:
return np.logical_and.reduce(masks)


class BooleanOrMask(BooleanOperation):
"""Boolean OR mask."""

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
return np.logical_or.reduce([BooleanOperation.get_mask_values(mask, nodes, **kwargs) for mask in self.masks])
def reduce_op(self, masks: list[np.ndarray]) -> np.ndarray:
return np.logical_or.reduce(masks)
57 changes: 9 additions & 48 deletions src/anemoi/graphs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,61 +63,22 @@ def get_grid_reference_distance(coords_rad: torch.Tensor, mask: torch.Tensor | N
return dists[dists > 0].max()


def add_margin(lats: np.ndarray, lons: np.ndarray, margin: float) -> tuple[np.ndarray, np.ndarray]:
"""Add a margin to the convex hull of the points considered.
For each point (lat, lon) add 8 points around it, each at a distance of `margin` from the original point.
Arguments
---------
lats : np.ndarray
Latitudes of the points considered.
lons : np.ndarray
Longitudes of the points considered.
margin : float
The margin to add to the convex hull.
Returns
-------
latitudes : np.ndarray
Latitudes of the points considered, including the margin.
longitudes : np.ndarray
Longitudes of the points considered, including the margin.
"""
assert margin >= 0, "Margin must be non-negative"
if margin == 0:
return lats, lons

latitudes, longitudes = [], []
for lat_sign in [-1, 0, 1]:
for lon_sign in [-1, 0, 1]:
latitudes.append(lats + lat_sign * margin)
longitudes.append(lons + lon_sign * margin)

return np.concatenate(latitudes), np.concatenate(longitudes)


def get_index_in_outer_join(vector: torch.Tensor, tensor: torch.Tensor) -> int:
"""Index position of vector.
Get the index position of a vector in a matrix.
def concat_edges(edge_indices1: torch.Tensor, edge_indices2: torch.Tensor) -> torch.Tensor:
"""Concat edges
Parameters
----------
vector : torch.Tensor of shape (N, )
Vector to get its position in the matrix.
tensor : torch.Tensor of shape (M, N,)
Tensor in which the position is searched.
edge_indices1: torch.Tensor
Edge indices of the first set of edges. Shape: (2, num_edges1)
edge_indices2: torch.Tensor
Edge indices of the second set of edges. Shape: (2, num_edges2)
Returns
-------
int
Index position of `vector` in `tensor`. -1 if `vector` is not in `tensor`.
torch.Tensor
Concatenated edge indices.
"""
mask = torch.all(tensor == vector, axis=1)
if mask.any():
return int(torch.where(mask)[0])
return -1
return torch.unique(torch.cat([edge_indices1, edge_indices2], axis=1), dim=1)


def haversine_distance(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndarray:
Expand Down
7 changes: 3 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,9 @@ def config_file(tmp_path) -> tuple[str, str]:
{
"source_name": "test_nodes",
"target_name": "test_nodes",
"edge_builder": {
"_target_": "anemoi.graphs.edges.KNNEdges",
"num_nearest_neighbours": 3,
},
"edge_builders": [
{"_target_": "anemoi.graphs.edges.KNNEdges", "num_nearest_neighbours": 3},
],
"attributes": {
"dist_norm": {"_target_": "anemoi.graphs.edges.attributes.EdgeLength"},
"edge_dirs": {"_target_": "anemoi.graphs.edges.attributes.EdgeDirection"},
Expand Down
18 changes: 18 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
import torch

from anemoi.graphs.utils import concat_edges


def test_concat_edges():
edge_indices1 = torch.tensor([[0, 1, 2, 3], [-1, -2, -3, -4]], dtype=torch.int64)
edge_indices2 = torch.tensor(np.array([[0, 4], [-1, -5]]), dtype=torch.int64)
no_edges = torch.tensor([[], []], dtype=torch.int64)

result1 = concat_edges(edge_indices1, edge_indices2)
result2 = concat_edges(no_edges, edge_indices2)

expected1 = torch.tensor([[0, 1, 2, 3, 4], [-1, -2, -3, -4, -5]], dtype=torch.int64)

assert torch.allclose(result1, expected1)
assert torch.allclose(result2, edge_indices2)

0 comments on commit 6b8ddce

Please sign in to comment.